Skip to content

Commit

Permalink
Add Enzyme to AD tests
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Feb 16, 2025
1 parent 7060896 commit cdbde76
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 49 deletions.
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Expand Down
12 changes: 8 additions & 4 deletions test/ad.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using DynamicPPL: LogDensityFunction
using Enzyme: Enzyme
using EnzymeCore: set_runtime_activity, Forward, Reverse

@testset "Automatic differentiation" begin
@testset verbose = true "Automatic differentiation" begin
@testset "Unsupported backends" begin
@model demo() = x ~ Normal()
@test_logs (:warn, r"not officially supported") LogDensityFunction(
Expand All @@ -23,9 +25,11 @@ using DynamicPPL: LogDensityFunction
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x)

@testset "$adtype" for adtype in [
AutoReverseDiff(; compile=false),
AutoReverseDiff(; compile=true),
AutoMooncake(; config=nothing),
AutoEnzyme(; mode=set_runtime_activity(Forward, true)),
AutoEnzyme(; mode=set_runtime_activity(Reverse, true)),
# AutoReverseDiff(; compile=false),
# AutoReverseDiff(; compile=true),
# AutoMooncake(; config=nothing),
]
@info "Testing AD on: $(m.f) - $(short_varinfo_name(varinfo)) - $adtype"

Expand Down
90 changes: 45 additions & 45 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,56 +45,56 @@ include("test_util.jl")
# groups are chosen to make both groups take roughly the same amount of
# time, but beyond that there is no particular reason for the split.
if GROUP == "All" || GROUP == "Group1"
include("utils.jl")
include("compiler.jl")
include("varnamedvector.jl")
include("varinfo.jl")
include("simple_varinfo.jl")
include("model.jl")
include("sampler.jl")
include("independence.jl")
include("distribution_wrappers.jl")
include("logdensityfunction.jl")
include("linking.jl")
include("serialization.jl")
include("pointwise_logdensities.jl")
include("lkj.jl")
include("deprecated.jl")
# include("utils.jl")
# include("compiler.jl")
# include("varnamedvector.jl")
# include("varinfo.jl")
# include("simple_varinfo.jl")
# include("model.jl")
# include("sampler.jl")
# include("independence.jl")
# include("distribution_wrappers.jl")
# include("logdensityfunction.jl")
# include("linking.jl")
# include("serialization.jl")
# include("pointwise_logdensities.jl")
# include("lkj.jl")
# include("deprecated.jl")
end

if GROUP == "All" || GROUP == "Group2"
include("contexts.jl")
include("context_implementations.jl")
include("threadsafe.jl")
include("debug_utils.jl")
@testset "compat" begin
include(joinpath("compat", "ad.jl"))
end
@testset "extensions" begin
include("ext/DynamicPPLMCMCChainsExt.jl")
include("ext/DynamicPPLJETExt.jl")
end
# include("contexts.jl")
# include("context_implementations.jl")
# include("threadsafe.jl")
# include("debug_utils.jl")
# @testset "compat" begin
# include(joinpath("compat", "ad.jl"))
# end
# @testset "extensions" begin
# include("ext/DynamicPPLMCMCChainsExt.jl")
# include("ext/DynamicPPLJETExt.jl")
# end
@testset "ad" begin
include("ext/DynamicPPLMooncakeExt.jl")
# include("ext/DynamicPPLMooncakeExt.jl")
include("ad.jl")
end
@testset "prob and logprob macro" begin
@test_throws ErrorException prob"..."
@test_throws ErrorException logprob"..."
end
@testset "doctests" begin
DocMeta.setdocmeta!(
DynamicPPL,
:DocTestSetup,
:(using DynamicPPL, Distributions);
recursive=true,
)
doctestfilters = [
# Ignore the source of a warning in the doctest output, since this is dependent on host.
# This is a line that starts with "└ @ " and ends with the line number.
r"└ @ .+:[0-9]+",
]
doctest(DynamicPPL; manual=false, doctestfilters=doctestfilters)
end
# @testset "prob and logprob macro" begin
# @test_throws ErrorException prob"..."
# @test_throws ErrorException logprob"..."
# end
# @testset "doctests" begin
# DocMeta.setdocmeta!(
# DynamicPPL,
# :DocTestSetup,
# :(using DynamicPPL, Distributions);
# recursive=true,
# )
# doctestfilters = [
# # Ignore the source of a warning in the doctest output, since this is dependent on host.
# # This is a line that starts with "└ @ " and ends with the line number.
# r"└ @ .+:[0-9]+",
# ]
# doctest(DynamicPPL; manual=false, doctestfilters=doctestfilters)
# end
end
end

0 comments on commit cdbde76

Please sign in to comment.