Skip to content

Commit

Permalink
Remove LogDensityProblemsAD; wrap adtype in LogDensityFunction (#806)
Browse files Browse the repository at this point in the history
* Remove LogDensityProblemsAD

* Implement LogDensityFunctionWithGrad in place of ADgradient

* Dynamically decide whether to use closure vs constant

* Combine LogDensityFunction{,WithGrad} into one (#811)

* Warn if unsupported AD type is used

* Update changelog

* Update DI compat bound

Co-authored-by: Guillaume Dalle <[email protected]>

* Don't store with_closure inside LogDensityFunction

Co-authored-by: Guillaume Dalle <[email protected]>

* setadtype --> LogDensityFunction

* Re-add ForwardDiffExt (including tests)

* Add more tests for coverage

---------

Co-authored-by: Guillaume Dalle <[email protected]>
  • Loading branch information
penelopeysm and gdalle authored Feb 19, 2025
1 parent f5e84f4 commit 90c7b26
Show file tree
Hide file tree
Showing 13 changed files with 420 additions and 187 deletions.
48 changes: 47 additions & 1 deletion HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## 0.35.0

**Breaking**
**Breaking changes**

### `.~` right hand side must be a univariate distribution

Expand Down Expand Up @@ -119,6 +119,52 @@ This release removes the feature of `VarInfo` where it kept track of which varia
This change also affects sampling in Turing.jl.
### `LogDensityFunction` argument order
- The method `LogDensityFunction(varinfo, model, context)` has been removed.
The only accepted order is `LogDensityFunction(model, varinfo, context; adtype)`.
(For an explanation of `adtype`, see below.)
The varinfo and context arguments are both still optional.
**Other changes**
### `LogDensityProblems` interface
LogDensityProblemsAD is now removed as a dependency.
Instead of constructing a `LogDensityProblemAD.ADgradient` object, we now directly use `DifferentiationInterface` to calculate the gradient of the log density with respect to model parameters.
Note that if you wish, you can still construct an `ADgradient` out of a `LogDensityFunction` object (there is nothing preventing this).
However, in this version, `LogDensityFunction` now takes an extra AD type argument.
If this argument is not provided, the behaviour is exactly the same as before, i.e. you can calculate `logdensity` but not its gradient.
However, if you do pass an AD type, that will allow you to calculate the gradient as well.
You may thus find that it is easier to instead do this:
```julia
@model f() = ...
ldf = LogDensityFunction(f(); adtype=AutoForwardDiff())
```
This will return an object which satisfies the `LogDensityProblems` interface to first-order, i.e. you can now directly call both
```
LogDensityProblems.logdensity(ldf, params)
LogDensityProblems.logdensity_and_gradient(ldf, params)
```
without having to construct a separate `ADgradient` object.
If you prefer, you can also construct a new `LogDensityFunction` with a new AD type afterwards.
The model, varinfo, and context will be taken from the original `LogDensityFunction`:
```julia
@model f() = ...
ldf = LogDensityFunction(f()) # by default, no adtype set
ldf_with_ad = LogDensityFunction(ldf, AutoForwardDiff())
```
## 0.34.2
- Fixed bugs in ValuesAsInModelContext as well as DebugContext where underlying PrefixContexts were not being applied.
Expand Down
7 changes: 3 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down Expand Up @@ -51,15 +51,14 @@ Bijectors = "0.13.18, 0.14, 0.15"
ChainRulesCore = "1"
Compat = "4"
ConstructionBase = "1.5.4"
DifferentiationInterface = "0.6.41"
Distributions = "0.25"
DocStringExtensions = "0.9"
KernelAbstractions = "0.9.33"
EnzymeCore = "0.6 - 0.8"
ForwardDiff = "0.10"
JET = "0.9"
KernelAbstractions = "0.9.33"
LinearAlgebra = "1.6"
LogDensityProblems = "2"
LogDensityProblemsAD = "1.7.0"
MCMCChains = "6"
MacroTools = "0.5.6"
Mooncake = "0.4.95"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ logjoint

### LogDensityProblems.jl interface

The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by simply wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction`:
The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction`.

```@docs
DynamicPPL.LogDensityFunction
Expand Down
72 changes: 29 additions & 43 deletions ext/DynamicPPLForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,54 +1,40 @@
module DynamicPPLForwardDiffExt

if isdefined(Base, :get_extension)
using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD
using ForwardDiff
else
using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD
using ..ForwardDiff
end

getchunksize(::ADTypes.AutoForwardDiff{chunk}) where {chunk} = chunk

standardtag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true
standardtag(::ADTypes.AutoForwardDiff) = false

function LogDensityProblemsAD.ADgradient(
ad::ADTypes.AutoForwardDiff, ℓ::DynamicPPL.LogDensityFunction
)
θ = DynamicPPL.getparams(ℓ)
f = Base.Fix1(LogDensityProblems.logdensity, ℓ)

# Define configuration for ForwardDiff.
tag = if standardtag(ad)
ForwardDiff.Tag(DynamicPPL.DynamicPPLTag(), eltype(θ))
using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems
using ForwardDiff

# check if the AD type already has a tag
use_dynamicppl_tag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true
use_dynamicppl_tag(::ADTypes.AutoForwardDiff) = false

Check warning on line 8 in ext/DynamicPPLForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLForwardDiffExt.jl#L8

Added line #L8 was not covered by tests

function DynamicPPL.tweak_adtype(
ad::ADTypes.AutoForwardDiff{chunk_size},
::DynamicPPL.Model,
vi::DynamicPPL.AbstractVarInfo,
::DynamicPPL.AbstractContext,
) where {chunk_size}
params = vi[:]

# Use DynamicPPL tag to improve stack traces
# https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/
# NOTE: DifferentiationInterface disables tag checking if the
# tag inside the AutoForwardDiff type is not nothing. See
# https://github.com/JuliaDiff/DifferentiationInterface.jl/blob/1df562180bdcc3e91c885aa5f4162a0be2ced850/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl#L338-L350.
# So we don't currently need to override ForwardDiff.checktag as well.
tag = if use_dynamicppl_tag(ad)
ForwardDiff.Tag(DynamicPPL.DynamicPPLTag(), eltype(params))
else
ForwardDiff.Tag(f, eltype(θ))
ad.tag

Check warning on line 27 in ext/DynamicPPLForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLForwardDiffExt.jl#L27

Added line #L27 was not covered by tests
end
chunk_size = getchunksize(ad)

# Optimise chunk size according to size of model
chunk = if chunk_size == 0 || chunk_size === nothing
ForwardDiff.Chunk(θ)
ForwardDiff.Chunk(params)
else
ForwardDiff.Chunk(length(θ), chunk_size)
ForwardDiff.Chunk(length(params), chunk_size)

Check warning on line 34 in ext/DynamicPPLForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLForwardDiffExt.jl#L34

Added line #L34 was not covered by tests
end

return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; chunk, tag, x=θ)
end

# Allow Turing tag in gradient etc. calls of the log density function
function ForwardDiff.checktag(
::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}},
::DynamicPPL.LogDensityFunction,
::AbstractArray{W},
) where {V,W}
return true
end
function ForwardDiff.checktag(
::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}},
::Base.Fix1{typeof(LogDensityProblems.logdensity),<:DynamicPPL.LogDensityFunction},
::AbstractArray{W},
) where {V,W}
return true
return ADTypes.AutoForwardDiff(; chunksize=ForwardDiff.chunksize(chunk), tag=tag)
end

end # module
1 change: 0 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ using MacroTools: MacroTools
using ConstructionBase: ConstructionBase
using Accessors: Accessors
using LogDensityProblems: LogDensityProblems
using LogDensityProblemsAD: LogDensityProblemsAD

using LinearAlgebra: LinearAlgebra, Cholesky

Expand Down
1 change: 1 addition & 0 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ at which point it will return the sampler of that context.
getsampler(context::SamplingContext) = context.sampler
getsampler(context::AbstractContext) = getsampler(NodeTrait(context), context)
getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(context))
getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context")

Check warning on line 187 in src/contexts.jl

View check run for this annotation

Codecov / codecov/patch

src/contexts.jl#L187

Added line #L187 was not covered by tests

"""
struct DefaultContext <: AbstractContext end
Expand Down
Loading

0 comments on commit 90c7b26

Please sign in to comment.