Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

parameters macro for specifying trace manually (aka Turing's VarInfo) #2492

Open
yebai opened this issue Feb 24, 2025 · 2 comments
Open

parameters macro for specifying trace manually (aka Turing's VarInfo) #2492

yebai opened this issue Feb 24, 2025 · 2 comments

Comments

@yebai
Copy link
Member

yebai commented Feb 24, 2025

It has been known for a while that one can write Turing models as callable objects/functors, in the form of

struct MyModel
    a::Int
end
@model function (f::MyModel)(x)
    m ~ Normal(f.a, 1)
    return x ~ Normal(m, 1)
end

Suppose one further assumes that all model parameters (i.e. LHS of tilde) have to be a field of MyModel. In that case, the type definition of MyModel manually specifies the tracing data structure for Turing models.

However, the above example doesn't yet use MyModel as a substitute for TypedVarInfo. For that to happen, one can introduce a new @parameters macro to manually specify TypedVarInfo for models. Here is an example:

@parameters struct Demo{T}
   m::T::T
   x::Vector{T}
end

@model function (gdemo::Demo)()
   (;m, s², x) = gdemo
   m ~ Normal(0, sqrt(s²))
   s² ~ InverseGamma(2, 3)
   x .~ Normal(m, sqrt(s²))
end

Here, all the model parameters are passed to models transparently via the (;m, s², x) = gdemo. This avoids the use of VarInfo ultimately (we will need VarInfo for other purposes in a lightweight manner, e.g. for accumulating logp . If a sampler opts to leave a parameter unspecified (i.e. sample_from_prior / sample_from_proposal), the model evaluator can still sample these parameters from their conditional prior. Otherwise, samplers must specify each parameter with a concrete value or mark it as missing or predict so the model evaluator handles them accordingly. In addition, users could manually specify parameter constraints and bijector transforms on struct Demo (note this happens without running the model once). All these properties allow more clarity, reduce the complexity of DynamicPPL, and improve robustness by avoiding edge cases of VarInfo.

This approach offers many advantages, e.g.

A note on the practicals for DynamicPPL: as I mentioned in a previous meeting, one could refactor the DynamicPPL functionality for building TypedVarInfo and provide an API for users to define @parameters struct Demo{T} automatically. This would involve running the model once, collecting all the model parameters and their support constraints, and then automatically generating a Julia struct definition.

This new design (passing tracing data structure explicitly and assume it is immutable) will resolve the following:

@mhauru
Copy link
Member

mhauru commented Feb 24, 2025

As discussed previously, I like this proposal as long as we can find a way to make most @models generate their @parameters automatically, so that a) simple models will continue to work as they do now, b) we don't need to maintain the new struct-based VarInfo in parallel with the current one.

I wonder what the best syntax is for connecting a particular parameter struct with the variables of a model. With

@model function (gdemo::Demo)()
   (;m, s², x) = gdemo
   m ~ Normal(0, sqrt(s²))
   s² ~ InverseGamma(2, 3)
   x .~ Normal(m, sqrt(s²))
end

it's not obvious to me how we would match e.g. the m in m ~ Normal(0, sqrt(s²)) to the field gdemo.m. We could make gdemo.m return some type of our own for this purpose, but then we would lose the nice property that in Turing models during execution the random variables are just regular Julia variables with regular values like Float64. Relying on matching names of fields and variables feels a bit hacky and magical too.

@model function (gdemo::Demo)()
   gdemo.m ~ Normal(0, sqrt(s²))
   gdemo.~ InverseGamma(2, 3)
   gdemo.x .~ Normal(m, sqrt(s²))
end

would work but gets a bit verbose.

@yebai
Copy link
Member Author

yebai commented Feb 25, 2025

We could make gdemo.m return some type of our own for this purpose

A quick note for future discussion: we could use a type similar to ForwardDiff's Dual to carry extra sampler/inference information. This can be made very generic and robust. For probabilistic programming, Dual's implementation is often more straightforward than differentiable programming.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants