Skip to content

Commit

Permalink
Re-add ForwardDiffExt (including tests)
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Feb 18, 2025
1 parent 05f1bce commit 566257e
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 44 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"

[extensions]
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLJETExt = ["JET"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLMooncakeExt = ["Mooncake"]
Expand Down
72 changes: 29 additions & 43 deletions ext/DynamicPPLForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,54 +1,40 @@
module DynamicPPLForwardDiffExt

if isdefined(Base, :get_extension)
using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD
using ForwardDiff
else
using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD
using ..ForwardDiff
end

getchunksize(::ADTypes.AutoForwardDiff{chunk}) where {chunk} = chunk

standardtag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true
standardtag(::ADTypes.AutoForwardDiff) = false

function LogDensityProblemsAD.ADgradient(
ad::ADTypes.AutoForwardDiff, ℓ::DynamicPPL.LogDensityFunction
)
θ = DynamicPPL.getparams(ℓ)
f = Base.Fix1(LogDensityProblems.logdensity, ℓ)

# Define configuration for ForwardDiff.
tag = if standardtag(ad)
ForwardDiff.Tag(DynamicPPL.DynamicPPLTag(), eltype(θ))
using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems
using ForwardDiff

# check if the AD type already has a tag
use_dynamicppl_tag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true
use_dynamicppl_tag(::ADTypes.AutoForwardDiff) = false

Check warning on line 8 in ext/DynamicPPLForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLForwardDiffExt.jl#L8

Added line #L8 was not covered by tests

function DynamicPPL.tweak_adtype(
ad::ADTypes.AutoForwardDiff{chunk_size},
::DynamicPPL.Model,
vi::DynamicPPL.AbstractVarInfo,
::DynamicPPL.AbstractContext,
) where {chunk_size}
params = vi[:]

# Use DynamicPPL tag to improve stack traces
# https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/
# NOTE: DifferentiationInterface disables tag checking if the
# tag inside the AutoForwardDiff type is not nothing. See
# https://github.com/JuliaDiff/DifferentiationInterface.jl/blob/1df562180bdcc3e91c885aa5f4162a0be2ced850/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl#L338-L350.
# So we don't currently need to override ForwardDiff.checktag as well.
tag = if use_dynamicppl_tag(ad)
ForwardDiff.Tag(DynamicPPL.DynamicPPLTag(), eltype(params))
else
ForwardDiff.Tag(f, eltype(θ))
ad.tag

Check warning on line 27 in ext/DynamicPPLForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLForwardDiffExt.jl#L27

Added line #L27 was not covered by tests
end
chunk_size = getchunksize(ad)

# Optimise chunk size according to size of model
chunk = if chunk_size == 0 || chunk_size === nothing
ForwardDiff.Chunk(θ)
ForwardDiff.Chunk(params)
else
ForwardDiff.Chunk(length(θ), chunk_size)
ForwardDiff.Chunk(length(params), chunk_size)

Check warning on line 34 in ext/DynamicPPLForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLForwardDiffExt.jl#L34

Added line #L34 was not covered by tests
end

return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; chunk, tag, x=θ)
end

# Allow Turing tag in gradient etc. calls of the log density function
function ForwardDiff.checktag(
::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}},
::DynamicPPL.LogDensityFunction,
::AbstractArray{W},
) where {V,W}
return true
end
function ForwardDiff.checktag(
::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}},
::Base.Fix1{typeof(LogDensityProblems.logdensity),<:DynamicPPL.LogDensityFunction},
::AbstractArray{W},
) where {V,W}
return true
return ADTypes.AutoForwardDiff(; chunksize=ForwardDiff.chunksize(chunk), tag=tag)
end

end # module
24 changes: 23 additions & 1 deletion src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ struct LogDensityFunction{
if adtype === nothing
prep = nothing
else
# Check support
# Make backend-specific tweaks to the adtype
adtype = tweak_adtype(adtype, model, varinfo, context)
# Check whether it is supported
is_supported(adtype) ||
@warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed."
# Get a set of dummy params to use for prep
Expand Down Expand Up @@ -227,6 +229,26 @@ LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f))

### Utils

"""
tweak_adtype(
adtype::ADTypes.AbstractADType,
model::Model,
varinfo::AbstractVarInfo,
context::AbstractContext
)
Return an 'optimised' form of the adtype. This is useful for doing
backend-specific optimisation of the adtype (e.g., for ForwardDiff, calculating
the chunk size: see the method override in `ext/DynamicPPLForwardDiffExt.jl`).
The model is passed as a parameter in case the optimisation depends on the
model.
By default, this just returns the input unchanged.
"""
tweak_adtype(
adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo, ::AbstractContext
) = adtype

"""
use_closure(adtype::ADTypes.AbstractADType)
Expand Down
32 changes: 32 additions & 0 deletions test/ext/DynamicPPLForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
module DynamicPPLForwardDiffExtTests

using DynamicPPL
using ADTypes: AutoForwardDiff
using ForwardDiff: ForwardDiff
using Distributions: MvNormal
using LinearAlgebra: I
using Test: @test, @testset

# get_chunksize(ad::AutoForwardDiff{chunk}) where {chunk} = chunk

@testset "ForwardDiff tweak_adtype" begin
MODEL_SIZE = 10
@model f() = x ~ MvNormal(zeros(MODEL_SIZE), I)
model = f()
varinfo = VarInfo(model)
context = DefaultContext()

@testset "Chunk size setting" for chunksize in (nothing, 0)
base_adtype = AutoForwardDiff(; chunksize=chunksize)
new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo, context)
@test new_adtype isa AutoForwardDiff{MODEL_SIZE}
end

@testset "Tag setting" begin
base_adtype = AutoForwardDiff()
new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo, context)
@test new_adtype.tag isa ForwardDiff.Tag{DynamicPPL.DynamicPPLTag}
end
end

end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ include("test_util.jl")
include("ext/DynamicPPLJETExt.jl")
end
@testset "ad" begin
include("ext/DynamicPPLForwardDiffExt.jl")
include("ext/DynamicPPLMooncakeExt.jl")
include("ad.jl")
end
Expand Down

0 comments on commit 566257e

Please sign in to comment.