Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] submodel explorations #815

Draft
wants to merge 17 commits into
base: release-0.35
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 69 additions & 1 deletion HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,74 @@

**Breaking**

### `.~` right hand side must be a univariate distribution

Previously we allowed statements like

```julia
x .~ [Normal(), Gamma()]
```

where the right hand side of a `.~` was an array of distributions, and ones like

```julia
x .~ MvNormal(fill(0.0, 2), I)
```

where the right hand side was a multivariate distribution.

These are no longer allowed. The only things allowed on the right hand side of a `.~` statement are univariate distributions, such as

```julia
x = Array{Float64,3}(undef, 2, 3, 4)
x .~ Normal()
```

The reasons for this are internal code simplification and the fact that broadcasting where both sides are multidimensional but of different dimensions is typically confusing to read.

If the right hand side and the left hand side have the same dimension, one can simply use `~`. Arrays of distributions can be replaced with `product_distribution`. So instead of

```julia
x .~ [Normal(), Gamma()]
x .~ Normal.(y)
x .~ MvNormal(fill(0.0, 2), I)
```

do

```julia
x ~ product_distribution([Normal(), Gamma()])
x ~ product_distribution(Normal.(y))
x ~ MvNormal(fill(0.0, 2), I)
```

This is often more performant as well. Note that using `~` rather than `.~` does change the internal storage format a bit: With `.~` `x[i]` are stored as separate variables, with `~` as a single multivariate variable `x`. In most cases this does not change anything for the user, but if it does cause issues, e.g. if you are dealing with `VarInfo` objects directly and need to keep the old behavior, you can always expand into a loop, such as

```julia
dists = Normal.(y)
for i in 1:length(dists)
x[i] ~ dists[i]
end
```

Cases where the right hand side is of a different dimension than the left hand side, and neither is a scalar, must be replaced with a loop. For example,

```julia
x = Array{Float64,3}(undef, 2, 3, 4)
x .~ MvNormal(fill(0, 2), I)
```

should be replaced with something like

```julia
x = Array{Float64,3}(2, 3, 4)
for i in 1:3, j in 1:4
x[:, i, j] ~ MvNormal(fill(0, 2), I)
end
```

This release also completely rewrites the internal implementation of `.~`, where from now on all `.~` statements are turned into loops over `~` statements at macro time. However, the only breaking aspect of this change is the above change to what's allowed on the right hand side.

### Remove indexing by samplers

This release removes the feature of `VarInfo` where it kept track of which variable was associated with which sampler. This means removing all user-facing methods where `VarInfo`s where being indexed with samplers. In particular,
Expand All @@ -14,7 +82,7 @@ This release removes the feature of `VarInfo` where it kept track of which varia
- `unflatten` no longer accepts a sampler as an argument
- `eltype(::VarInfo)` no longer accepts a sampler as an argument
- `keys(::VarInfo)` no longer accepts a sampler as an argument
- `VarInfo(::VarInfo, ::Sampler, ::AbstactVector)` no longer accepts the sampler argument.
- `VarInfo(::VarInfo, ::Sampler, ::AbstractVector)` no longer accepts the sampler argument.

### Reverse prefixing order

Expand Down
4 changes: 0 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -35,7 +34,6 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[extensions]
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
Expand All @@ -44,7 +42,6 @@ DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLJETExt = ["JET"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLMooncakeExt = ["Mooncake"]
DynamicPPLZygoteRulesExt = ["ZygoteRules"]

[compat]
ADTypes = "1"
Expand Down Expand Up @@ -74,5 +71,4 @@ OrderedCollections = "1"
Random = "1.6"
Requires = "1"
Test = "1.6"
ZygoteRules = "0.2"
julia = "1.10"
2 changes: 0 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -447,10 +447,8 @@ DynamicPPL.Experimental.is_suitable_varinfo

```@docs
tilde_assume
dot_tilde_assume
```

```@docs
tilde_observe
dot_tilde_observe
```
25 changes: 0 additions & 25 deletions ext/DynamicPPLZygoteRulesExt.jl

This file was deleted.

4 changes: 0 additions & 4 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,9 @@ export AbstractVarInfo,
PrefixContext,
ConditionContext,
assume,
dot_assume,
observe,
dot_observe,
tilde_assume,
tilde_observe,
dot_tilde_assume,
dot_tilde_observe,
# Pseudo distributions
NamedDist,
NoDist,
Expand Down
Loading
Loading