Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add adtype to DynamicPPL.Model #818

Open
wants to merge 1 commit into
base: release-0.35
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,22 @@ This release removes the feature of `VarInfo` where it kept track of which varia

**Other changes**

### Models now store AD backend types

In `DynamicPPL.Model`, an extra field `adtype::Union{Nothing,ADTypes.AbstractADType}` has been added. This field is used to store the AD backend which should be used when calculating gradients of the log density.

The field can be set by passing an extra argument to the `Model` constructor, but more realistically, it is likely that users will want to manually set the `adtype` field on an existing model:

```julia
@model f() = ...
model = f()
model_with_adtype = setadtype(model, AutoForwardDiff())
```

As far as `DynamicPPL.Model` is concerned, this field does not actually have any effect.
However, when a `LogDensityFunction` is constructed from said model, it will inherit the `adtype` field from the model.
See below for more information on `LogDensityFunction`.

### `LogDensityProblems` interface

LogDensityProblemsAD is now removed as a dependency.
Expand All @@ -136,7 +152,8 @@ Instead of constructing a `LogDensityProblemAD.ADgradient` object, we now direct
Note that if you wish, you can still construct an `ADgradient` out of a `LogDensityFunction` object (there is nothing preventing this).

However, in this version, `LogDensityFunction` now takes an extra AD type argument.
If this argument is not provided, the behaviour is exactly the same as before, i.e. you can calculate `logdensity` but not its gradient.
By default, this AD type is inherited from the model that the `LogDensityFunction` is constructed from.
If the model does not have an AD type, or if the argument is explicitly set to `nothing`, the behaviour is exactly the same as before, i.e. you can calculate `logdensity` but not its gradient.
However, if you do pass an AD type, that will allow you to calculate the gradient as well.
You may thus find that it is easier to instead do this:

Expand Down
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ export AbstractVarInfo,
getargnames,
extract_priors,
values_as_in_model,
setadtype,
# Samplers
Sampler,
SampleFromPrior,
Expand Down
2 changes: 1 addition & 1 deletion src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ struct LogDensityFunction{
model::Model,
varinfo::AbstractVarInfo=VarInfo(model),
context::AbstractContext=leafcontext(model.context);
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
adtype::Union{ADTypes.AbstractADType,Nothing}=model.adtype,
)
if adtype === nothing
prep = nothing
Expand Down
62 changes: 48 additions & 14 deletions src/model.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
"""
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext}
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext,TAD<:Union{Nothing,ADTypes.AbstractADType}}
f::F
args::NamedTuple{argnames,Targs}
defaults::NamedTuple{defaultnames,Tdefaults}
context::Ctx=DefaultContext()
adtype::TAD=nothing
end

A `Model` struct with model evaluation function of type `F`, arguments of names `argnames`
types `Targs`, default arguments of names `defaultnames` with types `Tdefaults`, missing
arguments `missings`, and evaluation context of type `Ctx`.
A `Model` struct contains the following fields:
- `f`, a model evaluation function of type `F`
- `args`, arguments of names `argnames` with types `Targs`
- `defaults`, default arguments of names `defaultnames` with types `Tdefaults`
- `context`, an evaluation context of type `Ctx`
- `adtype`, which can be nothing, or an automatic differentiation backend of type `TAD`

Its missing arguments are also stored as a type parameter `missings`.
Here `argnames`, `defaultargnames`, and `missings` are tuples of symbols, e.g. `(:a, :b)`.
`context` is by default `DefaultContext()`.

`context` is by default `DefaultContext()`, and `adtype` is by default `nothing`.

An argument with a type of `Missing` will be in `missings` by default. However, in
non-traditional use-cases `missings` can be defined differently. All variables in `missings`
Expand All @@ -33,12 +39,21 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition
Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,))
```
"""
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <:
AbstractProbabilisticProgram
struct Model{
F,
argnames,
defaultnames,
missings,
Targs,
Tdefaults,
Ctx<:AbstractContext,
TAD<:Union{Nothing,ADTypes.AbstractADType},
} <: AbstractProbabilisticProgram
f::F
args::NamedTuple{argnames,Targs}
defaults::NamedTuple{defaultnames,Tdefaults}
context::Ctx
adtype::TAD

@doc """
Model{missings}(f, args::NamedTuple, defaults::NamedTuple)
Expand All @@ -51,9 +66,10 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractConte
args::NamedTuple{argnames,Targs},
defaults::NamedTuple{defaultnames,Tdefaults},
context::Ctx=DefaultContext(),
) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx}
return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx}(
f, args, defaults, context
adtype::TAD=nothing,
) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx,TAD}
return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,TAD}(
f, args, defaults, context, adtype
)
end
end
Expand All @@ -71,22 +87,40 @@ model with different arguments.
args::NamedTuple{argnames,Targs},
defaults::NamedTuple{kwargnames,Tkwargs},
context::AbstractContext=DefaultContext(),
) where {F,argnames,Targs,kwargnames,Tkwargs}
adtype::TAD=nothing,
) where {F,argnames,Targs,kwargnames,Tkwargs,TAD}
missing_args = Tuple(
name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing
)
missing_kwargs = Tuple(
name for (name, typ) in zip(kwargnames, Tkwargs.types) if typ <: Missing
)
return :(Model{$(missing_args..., missing_kwargs...)}(f, args, defaults, context))
return :(Model{$(missing_args..., missing_kwargs...)}(
f, args, defaults, context, adtype
))
end

function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(); kwargs...)
return Model(f, args, NamedTuple(kwargs), context)
return Model(f, args, NamedTuple(kwargs), context, nothing)
end

"""
Model(model::Model, adtype::Union{Nothing,ADTypes.AbstractADType})

Create a new model with the same evaluation function and arguments as `model`, but with
automatic differentiation backend `adtype`.
"""
function Model(model::Model, adtype::Union{Nothing,ADTypes.AbstractADType})
return Model(model.f, model.args, model.defaults, model.context, adtype)
end

"""
contextualize(model::Model, context::AbstractContext)

Set the context of `model` to `context`.
"""
function contextualize(model::Model, context::AbstractContext)
return Model(model.f, model.args, model.defaults, context)
return Model(model.f, model.args, model.defaults, context, model.adtype)
end

"""
Expand Down
16 changes: 16 additions & 0 deletions test/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,22 @@ using Test, DynamicPPL, ADTypes, LogDensityProblems, ForwardDiff
end
end

@testset "AD type forwarding from model" begin
@model demo_simple() = x ~ Normal()
model = Model(demo_simple(), AutoForwardDiff())
ldf = DynamicPPL.LogDensityFunction(model)
# Check that the model's AD type is forwarded to the LDF
# Note: can't check ldf.adtype == AutoForwardDiff() because `tweak_adtype`
# modifies the underlying parameters a bit, so just check that it is still
# the correct backend package.
@test ldf.adtype isa AutoForwardDiff
# Check that the gradient can be evaluated on the resulting LDF
@test LogDensityProblems.capabilities(typeof(ldf)) ==
LogDensityProblems.LogDensityOrder{1}()
@test LogDensityProblems.logdensity(ldf, [1.0]) isa Any
@test LogDensityProblems.logdensity_and_gradient(ldf, [1.0]) isa Any
end

@testset "LogDensityFunction" begin
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
Expand Down
10 changes: 10 additions & 0 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,16 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
end
end

@testset "model adtype" begin
# Check that adtype can be set and unset
@model demo_adtype() = x ~ Normal()
adtype = AutoForwardDiff()
model = Model(demo_adtype(), adtype)
@test model.adtype == adtype
model = Model(model, nothing)
@test model.adtype === nothing
end

@testset "model de/conditioning" begin
@model function demo_condition()
x ~ Normal()
Expand Down
Loading