Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VarInfo with custom accumulators #744

Open
mhauru opened this issue Dec 10, 2024 · 7 comments
Open

VarInfo with custom accumulators #744

mhauru opened this issue Dec 10, 2024 · 7 comments

Comments

@mhauru
Copy link
Member

mhauru commented Dec 10, 2024

@yebai and I have been discussing the idea of replacing the current VarInfo type with something more general as a wrapper around VarNamedVector/Metadata. The motivation is that VarInfo currently stores

  • logp. This is ostensibly innocuous and straight-forward, but actually this is sometimes the log prior, sometimes log likelihood, sometimes log joint. Also some samplers hijack this so that even when you're sampling from log joint a sampler may actually use the logp to store the log likelihood. This can cause mix-ups, I've had bugs in the new Gibbs implementation where I've thought I have the likelihood but I actually have the prior.
  • num_produce and order, which are only used by particle methods (see Decide the fate of VarInfo.num_produce #661 for previous discussion)
  • a VarNamedVector/Metadata with the variable values.

We would rather not have fields in VarInfo that are only used by specific samplers, and others that are used for different purposes at different times.

The solution we've been thinking of would be some sort of wrapper type, probably nested, that wraps a VarNamedVector and allows one to store any extra information needed. If a particle sampler needs num_produce and order, it'll implement it's own wrapper type, and other samplers that only need e.g. the logjoint would use a different wrapper type.

All of this could of course be stored in a context (because anything can be done with a context) but contexts are already too difficult to reason about, and should in my opinion only be used when a simpler tool won't cut it. Thus we would rather like an interface where somewhere in the tilde pipeline, maybe everywhere where we currently call acclogp, we would call some more generic store_custom_varinfo_data function, which each wrapper varinfo would then overload to store logprior/num_produce/whatever_they_want_to_store.

It's not obvious though whether such an interface would be powerful enough to implement all the things we want to use it for. So we should probably start by making a list of all the things we want to use it for. This would include at least

  • log prior/log likelihood/log joint. I would like to store these separately too, to avoid mixing them up.
  • num_produce/order
  • what else?

@yebai probably has more thoughts to share on this.

@mhauru mhauru changed the title Nested VarInfo VarInfo with custom accumulators Feb 20, 2025
@mhauru
Copy link
Member Author

mhauru commented Feb 20, 2025

Fleshing this idea out a bit. I'm imagining something like this:

abstract type AbstractAccumulator end

struct LogPrior{T<:AbstractFloat} <: AbstractAccumulator
    logp::T
end

struct LogLikelihood{T<:AbstractFloat} <: AbstractAccumulator
    logp::T
end

struct NumProduce{T<:Integer} <: AbstractAccumulator
    num::T
end

struct Orders{T<:Integer} <: AbstractAccumulator
    orders::Dict{VarName,T}
end

struct VarInfo{Tmeta,Accs<:NTuple{N,AbstractAccumulator} where {N}} <: AbstractVarInfo
    metadata::Tmeta
    accs::Accs
end

Functions like getlogp(vi), getloglikelihood(vi), and getnumproduce(vi) would check whether an accumulator of the necessary type is in vi.accs. If not, they would error with something like "This VarInfo is not tracking the num produce variable". If yes, they would return the relevant value. Not sure what to do if there are multiple accumulators of the same type, maybe just ban that possiblity in an inner constructor.

At the end of tilde_observe!! and tilde_assume!! there would be something like

vi = @set vi.accs = map(acc -> accumulate_observe!!(acc, vi, left, right), vi.accs)
# or
vi = @set vi.accs = map(acc -> accumulate_assume!!(acc, vi, vn, right), vi.accs)

which is only allowed to modify acc, not the other arguments. The same way every context needs to define a method for tilde_observe and tilde_assume, every AbstractAccumulator would need to define methods for accumulate_observe!! and accumulate_assume!!, such as

accumulate_observe!!(acc::LogPrior, vi, left, right) = acc
accumulate_assume!!(acc::LogPrior, vi, vn, right) = LogPrior(acc.logp + logpdf(right, vi[vn]))

although we could provide the fallback defaults

accumulate_observe!!(acc::AbstractAccumulator, vi, left, right) = accumulate!!(acc, vi, left, right)
accumulate_assume!!(acc::AbstractAccumulator, vi, vn, right) = accumulate!!(acc, vi, vn, right)
accumulate!!(acc::AbstractAccumulator, vi, left, right) = acc

Separating logp into LogPrior and LogLikelihood is a somewhat orthogonal change, orthogonal to replacing vi.logp and vi.num_produce with vi.accs, but I don't see a reason to not accumulate them independently of each other and add them up in getlogp. setlogp!! and acclogp!! would no longer exist, you would always have to specify whether you are setting/adding to the log prior or log likelihood.

Benefits of a design like this:

  • We could get rid of DefaultContext, LikelihoodContext, and PriorContext, and hence the whole notion of a Leaf context, which would simplify context_implementations.jl a lot. We might also be able to get rid of e.g. PointwiseLogdensityContext, PriorExtractorContext, or ValuesAsInModelContext, replace them with accumulators as well.
  • If some particular use case of a model needs to keep track of something unusual, such as particle methods needing Order and NumProduce, that can be implemented outside of DPPL. We no longer need to hard code that stuff into VarInfo. The implementations are much lighter and less error-prone than a full-fledged context would be.

I do wonder whether the accs::NTuple{N,AbstractAccumulator} should actually not be a field of VarInfo, but rather a separate object that gets passed around in the tilde pipeline next to the varinfo. The two serve quite different purposes, and you could argue that a lot of functions of VarInfo should in fact reset the accumulators to maintain a consistent state: For instance, if you change the value of a variable, or subset a VarInfo, the accumulator values no longer match the values stored in the Metadata.

Ping @yebai and @penelopeysm for thoughts.

@penelopeysm
Copy link
Member

Fully in favour of the idea, just a couple of incredibly minor technical details, which maybe shouldn't even be in this comment:

I do wonder whether the accs::NTuple{N,AbstractAccumulator} should actually not be a field of VarInfo, but rather a separate object that gets passed around in the tilde pipeline next to the varinfo.

Possibly, but I think there is some entanglement of state anyway (exactly the instances that you described) so you'd end up with varinfo doing things to the accumulators in case it sees something. In which case we may as well put them together so that they have behaviour that's contained within one object

struct VarInfo{Tmeta,Accs<:NTuple{N,AbstractAccumulator} where {N}} <: AbstractVarInfo

Maybe we could also define (but using a generated function for type stability):

function accumulate_observe!!(accs::NTuple{N,AbstractAccumulator}, vi, left, right)
    for acc in accs
        accumulate_observe!!(acc, vi, left, right)
    end
end

and then at the end of tilde_{assume,observe}!! we can just write

vi = @set vi.accs = accumulate_{assume,observe}!!(vi.accs, vi, left, right)

we could provide the fallback defaults

I'd prefer not to, because it's not hard to implement the required methods, and i think it's easier to enforce interface if it errors loudly when the interface isn't obeyed

Functions like getlogp(vi), getloglikelihood(vi), and getnumproduce(vi) would check whether an accumulator of the necessary type is in vi.accs

Is it possible to do this also with a generated function?

@mhauru
Copy link
Member Author

mhauru commented Feb 21, 2025

Is it possible to do this also with a generated function?

Just to have the check be done at compile time and produce efficient code? Yeah, I think so. Could also try doing it the simpler way (maybe with recursion) and see if constant folding compiles it away, before bringing out the generated functions.

I'd prefer not to, because it's not hard to implement the required methods, and i think it's easier to enforce interface if it errors loudly when the interface isn't obeyed

Wouldn't it be easier, if I have an accumulator that only does something at observe, to only have to define that? Or if I have an accumulator that does the same thing for both observe and assume, then only have to define accumulate!!?

I'm not sure what you mean by making it easier to enforce the interface, in that the default implementations would guarantee that the interface is followed if you just subtype AbstractAccumulator, unless you explicitly go break it by e.g. defining accumulate_observe!!(::MyType, args...) that returns something that it shouldn't. With the default implementations in place the docs wouldn't say you have to implement accumulate_observe!! and accumulate_assume!!, but rather that you may implement one or both, or accumulate!!.

@penelopeysm
Copy link
Member

penelopeysm commented Feb 21, 2025

Wouldn't it be easier, if I have an accumulator that only does something at observe, to only have to define that?

It's just one extra line to define a no-op accumulate_assume.

Or if I have an accumulator that does the same thing for both observe and assume, then only have to define accumulate!!

You can still do that by defining your own accumulate!! (in fact, even better, give it a custom name, like accumulate_numproduce!!) and redirecting both acc_observe and acc_assume to that, for a total of two extra lines.

I'm not sure what you mean by making it easier to enforce the interface

I wasn't clear about this, apologies. You are right, having default implementations means that the interface will be satisfied even if the user doesn't try to. In that sense it actually makes it easier to satisfy the interface, in the sense that it is easier to have some function that returns something of the right type.

However, having default methods make it easier to inadvertently introduce bugs because a user forgot to define the behaviour they wanted and now there's a silent default method that does the wrong thing. In that sense, it makes it harder to have the right function that returns something of the right type and value.

Basically, I want to force the user to think and be explicit about what they're doing.

Also, having defaults makes it harder to find the correct method that a function call is dispatching to. see e.g. current tilde pipeline for an extreme example of this. (I understand why we do it, but it doesn't make it any simpler)

@penelopeysm
Copy link
Member

I'm not super hung up about it though if it's well documented (with a docs page explaining how to implement a new accumulator)

@torfjelde
Copy link
Member

A few though high-level initial thoughts / questions:

  1. Where in the tilde-pipeline would accumulate be called?
    1. Specifically, the purpose of PriorContext is that it short-circuits the tilde-pipeline to not even compute the log-likelihood. It's unclear to me how this would be done using the accumulator approach + if we could, it would lead to similar incompatibilities as we have now with varinfo, e.g. if I try to accumulate both Prior and Likelihood, both are short-circuiting the tilde-pipeline and nothing is done.
  2. Wrapping varinfo or related will be painful due to the number of functions you need to implement and test. Example: ThreadSafeVarInfo has bugged out sooo many ties times due to this.
  3. Only difference between varinfo and context right now is really just that:
    • varinfo is captured on return of the model, thus enabling making changes in an immutable way.
    • context is not captured, and so any changes must be made in-place.
  4. Does this change mean that contexts are not supposed to be mutated at all going forwards?
  5. I agree that the current contexts approach is quite general, and combining the contexts can be brittle. However, there's nothing stopping us from adding traits to the contexts (which has always been an ambition) which specifies where in the context stack it should be, e.g. the GibbsContext should be a parent to a leaf. So I'm not sure this accumulator approach, which AFAIK can sometimes be wrapped and sometimes not, solves this / improves this; such validation checks would still have to be implemented here as well. This also seems to be something you've also considered in #2424 👍

@mhauru
Copy link
Member Author

mhauru commented Feb 24, 2025

Where in the tilde-pipeline would accumulate be called?
Specifically, the purpose of PriorContext is that it short-circuits the tilde-pipeline to not even compute the log-likelihood. It's unclear to me how this would be done using the accumulator approach + if we could, it would lead to similar incompatibilities as we have now with varinfo, e.g. if I try to accumulate both Prior and Likelihood, both are short-circuiting the tilde-pipeline and nothing is done.

The log-likelihood/log-prior computation would happen in the accumulate_obssume!! function. ("Obssume" is a stand-in for either assume or observe.) Hence, if the varinfo e.g. carries a LogLikelihood accumulator but not a LogPrior accumulator, then in tilde_assume!! the there would only be a call to accumulate_assume!!(::LogLikelihood, args...) which is a no-op, and the log-prior would never be computed.

The place to call accumulate_obssume!! would be after sampling, if any, has been done, and the varinfo metadata has an up-to-date value for all variables. Pretty much the same place where acclogp_obssume!! is currently called. To avoid doing invlinking multiple times, I should probably modify the above signature for accumulate_obssume!! to take in the current value of the variable and logabsdetjac, so that e.g. tilde_assume!! would be something like

function tilde_assume!!(::EmptyContext, right, vn, vi)
    f = from_maybe_linked_internal_transform(vi, vn, right)
    x, logjac = with_logabsdet_jacobian(f, getindex_internal(vi, vn))
    vi = @set vi.accs = map(acc -> accumulate_assume!!(acc, r, logjac, vi, vn, right), vi.accs)
    return r, vi
end

where EmptyContext would be the only leaf context, the one that marks the end of the context stack. (Could call it DefaultContext, just not to be mixed with the current DefaultContext which should really be called JointContext.)

Wrapping varinfo or related will be painful due to the number of functions you need to implement and test. Example: ThreadSafeVarInfo has bugged out sooo many ties times due to this.

Is you worry about the multitude of functions for getting data from various accumulators, like the equivalent of current getlogp and get_num_produce? I'm hoping VarInfo to really only have one function for "get stuff from the accumulators", call it getacc, for which one of the arguments is the type of the accumulator to get data from. Functions like getlogprior and getlogjoint would be sugared versions of getacc, where the sugar can be implemented on the level of AbstractVarInfo. Thus I don't see a reason to have many more functions to implement for e.g. ThreadSafeVarInfo, implementing getacc would be enough.

Only difference between varinfo and context right now is really just that: [...]

That's a good point, I hadn't thought about it that way. I would add though that there's also a semantic difference for someone reading the code, i.e. I expect the two to be used in very different ways.

Does this change mean that contexts are not supposed to be mutated at all going forwards?

I hadn't thought about this necessarily changing anything about how contexts can be used. There would just be fewer contexts that we would need to implement.

So I'm not sure this accumulator approach, which AFAIK can sometimes be wrapped and sometimes not

I don't understand what you mean here by "wrapped".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants