Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for ProductNamedTupleDistribution #801

Open
sunxd3 opened this issue Feb 5, 2025 · 1 comment
Open

Add support for ProductNamedTupleDistribution #801

sunxd3 opened this issue Feb 5, 2025 · 1 comment
Labels
enhancement New feature or request

Comments

@sunxd3
Copy link
Member

sunxd3 commented Feb 5, 2025

JuliaStats/Distributions.jl#1803 introduced NamedTupleVariate.

The same PR also added ProductNamedTupleDistribution which is a product distribution that returns NamedTuple typed data when calling rand. For ProductNamedTupleDistribution, all the components are independent from each other. An example of it is Distributions.ProductNamedTupleDistribution.

We might want to add support for ProductNamedTupleDistribution and other NamedTupleVariate (that we and others define). This would involve making changes to the syntax, essentially introducing a stochastic version of unpacking.

For instance,

@model function demo()
    x, y ~ product_distribution((x=Normal(), y=Dirichlet([2, 4])))
    return x, y
end

I think updating model macro to allow Expr(:tuple, ...) on the LHS shouldn't be too hard. However it might cause some trouble with the tilde pipeline (@mhauru).

We should also think of some corner cases to avoid complications similar to dot-tilde.

@penelopeysm
Copy link
Member

penelopeysm commented Feb 17, 2025

It's the observe statements that are really difficult to sort out, and in particular, something where part of the distribution is assumed and part of it is observed

@model function f()
    1.5, y ~ product_distribution((x = Normal(), y = Normal())
end

or

@model function f(x)
    x, y ~ product_distribution((x = Normal(), y = Normal())
end

The reason is because we've always assumed that the lhs of a tilde is a single varname and that varname is either observed (then we head down tilde_observe!!) or it isn't (then we head down tilde_assume!!).

One way of getting around this complexity is to do it properly and have that everything of a collection of variables on the lhs can be separately conditioned on 😄

Another way of getting round it would be to forbid syntax like that and make people do

@model function f()
    1.5 ~ Normal()
    y ~ Normal()
end

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants