-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: Docs on ProjectTo and embedded subspaces #412
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,116 @@ | ||||||||
# Types that represent embedded subspaces | ||||||||
_Taking Types Representing Embedded Subspaces Seriously™_ | ||||||||
|
||||||||
To paraphrase Stefan Karpinski: _"This does mean treating sparse matrixes not just as a representation of dense matrixes, but the alternative is just too horrible."_ | ||||||||
|
||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
Consider the following possible `rrule` for `*` | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
```julia | ||||||||
function rrule(::typeof(*), a, b) | ||||||||
mul_pullback(ȳ) = (NoTangent(), ȳ*b', a'*ȳ) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should you use the |
||||||||
return (a * b), mul_pullback | ||||||||
end | ||||||||
``` | ||||||||
|
||||||||
This seems perfectly reasonable for floats: | ||||||||
```julia | ||||||||
julia> _, pb = rrule(*, 2.0, 3.0) | ||||||||
(6.0, var"#mul_pullback#10"{Float64, Float64}(2.0, 3.0)) | ||||||||
|
||||||||
julia> pb(1.0) | ||||||||
(NoTangent(), 3.0, 2.0) | ||||||||
``` | ||||||||
and for matrixes | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
```julia | ||||||||
julia> _, pb = rrule(*, [1.0 2.0; 3.0 4.0], [10.0 20.0; 30.0 40.0]) | ||||||||
([70.0 100.0; 150.0 220.0], var"#mul_pullback#10"{Matrix{Float64}, Matrix{Float64}}([1.0 2.0; 3.0 4.0], [10.0 20.0; 30.0 40.0])) | ||||||||
|
||||||||
julia> pb([1.0 0.0; 0.0 1.0]) | ||||||||
(NoTangent(), [10.0 30.0; 20.0 40.0], [1.0 3.0; 2.0 4.0]) | ||||||||
``` | ||||||||
|
||||||||
and even for complex numbers (assuming you like conjugation [which we do](@ref complexfunctions)) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
```julia | ||||||||
julia> _, pb = rrule(*, 0.0 + im, 1.0 + im) | ||||||||
(-1.0 + 1.0im, var"#mul_pullback#10"{ComplexF64, ComplexF64}(0.0 + 1.0im, 1.0 + 1.0im)) | ||||||||
|
||||||||
julia> pb(1.0) | ||||||||
(NoTangent(), 1.0 - 1.0im, 0.0 - 1.0im) | ||||||||
|
||||||||
julia> pb(1.0im) | ||||||||
(NoTangent(), 1.0 + 1.0im, 1.0 + 0.0im) | ||||||||
``` | ||||||||
|
||||||||
So far everything is wonderful. | ||||||||
Isn't linear algebra great? We get this nice code that generalizes to all kinds of vector spaces. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One sentence per line.
Suggested change
|
||||||||
|
||||||||
What if we start mixing it up a bit. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
Let's try a real amd a complex | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
```julia | ||||||||
julia> _, pb = rrule(*, 2.0, 3.0im) | ||||||||
(0.0 + 6.0im, var"#mul_pullback#10"{Float64, ComplexF64}(2.0, 0.0 + 3.0im)) | ||||||||
|
||||||||
julia> pb(1.0) | ||||||||
(NoTangent(), 0.0 - 3.0im, 2.0) | ||||||||
|
||||||||
julia> pb(1.0im) | ||||||||
(NoTangent(), 3.0 + 0.0im, 0.0 + 2.0im) | ||||||||
``` | ||||||||
|
||||||||
That's _an_ answer. | ||||||||
It's consistent with treating the reals as being an embedded subspace of the complex numbers. | ||||||||
i.e. treating `2.0` as actually being `2.0 + 0im`. | ||||||||
It doesn't feel great that we have escaped from the real type in this way. | ||||||||
But it's not wrong as such. | ||||||||
|
||||||||
Over to matrixes, lets try as `Diagonal` with a `Matrix` | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
```julia | ||||||||
julia> _, pb = rrule(*, Diagonal([1.0, 2.0]), [1.0 2.0; 3.0 4.0]) | ||||||||
([1.0 2.0; 6.0 8.0], var"#mul_pullback#10"{Diagonal{Float64, Vector{Float64}}, Matrix{Float64}}([1.0 0.0; 0.0 2.0], [1.0 2.0; 3.0 4.0])) | ||||||||
|
||||||||
julia> pb([1.0 0.0; 0.0 1.0]) | ||||||||
(NoTangent(), [1.0 3.0; 2.0 4.0], [1.0 0.0; 0.0 2.0]) | ||||||||
``` | ||||||||
|
||||||||
This is also _an_ answer. | ||||||||
This seems even worse though: no only has it escaped the `Diagonal` type, it has even escaped the subspace it represents. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
Further, it is inconsistent with what we would get had we AD'd the function for `*(::Diagonal, ::Matrix)` directly. | ||||||||
[That primal function](https://github.com/JuliaLang/julia/blob/f7f46af8ff39a1b4c7000651c680058e9c0639f5/stdlib/LinearAlgebra/src/diagonal.jl#L224-L245) boils down to: | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this refer to a specific Julia version for simplicity? |
||||||||
```julia | ||||||||
*(a::Diagonal, b::AbstractMatrix) = a.diag .* b | ||||||||
``` | ||||||||
By reading that primal method, we know that that ADing that method would have zeros on the off-diagonals, because they are never even accessed | ||||||||
(a similar argument can be made for the complex part of a real number). | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
--- | ||||||||
|
||||||||
Consider the following possible `rrule` for `sum` | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
```julia | ||||||||
function rrule(::typeof(sum), x::AbstractArray) | ||||||||
sum_pullback(ȳ) = (NoTangent(), fill(ȳ, size(x))) | ||||||||
return sum(x), sum_pullback | ||||||||
end | ||||||||
``` | ||||||||
|
||||||||
This seems all well and good at first: | ||||||||
```julia | ||||||||
julia> _, pb = rrule(sum, [1.0 2.0; 3.0 4.0]) | ||||||||
(10.0, var"#sum_pullback#9"{Matrix{Float64}}([1.0 2.0; 3.0 4.0])) | ||||||||
|
||||||||
julia> pb(1.0) | ||||||||
(NoTangent(), [1.0 1.0; 1.0 1.0]) | ||||||||
``` | ||||||||
|
||||||||
But now consider: | ||||||||
```julia | ||||||||
julia> _, pb = rrule(sum, Diagonal([1.0, 2.0])) | ||||||||
(3.0, var"#sum_pullback#9"{Diagonal{Float64, Vector{Float64}}}([1.0 0.0; 0.0 2.0])) | ||||||||
|
||||||||
julia> pb(1.0) | ||||||||
(NoTangent(), [1.0 1.0; 1.0 1.0]) | ||||||||
``` | ||||||||
That's not right -- not if us saying this was `Diagonal` meant anything. | ||||||||
If you try and use that dense matrix, to do gradient descent on your `Diagonal` input, you will get a non-diagonal result: | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
`[2.0 1.0; 1.0 2.0]`. | ||||||||
You have escape the subspace that the diagonal type represents. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is phrased strangely and hard to follow.