Skip to content

Commit

Permalink
Add value_flatten as described in #37. (#38)
Browse files Browse the repository at this point in the history
* Add `value_flatten` as described in #37.

* Remove heavy testing for `value_flatten`

* Update README to use `value_flatten`

* Formatting changes suggested by JuliaFormatter
  • Loading branch information
jonniedie authored Sep 3, 2021
1 parent d6c1045 commit 21e6ff7
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ParameterHandling"
uuid = "2412ca09-6db7-441c-8e3a-88d5709968c5"
authors = ["Invenia Technical Computing Corporation"]
version = "0.3.7"
version = "0.3.8"

[deps]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Expand Down
29 changes: 17 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,15 @@ julia> value(unflatten(new_v)) # Obtain constrained value.
3.071174489325673
```

We also provide the utility function `value_flatten` which returns un unflattening function
equivalent to `value(unflatten(v))`. The above could then be implemented as
```julia
julia> v, unflatten = value_flatten(x);

julia> unflatten(x)
1.0
```

It is straightforward to implement your own parameters that interoperate with those already
written by implementing `value` and `flatten` for them. You might want to do this if this
package doesn't currently support the functionality that you need.
Expand Down Expand Up @@ -209,19 +218,15 @@ raw_initial_params = (
noise_var=positive(0.2),
)

# Using ParameterHandling.flatten, we can obtain both a Vector{Float64} representation of
# these parameters, and a mapping from that vector back to the original parameters:
flat_initial_params, unflatten = ParameterHandling.flatten(raw_initial_params)
# Using ParameterHandling.value_flatten, we can obtain both a Vector{Float64} representation of
# these parameters, and a mapping from that vector back to the original parameter values:
flat_initial_params, unflatten = ParameterHandling.value_flatten(raw_initial_params)

# ParameterHandling.value strips out all of the Positive types in initial_params,
# returning a plain named tuple of named tuples and Float64s.
julia> initial_params = ParameterHandling.value(raw_initial_params)
(k1 = (var = 0.9, precision = 1.0), k2 = (var = 0.10000000000000002, precision = 0.30000000000000004), noise_var = 0.19999999999999998)

# We define unpack to map directly from the flat Vector{Float64} representation to a
# the NamedTuple representation with all the Positive types removed.
unpack = ParameterHandling.value unflatten

# GP-specific functionality. Don't worry about the details, just
# note the use of the structured representation of the parameters.
function build_gp(params::NamedTuple)
Expand All @@ -243,13 +248,13 @@ end

# Use Optim.jl to minimise the objective function w.r.t. the params.
# The important thing here is to note that we're passing in the flat vector of parameters to
# Optim, which is something that Optim knows how to work with, and using `unpack` to convert
# Optim, which is something that Optim knows how to work with, and using `unflatten` to convert
# from this representation to the structured one that our objective function knows about
# using `unpack` -- we've used ParameterHandling to build a bridge between Optim and an
# using `unflatten` -- we've used ParameterHandling to build a bridge between Optim and an
# entirely unrelated package.
training_results = Optim.optimize(
objective unpack,
θ -> only(Zygote.gradient(objective unpack, θ)),
objective unflatten,
θ -> only(Zygote.gradient(objective unflatten, θ)),
flat_initial_params,
BFGS(
alphaguess = Optim.LineSearches.InitialStatic(scaled=true),
Expand All @@ -260,7 +265,7 @@ training_results = Optim.optimize(
)

# Extracting the final values of the parameters.
final_params = unpack(training_results.minimizer)
final_params = unflatten(training_results.minimizer)
f_trained = build_gp(final_params)
```

Expand Down
3 changes: 2 additions & 1 deletion src/ParameterHandling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ using ChainRulesCore
using LinearAlgebra
using SparseArrays

export flatten, positive, bounded, fixed, deferred, orthogonal, positive_definite
export flatten,
value_flatten, positive, bounded, fixed, deferred, orthogonal, positive_definite

include("flatten.jl")
include("parameters.jl")
Expand Down
21 changes: 21 additions & 0 deletions src/flatten.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,24 @@ _cumsum(x) = cumsum(x)
if VERSION < v"1.5"
_cumsum(x::Tuple) = (_cumsum(collect(x))...,)
end

"""
value_flatten([eltype=Float64], x)
Operates similarly to `flatten`, but the returned `unflatten` function returns an object
like `x`, but with unwrapped values.
Doing
```julia
v, unflatten = value_flatten(x)
```
is the same as doing
```julia
v, _unflatten = flatten(x)
unflatten = ParameterHandling.value ∘ _unflatten
```
"""
function value_flatten(args...)
v, unflatten = flatten(args...)
return v, value unflatten
end
8 changes: 8 additions & 0 deletions test/parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,12 @@ pdiagmat(args...) = PDiagMat(args...)
# Check that it's successfully optimised.
@test mean(value(unflatten(results.minimizer))) 0 atol = 1e-7
end

@testset "value_flatten" begin
x = (ones(3), fixed(5.0), (a=fixed(5.0), b=[6.0, 2.1]))
v, unflatten = value_flatten(x)

@test length(v) == 5
@test unflatten(v) == ParameterHandling.value(x)
end
end

2 comments on commit 21e6ff7

@willtebbutt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/44176

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.8 -m "<description of version>" 21e6ff717cf94ee46175c4ac09bf44cddd0d677f
git push origin v0.3.8

Please sign in to comment.