Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change to DifferentiationInterface #46

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
name: CI

on:
push:
branches:
- main
tags: ['*']
pull_request:

concurrency:
# Skip intermediate builds: always.
# Cancel intermediate builds: only if it is a pull request build.
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}

jobs:
test:
name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ matrix.arch }} - ${{ github.event_name }}
Expand All @@ -19,17 +22,17 @@ jobs:
matrix:
version:
- '1'
- '1.6'
- 'min'
os:
- ubuntu-latest
arch:
- x64
steps:
- uses: actions/checkout@v3
- uses: julia-actions/setup-julia@v1
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: julia-actions/cache@v1
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
36 changes: 7 additions & 29 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
name = "NormalizingFlows"
uuid = "50e4474d-9f12-44b7-af7a-91ab30ff6256"
version = "0.1.1"
version = "0.2.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -15,36 +15,14 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
NormalizingFlowsEnzymeExt = "Enzyme"
NormalizingFlowsForwardDiffExt = "ForwardDiff"
NormalizingFlowsReverseDiffExt = "ReverseDiff"
NormalizingFlowsZygoteExt = "Zygote"

[compat]
ADTypes = "0.1, 0.2, 1"
Bijectors = "0.12.6, 0.13, 0.14"
DiffResults = "1"
ADTypes = "1"
Bijectors = "0.12.6, 0.13, 0.14, 0.15"
DifferentiationInterface = "0.6"
Distributions = "0.25"
DocStringExtensions = "0.9"
Enzyme = "0.11, 0.12, 0.13"
ForwardDiff = "0.10.25"
Optimisers = "0.2.16, 0.3"
Optimisers = "0.2.16, 0.3, 0.4"
ProgressMeter = "1.0.0"
Requires = "1"
ReverseDiff = "1.14"
StatsBase = "0.33, 0.34"
Zygote = "0.6"
julia = "1.6"

[extras]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
julia = "1.10"
11 changes: 2 additions & 9 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ For example of Gaussian VI, we can construct the flow as follows:
```@julia
using Distributions, Bijectors
T= Float32
@leaf MvNormal # to prevent params in q₀ from being optimized
q₀ = MvNormal(zeros(T, 2), ones(T, 2))
flow = Bijectors.transformed(q₀, Bijectors.Shift(zeros(T,2)) ∘ Bijectors.Scale(ones(T, 2)))
```
Expand All @@ -23,7 +24,7 @@ To train the Gaussian VI targeting at distirbution $p$ via ELBO maiximization, w
using NormalizingFlows

sample_per_iter = 10
flow_trained, stats, _ = train_flow(
flow_trained, stats, _ , _ = train_flow(
elbo,
flow,
logp,
Expand Down Expand Up @@ -83,11 +84,3 @@ NormalizingFlows.loglikelihood
```@docs
NormalizingFlows.optimize
```


## Utility Functions for Taking Gradient
```@docs
NormalizingFlows.grad!
NormalizingFlows.value_and_gradient!
```

7 changes: 5 additions & 2 deletions docs/src/example.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Here we used the `PlanarLayer()` from `Bijectors.jl` to construct a

```julia
using Bijectors, FunctionChains
using Functors

function create_planar_flow(n_layers::Int, q₀)
d = length(q₀)
Expand All @@ -45,7 +46,9 @@ function create_planar_flow(n_layers::Int, q₀)
end

# create a 20-layer planar flow
flow = create_planar_flow(20, MvNormal(zeros(Float32, 2), I))
@leaf MvNormal # to prevent params in q₀ from being optimized
q₀ = MvNormal(zeros(Float32, 2), I)
flow = create_planar_flow(20, q₀)
flow_untrained = deepcopy(flow) # keep a copy of the untrained flow for comparison
```
*Notice that here the flow layers are chained together using `fchain` function from [`FunctionChains.jl`](https://github.com/oschulz/FunctionChains.jl).
Expand Down Expand Up @@ -116,4 +119,4 @@ plot!(title = "Comparison of Trained and Untrained Flow", xlabel = "X", ylabel=

## Reference

- Rezende, D. and Mohamed, S., 2015. *Variational inference with normalizing flows*. International Conference on Machine Learning
- Rezende, D. and Mohamed, S., 2015. *Variational inference with normalizing flows*. International Conference on Machine Learning
25 changes: 0 additions & 25 deletions ext/NormalizingFlowsEnzymeExt.jl

This file was deleted.

28 changes: 0 additions & 28 deletions ext/NormalizingFlowsForwardDiffExt.jl

This file was deleted.

22 changes: 0 additions & 22 deletions ext/NormalizingFlowsReverseDiffExt.jl

This file was deleted.

23 changes: 0 additions & 23 deletions ext/NormalizingFlowsZygoteExt.jl

This file was deleted.

46 changes: 13 additions & 33 deletions src/NormalizingFlows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@ using Bijectors
using Optimisers
using LinearAlgebra, Random, Distributions, StatsBase
using ProgressMeter
using ADTypes, DiffResults
using ADTypes
using DifferentiationInterface

using DocStringExtensions

export train_flow, elbo, loglikelihood, value_and_gradient!

using ADTypes
using DiffResults
export train_flow, elbo, loglikelihood

"""
"""
train_flow([rng::AbstractRNG, ]vo, flow, args...; kwargs...)

Train the given normalizing flow `flow` by calling `optimize`.
Expand Down Expand Up @@ -56,47 +54,29 @@ function train_flow(
# use FunctionChains instead of simple compositions to construct the flow when many flow layers are involved
# otherwise the compilation time for destructure will be too long
θ_flat, re = Optimisers.destructure(flow)

loss(θ, rng, args...) = -vo(rng, re(θ), args...)

# Normalizing flow training loop
θ_flat_trained, opt_stats, st = optimize(
rng,
θ_flat_trained, opt_stats, st, time_elapsed = optimize(
ADbackend,
vo,
loss,
θ_flat,
re,
args...;
(rng, args...)...;
max_iters=max_iters,
optimiser=optimiser,
kwargs...,
)

flow_trained = re(θ_flat_trained)
return flow_trained, opt_stats, st
return flow_trained, opt_stats, st, time_elapsed
end

include("train.jl")


include("optimize.jl")
include("objectives.jl")

# optional dependencies
if !isdefined(Base, :get_extension) # check whether :get_extension is defined in Base
using Requires
end

# Question: should Exts be loaded here or in train.jl?
function __init__()
@static if !isdefined(Base, :get_extension)
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include(
"../ext/NormalizingFlowsForwardDiffExt.jl"
)
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include(
"../ext/NormalizingFlowsReverseDiffExt.jl"
)
@require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" include(
"../ext/NormalizingFlowsEnzymeExt.jl"
)
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include(
"../ext/NormalizingFlowsZygoteExt.jl"
)
end
end
end
2 changes: 1 addition & 1 deletion src/objectives.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
include("objectives/elbo.jl")
include("objectives/loglikelihood.jl")
include("objectives/loglikelihood.jl") # not tested
2 changes: 1 addition & 1 deletion src/objectives/elbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ end

function elbo(flow::Bijectors.TransformedDistribution, logp, n_samples)
return elbo(Random.default_rng(), flow, logp, n_samples)
end
end
7 changes: 5 additions & 2 deletions src/objectives/loglikelihood.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,32 @@
# training by minimizing forward KL (MLE)
####################################
"""
loglikelihood(flow::Bijectors.TransformedDistribution, xs::AbstractVecOrMat)
loglikelihood(rng, flow::Bijectors.TransformedDistribution, xs::AbstractVecOrMat)

Compute the log-likelihood for variational distribution flow at a batch of samples xs from
the target distribution p.

# Arguments
- `rng`: random number generator (empty argument, only needed to ensure the same signature as other variational objectives)
- `flow`: variational distribution to be trained. In particular
"flow = transformed(q₀, T::Bijectors.Bijector)",
q₀ is a reference distribution that one can easily sample and compute logpdf
- `xs`: samples from the target distribution p.

"""
function loglikelihood(
rng::AbstractRNG, # empty argument
flow::Bijectors.UnivariateTransformed, # variational distribution to be trained
xs::AbstractVector, # sample batch from target dist p
)
return mean(Base.Fix1(logpdf, flow), xs)
end

function loglikelihood(
rng::AbstractRNG, # empty argument
flow::Bijectors.MultivariateTransformed, # variational distribution to be trained
xs::AbstractMatrix, # sample batch from target dist p
)
llhs = map(x -> logpdf(flow, x), eachcol(xs))
return mean(llhs)
end
end
Loading
Loading