Skip to content

Commit

Permalink
Add ForwardDiffExt tests back
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Feb 18, 2025
1 parent 11d0e8a commit 8f8018a
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
32 changes: 32 additions & 0 deletions test/ext/DynamicPPLForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
module DynamicPPLForwardDiffExtTests

using DynamicPPL
using ADTypes: AutoForwardDiff
using ForwardDiff: ForwardDiff
using Distributions: MvNormal
using LinearAlgebra: I
using Test: @test, @testset

# get_chunksize(ad::AutoForwardDiff{chunk}) where {chunk} = chunk

@testset "ForwardDiff tweak_adtype" begin
MODEL_SIZE = 10
@model f() = x ~ MvNormal(zeros(MODEL_SIZE), I)
model = f()
varinfo = VarInfo(model)
context = DefaultContext()

@testset "Chunk size setting" for chunksize in (nothing, 0)
base_adtype = AutoForwardDiff(; chunksize=chunksize)
new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo, context)
@test new_adtype isa AutoForwardDiff{MODEL_SIZE}
end

@testset "Tag setting" begin
base_adtype = AutoForwardDiff()
new_adtype = DynamicPPL.tweak_adtype(base_adtype, model, varinfo, context)
@test new_adtype.tag isa ForwardDiff.Tag{DynamicPPL.DynamicPPLTag}
end
end

end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ include("test_util.jl")
include("ext/DynamicPPLJETExt.jl")
end
@testset "ad" begin
include("ext/DynamicPPLForwardDiffExt.jl")
include("ext/DynamicPPLMooncakeExt.jl")
include("ad.jl")
end
Expand Down

0 comments on commit 8f8018a

Please sign in to comment.