diff --git a/Project.toml b/Project.toml index 3a18157..e17b62f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ParameterHandling" uuid = "2412ca09-6db7-441c-8e3a-88d5709968c5" authors = ["Invenia Technical Computing Corporation"] -version = "0.3.7" +version = "0.3.8" [deps] Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" diff --git a/README.md b/README.md index 14da837..b3ec8cd 100644 --- a/README.md +++ b/README.md @@ -163,6 +163,15 @@ julia> value(unflatten(new_v)) # Obtain constrained value. 3.071174489325673 ``` +We also provide the utility function `value_flatten` which returns un unflattening function +equivalent to `value(unflatten(v))`. The above could then be implemented as +```julia +julia> v, unflatten = value_flatten(x); + +julia> unflatten(x) +1.0 +``` + It is straightforward to implement your own parameters that interoperate with those already written by implementing `value` and `flatten` for them. You might want to do this if this package doesn't currently support the functionality that you need. @@ -209,19 +218,15 @@ raw_initial_params = ( noise_var=positive(0.2), ) -# Using ParameterHandling.flatten, we can obtain both a Vector{Float64} representation of -# these parameters, and a mapping from that vector back to the original parameters: -flat_initial_params, unflatten = ParameterHandling.flatten(raw_initial_params) +# Using ParameterHandling.value_flatten, we can obtain both a Vector{Float64} representation of +# these parameters, and a mapping from that vector back to the original parameter values: +flat_initial_params, unflatten = ParameterHandling.value_flatten(raw_initial_params) # ParameterHandling.value strips out all of the Positive types in initial_params, # returning a plain named tuple of named tuples and Float64s. julia> initial_params = ParameterHandling.value(raw_initial_params) (k1 = (var = 0.9, precision = 1.0), k2 = (var = 0.10000000000000002, precision = 0.30000000000000004), noise_var = 0.19999999999999998) -# We define unpack to map directly from the flat Vector{Float64} representation to a -# the NamedTuple representation with all the Positive types removed. -unpack = ParameterHandling.value ∘ unflatten - # GP-specific functionality. Don't worry about the details, just # note the use of the structured representation of the parameters. function build_gp(params::NamedTuple) @@ -243,13 +248,13 @@ end # Use Optim.jl to minimise the objective function w.r.t. the params. # The important thing here is to note that we're passing in the flat vector of parameters to -# Optim, which is something that Optim knows how to work with, and using `unpack` to convert +# Optim, which is something that Optim knows how to work with, and using `unflatten` to convert # from this representation to the structured one that our objective function knows about -# using `unpack` -- we've used ParameterHandling to build a bridge between Optim and an +# using `unflatten` -- we've used ParameterHandling to build a bridge between Optim and an # entirely unrelated package. training_results = Optim.optimize( - objective ∘ unpack, - θ -> only(Zygote.gradient(objective ∘ unpack, θ)), + objective ∘ unflatten, + θ -> only(Zygote.gradient(objective ∘ unflatten, θ)), flat_initial_params, BFGS( alphaguess = Optim.LineSearches.InitialStatic(scaled=true), @@ -260,7 +265,7 @@ training_results = Optim.optimize( ) # Extracting the final values of the parameters. -final_params = unpack(training_results.minimizer) +final_params = unflatten(training_results.minimizer) f_trained = build_gp(final_params) ``` diff --git a/src/ParameterHandling.jl b/src/ParameterHandling.jl index 2e4d75f..bc7f78e 100644 --- a/src/ParameterHandling.jl +++ b/src/ParameterHandling.jl @@ -6,7 +6,8 @@ using ChainRulesCore using LinearAlgebra using SparseArrays -export flatten, positive, bounded, fixed, deferred, orthogonal, positive_definite +export flatten, + value_flatten, positive, bounded, fixed, deferred, orthogonal, positive_definite include("flatten.jl") include("parameters.jl") diff --git a/src/flatten.jl b/src/flatten.jl index e79946d..717b27b 100644 --- a/src/flatten.jl +++ b/src/flatten.jl @@ -104,3 +104,24 @@ _cumsum(x) = cumsum(x) if VERSION < v"1.5" _cumsum(x::Tuple) = (_cumsum(collect(x))...,) end + +""" + value_flatten([eltype=Float64], x) + +Operates similarly to `flatten`, but the returned `unflatten` function returns an object +like `x`, but with unwrapped values. + +Doing +```julia +v, unflatten = value_flatten(x) +``` +is the same as doing +```julia +v, _unflatten = flatten(x) +unflatten = ParameterHandling.value ∘ _unflatten +``` +""" +function value_flatten(args...) + v, unflatten = flatten(args...) + return v, value ∘ unflatten +end diff --git a/test/parameters.jl b/test/parameters.jl index d6e3cc3..5f0485c 100644 --- a/test/parameters.jl +++ b/test/parameters.jl @@ -162,4 +162,12 @@ pdiagmat(args...) = PDiagMat(args...) # Check that it's successfully optimised. @test mean(value(unflatten(results.minimizer))) ≈ 0 atol = 1e-7 end + + @testset "value_flatten" begin + x = (ones(3), fixed(5.0), (a=fixed(5.0), b=[6.0, 2.1])) + v, unflatten = value_flatten(x) + + @test length(v) == 5 + @test unflatten(v) == ParameterHandling.value(x) + end end