Skip to content

Commit

Permalink
Parameterise orthogonal matrices (#17)
Browse files Browse the repository at this point in the history
* Add LinearAlgebra

* Add LinearAlgebra

* Add nearest_orthogonal_matrix

* Add LinearAlgebra to tests

* Add orthogonal transform

* Bump patch

* Add inlining comment

* Add attribution and comments

* Update src/parameters.jl

Co-authored-by: Alex Robson <[email protected]>

* Fix and test equality on Orthogonal

* Test value(::Orthogonal) is orthogonal

* Set tolerance for is_almost_orthogonal

* Ensure that Zygote runs

* Generalise eltype to Real or Complex

* Test nearest_orthogonal_matrix(::ComplexF64)

* Revert type widening

* Turn orthogonality check into a test

Co-authored-by: wytbella <[email protected]>

* Rename variable for readability

Co-authored-by: wytbella <[email protected]>

* Rename variable for readability

Co-authored-by: wytbella <[email protected]>

Co-authored-by: Alex Robson <[email protected]>
Co-authored-by: wytbella <[email protected]>
  • Loading branch information
3 people authored Mar 12, 2021
1 parent 462ed7c commit b0b4618
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 2 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
name = "ParameterHandling"
uuid = "2412ca09-6db7-441c-8e3a-88d5709968c5"
authors = ["Invenia Technical Computing Corporation"]
version = "0.3.0"
version = "0.3.1"

[deps]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Expand Down
3 changes: 2 additions & 1 deletion src/ParameterHandling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ module ParameterHandling

using Bijectors
using Compat: only
using LinearAlgebra

export flatten, positive, bounded, fixed, deferred
export flatten, positive, bounded, fixed, deferred, orthogonal

include("flatten.jl")
include("parameters.jl")
Expand Down
40 changes: 40 additions & 0 deletions src/parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,43 @@ function flatten(::Type{T}, x::Deferred) where T<:Real
unflatten_Deferred(v_new::Vector{T}) = Deferred(x.f, unflatten(v_new))
return v, unflatten_Deferred
end

"""
nearest_orthogonal_matrix(X::StridedMatrix)
Project `X` onto the closest orthogonal matrix in Frobenius norm.
Originally used in varz: https://github.com/wesselb/varz/blob/master/varz/vars.py#L446
"""
@inline function nearest_orthogonal_matrix(X::StridedMatrix{<:Union{Real, Complex}})
# Inlining necessary for type inference for some reason.
U, _, V = svd(X)
return U * V'
end

"""
orthogonal(X::StridedMatrix{<:Real})
Produce a parameter whose `value` is constrained to be an orthogonal matrix. The argument `X` need not
be orthogonal.
This functionality projects `X` onto the nearest element subspace of orthogonal matrices (in
Frobenius norm) and is overparametrised as a consequence.
Originally used in varz: https://github.com/wesselb/varz/blob/master/varz/vars.py#L446
"""
orthogonal(X::StridedMatrix{<:Real}) = Orthogonal(X)

struct Orthogonal{TX<:StridedMatrix{<:Real}} <: AbstractParameter
X::TX
end

Base.:(==)(X::Orthogonal, Y::Orthogonal) = X.X == Y.X

value(X::Orthogonal) = nearest_orthogonal_matrix(X.X)

function flatten(::Type{T}, X::Orthogonal) where {T<:Real}
v, unflatten_to_Array = flatten(T, X.X)
unflatten_Orthogonal(v_new::Vector{T}) = Orthogonal(unflatten_to_Array(v_new))
return v, unflatten_Orthogonal
end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
23 changes: 23 additions & 0 deletions test/parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,29 @@ pdiagmat(args...) = PDiagMat(args...)
)
end

@testset "orthogonal" begin
is_almost_orthogonal(X::AbstractMatrix, tol) = norm(X'X - I) < tol

@testset "nearest_orthogonal_matrix($T)" for T in [Float64, ComplexF64]
X_orth = ParameterHandling.nearest_orthogonal_matrix(randn(T, 5, 4))
@test is_almost_orthogonal(X_orth, 1e-9)
X_orth_2 = ParameterHandling.nearest_orthogonal_matrix(X_orth)
@test X_orth X_orth_2 # nearest_orthogonal_matrix is a projection.
end

X = orthogonal(randn(5, 4))
@test X == X
test_parameter_interface(X)
@test is_almost_orthogonal(value(X), 1e-9)

# We do not implement any custom rrules, so we only check that `Zygote` is able to
# differentiate, and assume that the result is correct if it doesn't error.
@testset "Zygote" begin
_, pb = Zygote.pullback(X -> value(orthogonal(X)), randn(3, 2))
@test only(pb(randn(3, 2))) isa Matrix{<:Real}
end
end

function objective_function(unflatten, flat_θ::Vector{<:Real})
θ = value(unflatten(flat_θ))
return abs2.a) + abs2.b)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Bijectors
using Compat: only
using Distributions
using LinearAlgebra
using Optim
using ParameterHandling
using PDMats
Expand Down

2 comments on commit b0b4618

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/31802

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.1 -m "<description of version>" b0b46186e73792163cc9e385fe3dbfa1646da784
git push origin v0.3.1

Please sign in to comment.