-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add AD testing utilities #799
base: release-0.35
Are you sure you want to change the base?
Conversation
eac98e1
to
ded7fa3
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## release-0.35 #799 +/- ##
================================================
- Coverage 85.78% 3.96% -81.82%
================================================
Files 36 37 +1
Lines 4207 4184 -23
================================================
- Hits 3609 166 -3443
- Misses 598 4018 +3420 ☔ View full report in Codecov by Sentry. |
Pull Request Test Coverage Report for Build 13218146414Warning: This coverage report may be inaccurate.This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.
Details
💛 - Coveralls |
ded7fa3
to
dac729e
Compare
Discussions with @willtebbutt on this: There are two perspectives for testing AD on models, which would have an impact on how we handle the present glut of VarInfo types:
Also, Will said he cares most about having the function Going forward it makes sense that we have something that looks like this: """
Return an appropriate varinfo for the model. It would be nice if we could pass
the varinfo_type as a type itself, but I'm not sure if that's possible.
Also, unsure how to handle cases where a given varinfo type cannot be
constructed for a given model, e.g. SimpleVarInfo{NamedTuple} cannot handle
models with complex varnames.
This function could also take a vector of params and/or an rng seed to control
how the values in the varinfo are initialised.
"""
function construct_varinfo(varinfo_type::Symbol, model::Model) end
"""
All possible varinfo types.
"""
const ALL_VARINFO_TYPES = [:typed_vi, :untyped_vi, :vnv, :svi_nt, :svi_dict...]
"""
All sensible varinfo types.
"""
const BASIC_VARINFO_TYPES = [:typed_vi, :vnv, :svi_nt]
"""
This is largely already implemented except that the second parameter requires
a varinfo object rather than a specification of its type.
"""
function make_function(model::Model, varinfo_type::Symbol)
"""
Already implemented
"""
function make_params(model::Model)
"""
Test a model with all specified varinfo_types and all AD backends in adtypes.
If reference_adtype is not passed, it just checks that AD runs without errors.
If it is passed, additionally use that to also check for correctness.
"""
function test_model_ad(
model::Model,
adtypes::Vector{<:AbstractADType};
varinfo_types::Vector{Symbol}=BASIC_VARINFO_TYPES,
reference_adtype::Union{Nothing,AbstractADType}=nothing,
) On DPPL's side, we can then do this: const TESTED_ADTYPES = [AutoForwardDiff, ...]
for model in DEMO_MODELS
test_model_ad(model, TESTED_ADTYPES[2:end]; reference_adtype=TESTED_ADTYPES[1])
end On the Mooncake side, Will can do this: for model in DEMO_MODELS
test_model_ad(model, AutoMooncake(..); varinfo_types=ALL_VARINFO_TYPES)
end Remaining questionDynamicPPL's Thus, in principle, one might need to add an additional parameter We could unify these two scenarios by cutting out LogDensityProblemsAD ourselves 👀 |
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
I think before coming back to this, I'll first check whether dropping LogDensityProblemsAD results in anything bad. I think our general sense is that we should be able to cut it out without too many problems, and if that expectation is borne out by testing, then we should probably do it first. |
Test failures are due to a bad rule so, happily, nothing systemic. I've opened a Mooncake issue to track, and will try to address in the morning so that we can get CI on this PR passing! compintell/Mooncake.jl#470 |
Overview
This is a, perhaps somewhat overdue, PR to add the functionality which I first wrote in https://github.com/penelopeysm/ModelTests.jl.
It provides two main functions:
DynamicPPL.TestUtils.AD.ad_ldp(::Model, ::Vector{<:Real}, ::AbstractADType, ::AbstractVarInfo)
DynamicPPL.TestUtils.AD.ad_di
(same signature)which calculate the logdensity and its gradient of a given model at the specified parameters.
The former uses LogDensityProblemsAD.jl; the latter circumvents this and goes straight to DifferentiationInterface.jl. (The varinfo argument is used only to specify the type of varinfo used during the evaluation, its contents are ignored. I wish that there was a cleaner way to specify this, but as far as I can tell it's not possible, especially with SimpleVarInfo which often requires parameters to be initialised inside it.)
There are three auxiliary functions:
DynamicPPL.TestUtils.AD.make_function
andDynamicPPL.TestUtils.AD.make_params
generate a functionf
and an argumentx
, such thatf(x)
evaluates the logdensity of a model at the pointx
. These can, in theory, be passed to any autodiff library, even those which do not have integrations with LogDensityProblemsAD, DifferentiationInterface, or ADTypes.DynamicPPL.TestUtils.AD.test_correctness
provides a quick and easy wrapper to test a model plus a given set of AD backends (using the default VarInfo) for correctness.Testing
Unfortunately, I didn't manage to make much use of
test_correctness
in the current DynamicPPL test suite. The main reason is because we are testing all the demo models with pretty much all possible variations of VarInfo.I have made sure to not change the tests, but I'm not entirely convinced that we need to test AD with different combinations of VarInfo. The reason is because AD is used primarily during sampling, and there isn't really any way to actually call
AbstractMCMC.sample
on a model (cf. #606) with anything but the defaultVarInfo
.The use of non-default varinfos is, as far as I can tell, restricted to fairly small sections of the codebase (e.g. the
loglikelihood
/logjoint
/logprior
functions), and it's not clear to me that AD is used in any part of that. So, it seems to me that these are orthogonal concerns.I've left it versatile for now to be on the safe side, but if people agree then I would be very happy to remove the varinfo argument from the functions above.
Miscellaneous bits
The names of the functions can be changed, I'm not super happy with them, but also I've stared at this code for too long so I'm not the best person to suggest names 😉