Skip to content

Commit

Permalink
export only lower-case names (#5)
Browse files Browse the repository at this point in the history
* Bump patch

* Change to lower-case exports

* Refactor documentation

* Fix test bug

Co-authored-by: Letif Mones <[email protected]>

* Update src/parameters.jl

Co-authored-by: wytbella <[email protected]>

* Update src/parameters.jl

Co-authored-by: wytbella <[email protected]>

* Improve docs + tests

Co-authored-by: wt <[email protected]>
Co-authored-by: Letif Mones <[email protected]>
Co-authored-by: wytbella <[email protected]>
  • Loading branch information
4 people authored Sep 10, 2020
1 parent 5aada90 commit 4af076d
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 58 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ParameterHandling"
uuid = "2412ca09-6db7-441c-8e3a-88d5709968c5"
authors = ["Invenia Technical Computing Corporation"]
version = "0.1.2"
version = "0.2.0"

[deps]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Expand Down
2 changes: 1 addition & 1 deletion src/ParameterHandling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module ParameterHandling
using Bijectors
using Compat: only

export flatten, Positive, Bounded, Fixed, Deferred
export flatten, positive, bounded, fixed, deferred

include("flatten.jl")
include("parameters.jl")
Expand Down
79 changes: 50 additions & 29 deletions src/parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,29 @@ value(x::Tuple) = map(value, x)
value(x::NamedTuple) = map(value, x)
value(x::Dict) = Dict(k => value(v) for (k, v) in x)


"""
Positive(value::Real)
positive(val::Real, transform::Bijector=Bijectors.Exp(), ε::Real = 1e-12)
The `value` of a `Positive` is a `Real` number that is constrained to be positive. This is
represented in terms of an `unconstrained_value` and a `transform` that maps any value
the `unconstrained_value` might take to the positive reals.
Returns a `Postive`.
The `value` of a `Positive` is a `Real` number that is constrained to be positive.
This is represented in terms of an a `transform` that maps an `unconstrained_value` to the
positive reals.
Satisfies `val ≈ transform(unconstrained_value)`
"""
function positive(val::Real, transform::Bijector=Bijectors.Exp(), ε::Real = 1e-12)
if val <= 0
throw(ArgumentError("Value, $val, is not positive."))
end
unconstrained_value = inv(transform)(val - ε)
return Positive(unconstrained_value, transform, convert(typeof(unconstrained_value), ε))
end

struct Positive{T<:Real, V<:Bijector, Tε<:Real} <: AbstractParameter
unconstrained_value::T
transform::V
ε::Tε
end

Positive(value::Real) = Positive(value, Bijectors.Exp(), convert(typeof(value), 1e-12))

value(x::Positive) = x.transform(x.unconstrained_value) + x.ε

function flatten(x::Positive)
Expand All @@ -48,12 +55,32 @@ function flatten(x::Positive)
end

"""
Bounded(value::Real, lower_bound::Real, upper_bound::Real)
bounded(val::Real, lower_bound::Real, upper_bound::Real)
Constructs a `Bounded`.
The `value` of a `Bounded` is a `Real` number that is constrained to be within the interval
(`lower_bound`, `upper_bound`). This is represented in terms of an `unconstrained_value` and
a `transform` that maps any value to reals included by (`lower_bound`, `upper_bound`).
(`lower_bound`, `upper_bound`), and is equal to `val`.
This is represented internally in terms of an `unconstrained_value` and a `transform` that
maps any real to this interval. `unconstrained_value` is `inv(transform)(val)`.
"""
function bounded(val::Real, lower_bound::Real, upper_bound::Real)
lb = convert(typeof(val), lower_bound)
ub = convert(typeof(val), upper_bound)
ε = convert(typeof(val), 1e-12)

if val > upper_bound || val < lower_bound
throw(ArgumentError(
"Value, $val, outside of specified bounds ($lower_bound, $upper_bound).",
))
end

inv_transform = Bijectors.Logit(lb + ε, ub - ε)
transform = inv(inv_transform)

# Bijectors defines only Logit struct so we use Logistic as the inverse of Logit
return Bounded(inv_transform(val), lb, ub, transform, ε)
end

struct Bounded{T<:Real, V<:Bijector, Tε<:Real} <: AbstractParameter
unconstrained_value::T
lower_bound::T
Expand All @@ -62,34 +89,29 @@ struct Bounded{T<:Real, V<:Bijector, Tε<:Real} <: AbstractParameter
ε::Tε
end

function Bounded(value::Real, lower_bound::Real, upper_bound::Real)
lb = convert(typeof(value), lower_bound)
ub = convert(typeof(value), upper_bound)
ε = convert(typeof(value), 1e-12)

# Bijectors defines only Logit struct so we use Logistic as the inverse of Logit
return Bounded(value, lb, ub, inv(Bijectors.Logit(lb + ε, ub - ε)), ε)
end

value(x::Bounded) = x.transform(x.unconstrained_value)

function flatten(x::Bounded)
v, unflatten_to_Real = flatten(x.unconstrained_value)

function unflatten_Bounded(v_new::Vector{<:Real})
return Bounded(unflatten_to_Real(v_new), x.lower_bound, x.upper_bound, x.transform, x.ε)
return Bounded(
unflatten_to_Real(v_new), x.lower_bound, x.upper_bound, x.transform, x.ε,
)
end

return v, unflatten_Bounded
end

"""
Fixed{T}
fixed(val)
Represents a parameter whose value is required to stay constant. The `value` of a `Fixed` is
simply its value -- that constantness of the parameter is enforced by returning an empty
simply `val`. Constantness of the parameter is enforced by returning an empty
vector from `flatten`.
"""
fixed(val) = Fixed(val)

struct Fixed{T} <: AbstractParameter
value::T
end
Expand All @@ -104,20 +126,19 @@ function flatten(x::Fixed)
end

"""
Deferred(f, args...)
deferred(f, args...)
The `value` of a `Deferred` is `f(value(args)...)`. This makes it possible to make the value
The `value` of a `deferred` is `f(value(args)...)`. This makes it possible to make the value
of the `args` e.g. `AbstractParameter`s and, therefore, enforce constraints on them even if
`f` knows nothing about `AbstractParameters`.
It can be helpful to use `Deferred` recursively when constructing complicated objects.
It can be helpful to use `deferred` recursively when constructing complicated objects.
"""
deferred(f, args...) = Deferred(f, args)

struct Deferred{Tf, Targs} <: AbstractParameter
f::Tf
args::Targs
function Deferred(f::Tf, args...) where {Tf}
return new{Tf, typeof(args)}(f, args)
end
end

Base.:(==)(a::Deferred, b::Deferred) = (a.f == b.f) && (a.args == b.args)
Expand All @@ -129,7 +150,7 @@ function flatten(x::Deferred)
v, unflatten = flatten(x.args)

function unflatten_Deferred(v_new::Vector{<:Real})
return Deferred(x.f, unflatten(v_new)...)
return Deferred(x.f, unflatten(v_new))
end

return v, unflatten_Deferred
Expand Down
59 changes: 32 additions & 27 deletions test/parameters.jl
Original file line number Diff line number Diff line change
@@ -1,37 +1,42 @@
using ParameterHandling: Positive, Bounded

@testset "parameters" begin

@testset "Postive" begin
test_parameter_interface(Positive(5.0))
p = Positive(-10.0)
@test value(p) > p.ε
p = Positive(-Inf)
@test value(p) == p.ε
@testset "postive" begin
@testset "$val" for val in [5.0, 1e-11, 1e-12]
p = positive(val)
test_parameter_interface(p)
@test value(p) val
end

@test_throws ArgumentError positive(-0.1)
end

@testset "Bounded" begin
test_parameter_interface(Bounded(-1.0, -0.1, 2.0))
p = Bounded(-10.0, -0.1, 2.0)
@test value(p) > p.lower_bound
p = Bounded(-Inf, -0.1, 2.0)
@test value(p) == p.lower_bound + p.ε
p = Bounded(10.0, -0.1, 2.0)
@test value(p) < p.upper_bound
p = Bounded(Inf, -0.1, 2.0)
@test value(p) == p.upper_bound - p.ε
@testset "bounded" begin
@testset "$val" for val in [-0.05, -0.1 + 1e-12, 2.0 - 1e-11, 2.0 - 1e-12]
p = bounded(val, -0.1, 2.0)
test_parameter_interface(p)
@test value(p) val
end

@test_throws ArgumentError bounded(-0.05, 0.0, 1.0)
end

@testset "Fixed" begin
test_parameter_interface(Fixed((a=5.0, b=4.0)))
@testset "fixed" begin
val = (a=5.0, b=4.0)
p = fixed(val)
test_parameter_interface(p)
@test value(p) == val
end

@testset "Deferred" begin
test_parameter_interface(Deferred(sin, 0.5))
test_parameter_interface(Deferred(sin, Positive(log(0.5))))
@testset "deferred" begin
test_parameter_interface(deferred(sin, 0.5))
test_parameter_interface(deferred(sin, positive(0.5)))
test_parameter_interface(
Deferred(
deferred(
MvNormal,
Fixed(randn(5)),
Deferred(PDiagMat, Positive.(randn(5))),
fixed(randn(5)),
deferred(PDiagMat, positive.(rand(5) .+ 1e-1)),
)
)
end
Expand All @@ -41,7 +46,7 @@
return abs2.a) + abs2.b)
end

# This is more of a worked example. Will be properly split up / tidied up.
# This is more of a worked example.
@testset "Integration" begin

θ0 = (a=5.0, b=4.0)
Expand All @@ -62,7 +67,7 @@

@testset "Other Integration" begin

θ0 = (a=5.0, b=Fixed(4.0))
θ0 = (a=5.0, b=fixed(4.0))
flat_parameters, unflatten = flatten(θ0)

results = Optim.optimize(
Expand All @@ -85,7 +90,7 @@

@testset "Normal" begin

θ0 = Deferred(Normal, randn(), Positive(log(1.0)))
θ0 = deferred(Normal, randn(), positive(1.0))
flat_parameters, unflatten = flatten(θ0)

results = Optim.optimize(
Expand Down

2 comments on commit 4af076d

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/21167

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.0 -m "<description of version>" 4af076d5d48f0e949b41bb26d5834a774f8a144f
git push origin v0.2.0

Please sign in to comment.