Skip to content

Commit

Permalink
Address some type-stability issues (#6)
Browse files Browse the repository at this point in the history
* Fix up flatten

* Fix up parameters type stability

* Bump patch version

* Add gotcha to readme

* Fix up 1.3
  • Loading branch information
willtebbutt authored Dec 6, 2020
1 parent e52552d commit 1d1fffd
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 19 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ParameterHandling"
uuid = "2412ca09-6db7-441c-8e3a-88d5709968c5"
authors = ["Invenia Technical Computing Corporation"]
version = "0.2.0"
version = "0.2.1"

[deps]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,4 @@ package doesn't currently support the functionality that you need.
# Gotchas

1. `Integer`s typically don't take part in the kind of optimisation procedures that this package is designed to handle. Consequently, `flatten(::Integer)` produces an empty vector.
2. `deferred` has some type-stability issues when used in conjunction with abstract types. For example, `flatten(deferred(Normal, 5.0, 4.0))` won't infer properly. A simple work around is to write a function `normal(args...) = Normal(args...)` and work with `deferred(normal, 5.0, 4.0)` instead.
19 changes: 13 additions & 6 deletions src/flatten.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ function flatten(x::AbstractVector)
x_vecs_and_backs = map(flatten, x)
x_vecs, backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
function Vector_from_vec(x_vec)
sz = cumsum(map(length, x_vecs))
sz = _cumsum(map(length, x_vecs))
x_Vec = [backs[n](x_vec[sz[n] - length(x_vecs[n]) + 1:sz[n]]) for n in eachindex(x)]
return oftype(x, x_Vec)
end
return vcat(x_vecs...), Vector_from_vec
return reduce(vcat, x_vecs), Vector_from_vec
end

function flatten(x::AbstractArray)
Expand All @@ -54,11 +54,13 @@ function flatten(x::AbstractArray)
end

function flatten(x::Tuple)
x_vecs, unflattens = zip(map(flatten, x)...)
sz = cumsum(collect(map(length, x_vecs)))
x_vecs_and_backs = map(flatten, x)
x_vecs, x_backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
lengths = map(length, x_vecs)
sz = _cumsum(lengths)
function unflatten_to_Tuple(v::Vector{<:Real})
return ntuple(length(x)) do n
return unflattens[n](v[sz[n] - length(x_vecs[n]) + 1:sz[n]])
map(x_backs, lengths, sz) do x_back, l, s
return x_back(v[s - l + 1:s])
end
end
return reduce(vcat, x_vecs), unflatten_to_Tuple
Expand All @@ -81,3 +83,8 @@ function flatten(d::Dict)
end
return d_vec, unflatten_to_Dict
end

_cumsum(x) = cumsum(x)
if VERSION < v"1.5"
_cumsum(x::Tuple) = (_cumsum(collect(x))..., )
end
18 changes: 15 additions & 3 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,33 @@ using Test

using ParameterHandling: AbstractParameter, value

function test_flatten_interface(x::T) where {T}
function test_flatten_interface(x::T; check_inferred::Bool=true) where {T}

# Ensure that basic functionality is implemented.
v, unflatten = flatten(x)
@test v isa Vector{<:Real}
@test x == unflatten(v)
@test unflatten(v) isa T

# Check that everything infers properly.
if check_inferred
@inferred flatten(x)
end

return nothing
end

function test_parameter_interface(x)
function test_parameter_interface(x; check_inferred::Bool=true)

# Parameters need to be flatten-able.
test_flatten_interface(x)
test_flatten_interface(x; check_inferred=check_inferred)

# Run this to make sure that it doesn't error.
value(x)

if check_inferred
@inferred value(x)
end
return nothing
end

Expand Down
10 changes: 6 additions & 4 deletions test/flatten.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@
end

@testset "Tuple" begin
test_flatten_interface((1.0, 2.0))
test_flatten_interface((1.0, (2.0, 3.0), randn(5)))
test_flatten_interface((1.0, 2.0); check_inferred=tuple_infers)
test_flatten_interface((1.0, (2.0, 3.0), randn(5)); check_inferred=tuple_infers)
end

@testset "NamedTuple" begin
test_flatten_interface((a=1.0, b=(2.0, 3.0), c=(e=5.0,)))
test_flatten_interface(
(a=1.0, b=(2.0, 3.0), c=(e=5.0,)); check_inferred=tuple_infers,
)
end

@testset "Dict" begin
test_flatten_interface(Dict(:a => (a=4.0, b=3.0), :b => 5.0))
test_flatten_interface(Dict(:a => (a=4.0, b=3.0), :b => 5.0); check_inferred=false)
end
end
14 changes: 9 additions & 5 deletions test/parameters.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using ParameterHandling: Positive, Bounded

mvnormal(args...) = MvNormal(args...)
pdiagmat(args...) = PDiagMat(args...)

@testset "parameters" begin

@testset "postive" begin
Expand Down Expand Up @@ -30,14 +33,15 @@ using ParameterHandling: Positive, Bounded
end

@testset "deferred" begin
test_parameter_interface(deferred(sin, 0.5))
test_parameter_interface(deferred(sin, positive(0.5)))
test_parameter_interface(deferred(sin, 0.5); check_inferred=tuple_infers)
test_parameter_interface(deferred(sin, positive(0.5)); check_inferred=tuple_infers)
test_parameter_interface(
deferred(
MvNormal,
mvnormal,
fixed(randn(5)),
deferred(PDiagMat, positive.(rand(5) .+ 1e-1)),
)
deferred(pdiagmat, positive.(rand(5) .+ 1e-1)),
);
check_inferred=tuple_infers,
)
end

Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ using Zygote
using ParameterHandling: value
using ParameterHandling.TestUtils: test_flatten_interface, test_parameter_interface

const tuple_infers = VERSION < v"1.5" ? false : true

@testset "ParameterHandling.jl" begin
include("flatten.jl")
include("parameters.jl")
Expand Down

2 comments on commit 1d1fffd

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/25908

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.1 -m "<description of version>" 1d1fffdbdb099ae513c7c81b74109137287f8321
git push origin v0.2.1

Please sign in to comment.