Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test Enzyme on demo models / varinfos #813

Draft
wants to merge 8 commits into
base: release-0.35
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## 0.35.0

**Breaking**
**Breaking changes**

### Remove indexing by samplers

Expand Down Expand Up @@ -49,6 +49,51 @@ This release removes the feature of `VarInfo` where it kept track of which varia

This change also affects sampling in Turing.jl.

### `LogDensityFunction` argument order

- The method `LogDensityFunction(varinfo, model, context)` has been removed.
The only accepted order is `LogDensityFunction(model, varinfo, context; adtype)`.
(For an explanation of `adtype`, see below.)
The varinfo and context arguments are both still optional.

**Other changes**

### `LogDensityProblems` interface

LogDensityProblemsAD is now removed as a dependency.
Instead of constructing a `LogDensityProblemAD.ADgradient` object, we now directly use `DifferentiationInterface` to calculate the gradient of the log density with respect to model parameters.

Note that if you wish, you can still construct an `ADgradient` out of a `LogDensityFunction` object (there is nothing preventing this).

However, in this version, `LogDensityFunction` now takes an extra AD type argument.
If this argument is not provided, the behaviour is exactly the same as before, i.e. you can calculate `logdensity` but not its gradient.
However, if you do pass an AD type, that will allow you to calculate the gradient as well.
You may thus find that it is easier to instead do this:

```julia
@model f() = ...

ldf = LogDensityFunction(f(); adtype=AutoForwardDiff())
```

This will return an object which satisfies the `LogDensityProblems` interface to first-order, i.e. you can now directly call both

```
LogDensityProblems.logdensity(ldf, params)
LogDensityProblems.logdensity_and_gradient(ldf, params)
```

without having to construct a separate `ADgradient` object.

If you prefer, you can also use `setadtype` to tack on the AD type afterwards:

```julia
@model f() = ...

ldf = LogDensityFunction(f()) # by default, no adtype set
ldf_with_ad = setadtype(ldf, AutoForwardDiff())
```

## 0.34.2

- Fixed bugs in ValuesAsInModelContext as well as DebugContext where underlying PrefixContexts were not being applied.
Expand Down
9 changes: 3 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -29,7 +29,6 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Expand All @@ -38,7 +37,6 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[extensions]
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLJETExt = ["JET"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLMooncakeExt = ["Mooncake"]
Expand All @@ -54,15 +52,14 @@ Bijectors = "0.13.18, 0.14, 0.15"
ChainRulesCore = "1"
Compat = "4"
ConstructionBase = "1.5.4"
DifferentiationInterface = "0.6.39"
Distributions = "0.25"
DocStringExtensions = "0.9"
KernelAbstractions = "0.9.33"
EnzymeCore = "0.6 - 0.8"
ForwardDiff = "0.10"
JET = "0.9"
KernelAbstractions = "0.9.33"
LinearAlgebra = "1.6"
LogDensityProblems = "2"
LogDensityProblemsAD = "1.7.0"
MCMCChains = "6"
MacroTools = "0.5.6"
Mooncake = "0.4.59"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ logjoint

### LogDensityProblems.jl interface

The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by simply wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction`:
The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction`.

```@docs
DynamicPPL.LogDensityFunction
Expand Down
54 changes: 0 additions & 54 deletions ext/DynamicPPLForwardDiffExt.jl

This file was deleted.

1 change: 0 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ using MacroTools: MacroTools
using ConstructionBase: ConstructionBase
using Accessors: Accessors
using LogDensityProblems: LogDensityProblems
using LogDensityProblemsAD: LogDensityProblemsAD

using LinearAlgebra: LinearAlgebra, Cholesky

Expand Down
1 change: 1 addition & 0 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ at which point it will return the sampler of that context.
getsampler(context::SamplingContext) = context.sampler
getsampler(context::AbstractContext) = getsampler(NodeTrait(context), context)
getsampler(::IsParent, context::AbstractContext) = getsampler(childcontext(context))
getsampler(::IsLeaf, ::AbstractContext) = error("No sampler found in context")

"""
struct DefaultContext <: AbstractContext end
Expand Down
Loading
Loading