-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
VarInfo with custom accumulators #744
Comments
Fleshing this idea out a bit. I'm imagining something like this: abstract type AbstractAccumulator end
struct LogPrior{T<:AbstractFloat} <: AbstractAccumulator
logp::T
end
struct LogLikelihood{T<:AbstractFloat} <: AbstractAccumulator
logp::T
end
struct NumProduce{T<:Integer} <: AbstractAccumulator
num::T
end
struct Orders{T<:Integer} <: AbstractAccumulator
orders::Dict{VarName,T}
end
struct VarInfo{Tmeta,Accs<:NTuple{N,AbstractAccumulator} where {N}} <: AbstractVarInfo
metadata::Tmeta
accs::Accs
end Functions like At the end of vi = @set vi.accs = map(acc -> accumulate_observe!!(acc, vi, left, right), vi.accs)
# or
vi = @set vi.accs = map(acc -> accumulate_assume!!(acc, vi, vn, right), vi.accs) which is only allowed to modify accumulate_observe!!(acc::LogPrior, vi, left, right) = acc
accumulate_assume!!(acc::LogPrior, vi, vn, right) = LogPrior(acc.logp + logpdf(right, vi[vn])) although we could provide the fallback defaults accumulate_observe!!(acc::AbstractAccumulator, vi, left, right) = accumulate!!(acc, vi, left, right)
accumulate_assume!!(acc::AbstractAccumulator, vi, vn, right) = accumulate!!(acc, vi, vn, right)
accumulate!!(acc::AbstractAccumulator, vi, left, right) = acc Separating Benefits of a design like this:
I do wonder whether the Ping @yebai and @penelopeysm for thoughts. |
Fully in favour of the idea, just a couple of incredibly minor technical details, which maybe shouldn't even be in this comment:
Possibly, but I think there is some entanglement of state anyway (exactly the instances that you described) so you'd end up with varinfo doing things to the accumulators in case it sees something. In which case we may as well put them together so that they have behaviour that's contained within one object
Maybe we could also define (but using a generated function for type stability): function accumulate_observe!!(accs::NTuple{N,AbstractAccumulator}, vi, left, right)
for acc in accs
accumulate_observe!!(acc, vi, left, right)
end
end and then at the end of
I'd prefer not to, because it's not hard to implement the required methods, and i think it's easier to enforce interface if it errors loudly when the interface isn't obeyed
Is it possible to do this also with a generated function? |
Just to have the check be done at compile time and produce efficient code? Yeah, I think so. Could also try doing it the simpler way (maybe with recursion) and see if constant folding compiles it away, before bringing out the generated functions.
Wouldn't it be easier, if I have an accumulator that only does something at observe, to only have to define that? Or if I have an accumulator that does the same thing for both observe and assume, then only have to define I'm not sure what you mean by making it easier to enforce the interface, in that the default implementations would guarantee that the interface is followed if you just subtype |
It's just one extra line to define a no-op
You can still do that by defining your own
I wasn't clear about this, apologies. You are right, having default implementations means that the interface will be satisfied even if the user doesn't try to. In that sense it actually makes it easier to satisfy the interface, in the sense that it is easier to have some function that returns something of the right type. However, having default methods make it easier to inadvertently introduce bugs because a user forgot to define the behaviour they wanted and now there's a silent default method that does the wrong thing. In that sense, it makes it harder to have the right function that returns something of the right type and value. Basically, I want to force the user to think and be explicit about what they're doing. Also, having defaults makes it harder to find the correct method that a function call is dispatching to. see e.g. current tilde pipeline for an extreme example of this. (I understand why we do it, but it doesn't make it any simpler) |
I'm not super hung up about it though if it's well documented (with a docs page explaining how to implement a new accumulator) |
A few though high-level initial thoughts / questions:
|
The log-likelihood/log-prior computation would happen in the The place to call function tilde_assume!!(::EmptyContext, right, vn, vi)
f = from_maybe_linked_internal_transform(vi, vn, right)
x, logjac = with_logabsdet_jacobian(f, getindex_internal(vi, vn))
vi = @set vi.accs = map(acc -> accumulate_assume!!(acc, r, logjac, vi, vn, right), vi.accs)
return r, vi
end where
Is you worry about the multitude of functions for getting data from various accumulators, like the equivalent of current
That's a good point, I hadn't thought about it that way. I would add though that there's also a semantic difference for someone reading the code, i.e. I expect the two to be used in very different ways.
I hadn't thought about this necessarily changing anything about how contexts can be used. There would just be fewer contexts that we would need to implement.
I don't understand what you mean here by "wrapped". |
@yebai and I have been discussing the idea of replacing the current
VarInfo
type with something more general as a wrapper aroundVarNamedVector
/Metadata
. The motivation is thatVarInfo
currently storeslogp
. This is ostensibly innocuous and straight-forward, but actually this is sometimes the log prior, sometimes log likelihood, sometimes log joint. Also some samplers hijack this so that even when you're sampling from log joint a sampler may actually use thelogp
to store the log likelihood. This can cause mix-ups, I've had bugs in the new Gibbs implementation where I've thought I have the likelihood but I actually have the prior.num_produce
andorder
, which are only used by particle methods (see Decide the fate ofVarInfo.num_produce
#661 for previous discussion)VarNamedVector
/Metadata
with the variable values.We would rather not have fields in
VarInfo
that are only used by specific samplers, and others that are used for different purposes at different times.The solution we've been thinking of would be some sort of wrapper type, probably nested, that wraps a
VarNamedVector
and allows one to store any extra information needed. If a particle sampler needsnum_produce
andorder
, it'll implement it's own wrapper type, and other samplers that only need e.g. the logjoint would use a different wrapper type.All of this could of course be stored in a context (because anything can be done with a context) but contexts are already too difficult to reason about, and should in my opinion only be used when a simpler tool won't cut it. Thus we would rather like an interface where somewhere in the tilde pipeline, maybe everywhere where we currently call
acclogp
, we would call some more genericstore_custom_varinfo_data
function, which each wrapper varinfo would then overload to storelogprior
/num_produce
/whatever_they_want_to_store
.It's not obvious though whether such an interface would be powerful enough to implement all the things we want to use it for. So we should probably start by making a list of all the things we want to use it for. This would include at least
@yebai probably has more thoughts to share on this.
The text was updated successfully, but these errors were encountered: