Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AD testing utilities #799

Draft
wants to merge 5 commits into
base: release-0.35
Choose a base branch
from
Draft

Add AD testing utilities #799

wants to merge 5 commits into from

Conversation

penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Feb 5, 2025

Overview

This is a, perhaps somewhat overdue, PR to add the functionality which I first wrote in https://github.com/penelopeysm/ModelTests.jl.

It provides two main functions:

  • DynamicPPL.TestUtils.AD.ad_ldp(::Model, ::Vector{<:Real}, ::AbstractADType, ::AbstractVarInfo)
  • and DynamicPPL.TestUtils.AD.ad_di (same signature)

which calculate the logdensity and its gradient of a given model at the specified parameters.

The former uses LogDensityProblemsAD.jl; the latter circumvents this and goes straight to DifferentiationInterface.jl. (The varinfo argument is used only to specify the type of varinfo used during the evaluation, its contents are ignored. I wish that there was a cleaner way to specify this, but as far as I can tell it's not possible, especially with SimpleVarInfo which often requires parameters to be initialised inside it.)

There are three auxiliary functions:

  • DynamicPPL.TestUtils.AD.make_function and DynamicPPL.TestUtils.AD.make_params generate a function f and an argument x, such that f(x) evaluates the logdensity of a model at the point x. These can, in theory, be passed to any autodiff library, even those which do not have integrations with LogDensityProblemsAD, DifferentiationInterface, or ADTypes.
  • DynamicPPL.TestUtils.AD.test_correctness provides a quick and easy wrapper to test a model plus a given set of AD backends (using the default VarInfo) for correctness.

Testing

Unfortunately, I didn't manage to make much use of test_correctness in the current DynamicPPL test suite. The main reason is because we are testing all the demo models with pretty much all possible variations of VarInfo.

I have made sure to not change the tests, but I'm not entirely convinced that we need to test AD with different combinations of VarInfo. The reason is because AD is used primarily during sampling, and there isn't really any way to actually call AbstractMCMC.sample on a model (cf. #606) with anything but the default VarInfo.

The use of non-default varinfos is, as far as I can tell, restricted to fairly small sections of the codebase (e.g. the loglikelihood / logjoint / logprior functions), and it's not clear to me that AD is used in any part of that. So, it seems to me that these are orthogonal concerns.

I've left it versatile for now to be on the safe side, but if people agree then I would be very happy to remove the varinfo argument from the functions above.

Miscellaneous bits

The names of the functions can be changed, I'm not super happy with them, but also I've stared at this code for too long so I'm not the best person to suggest names 😉

@penelopeysm penelopeysm changed the base branch from master to release-0.35 February 5, 2025 16:02
Copy link

codecov bot commented Feb 5, 2025

Codecov Report

Attention: Patch coverage is 0% with 21 lines in your changes missing coverage. Please review.

Project coverage is 3.96%. Comparing base (1366440) to head (489c40e).
Report is 3 commits behind head on release-0.35.

Files with missing lines Patch % Lines
src/test_utils/ad.jl 0.00% 21 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (1366440) and HEAD (489c40e). Click for more details.

HEAD has 21 uploads less than BASE
Flag BASE (1366440) HEAD (489c40e)
28 7
Additional details and impacted files
@@               Coverage Diff                @@
##           release-0.35    #799       +/-   ##
================================================
- Coverage         85.78%   3.96%   -81.82%     
================================================
  Files                36      37        +1     
  Lines              4207    4184       -23     
================================================
- Hits               3609     166     -3443     
- Misses              598    4018     +3420     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@coveralls
Copy link

coveralls commented Feb 5, 2025

Pull Request Test Coverage Report for Build 13218146414

Warning: This coverage report may be inaccurate.

This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.

Details

  • 0 of 21 (0.0%) changed or added relevant lines in 1 file are covered.
  • 2732 unchanged lines in 26 files lost coverage.
  • Overall coverage decreased (-81.8%) to 4.062%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/test_utils/ad.jl 0 21 0.0%
Files with Coverage Reduction New Missed Lines %
src/selector.jl 2 0.0%
src/varname.jl 6 0.0%
src/test_utils/model_interface.jl 7 0.0%
src/model_utils.jl 11 0.0%
src/test_utils/contexts.jl 12 0.0%
src/distribution_wrappers.jl 13 0.0%
src/logdensityfunction.jl 21 0.0%
src/test_utils/varinfo.jl 23 0.0%
src/submodel_macro.jl 26 0.0%
src/extract_priors.jl 30 0.0%
Totals Coverage Status
Change from base Build 13156283797: -81.8%
Covered Lines: 166
Relevant Lines: 4087

💛 - Coveralls

@penelopeysm
Copy link
Member Author

penelopeysm commented Feb 7, 2025

Discussions with @willtebbutt on this:

There are two perspectives for testing AD on models, which would have an impact on how we handle the present glut of VarInfo types:

  1. The AD developer's point of view: "whether I can differentiate DPPL models". - From Will's perspective, he would like to test on as many things as possible because sometimes they help to catch cases that are generally applicable to many differentiation targets.
  2. The Turing developer's point of view: "whether my models can be differentiated". - As explained above, there is only really one type of VarInfo that is routinely used in Turing sampling, namely TypedVarInfo. It therefore feels overkill to test AD on every possibility of VarInfo, but it would make sense to test on (1) TypedVarInfo; (2) some flavour of VarNamedVector, because we intend to switch Metadata -> VarNamedVector; (3) possibly some flavour of SimpleVarInfo.

Also, Will said he cares most about having the function f to be differentiated + the parameters x to differentiate it at. This is handled by make_function and make_params

Going forward it makes sense that we have something that looks like this:

"""
Return an appropriate varinfo for the model. It would be nice if we could pass
the varinfo_type as a type itself, but I'm not sure if that's possible.

Also, unsure how to handle cases where a given varinfo type cannot be
constructed for a given model, e.g. SimpleVarInfo{NamedTuple} cannot handle
models with complex varnames.

This function could also take a vector of params and/or an rng seed to control
how the values in the varinfo are initialised.
"""
function construct_varinfo(varinfo_type::Symbol, model::Model) end

"""
All possible varinfo types.
"""
const ALL_VARINFO_TYPES = [:typed_vi, :untyped_vi, :vnv, :svi_nt, :svi_dict...]

"""
All sensible varinfo types.
"""
const BASIC_VARINFO_TYPES = [:typed_vi, :vnv, :svi_nt]

"""
This is largely already implemented except that the second parameter requires
a varinfo object rather than a specification of its type.
"""
function make_function(model::Model, varinfo_type::Symbol)

"""
Already implemented
"""
function make_params(model::Model)

"""
Test a model with all specified varinfo_types and all AD backends in adtypes.

If reference_adtype is not passed, it just checks that AD runs without errors.
If it is passed, additionally use that to also check for correctness.
"""
function test_model_ad(
    model::Model, 
    adtypes::Vector{<:AbstractADType};
    varinfo_types::Vector{Symbol}=BASIC_VARINFO_TYPES,
    reference_adtype::Union{Nothing,AbstractADType}=nothing,
)

On DPPL's side, we can then do this:

const TESTED_ADTYPES = [AutoForwardDiff, ...]

for model in DEMO_MODELS
    test_model_ad(model, TESTED_ADTYPES[2:end]; reference_adtype=TESTED_ADTYPES[1])
end

On the Mooncake side, Will can do this:

for model in DEMO_MODELS
    test_model_ad(model, AutoMooncake(..); varinfo_types=ALL_VARINFO_TYPES)
end

Remaining question

DynamicPPL's logdensity_and_gradient uses LogDensityProblemsAD. Thus, technically, testing whether logdensity is differentiable (which is what the above does) is not the same as testing whether logdensity_and_gradient runs correctly, even though the latter does (eventually) try to differentiate logdensity.

Thus, in principle, one might need to add an additional parameter use_ldpad::Bool to control which route is taken. From the DPPL side we would set this to true because we want to test the actual code being used in DPPL. From Will's side he would probably set it to false to avoid the extra indirection.

We could unify these two scenarios by cutting out LogDensityProblemsAD ourselves 👀

penelopeysm and others added 3 commits February 8, 2025 18:14
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@penelopeysm
Copy link
Member Author

I think before coming back to this, I'll first check whether dropping LogDensityProblemsAD results in anything bad. I think our general sense is that we should be able to cut it out without too many problems, and if that expectation is borne out by testing, then we should probably do it first.

@willtebbutt
Copy link
Member

Test failures are due to a bad rule so, happily, nothing systemic. I've opened a Mooncake issue to track, and will try to address in the morning so that we can get CI on this PR passing! compintell/Mooncake.jl#470

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants