Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ChainRules #366

Merged
merged 35 commits into from
May 28, 2020
Merged
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
1c3ee6d
Add ChainRules fallback
oxinabox Oct 7, 2019
0280d45
Test ChainRules integration directly
oxinabox Oct 18, 2019
c2b37c5
use metaprogramming in blacklist
oxinabox Oct 20, 2019
ec5c65f
compact choosing pullback mechanism code
oxinabox Oct 23, 2019
3d9bd62
Simplify conversion Zygote style pullback
oxinabox Jan 13, 2020
449a6a7
add ChainRules, rm DiffRules
oxinabox Apr 20, 2020
c9f2dd0
wrap CR types
MikeInnes Apr 14, 2020
56c8402
wrap `nothing`
MikeInnes Apr 20, 2020
f640621
record fastmath regression
MikeInnes Apr 20, 2020
6b354e5
conjugate appropriately
oxinabox Apr 21, 2020
b86e09e
multiple input and multiple output functions with chainrules
oxinabox Apr 23, 2020
02a302c
delete extra code added to test by mistake
oxinabox Apr 23, 2020
3e1e8e0
Integration test that we have identity working right
oxinabox Apr 24, 2020
5a9ec6a
Nested AD working
oxinabox Apr 24, 2020
d9db59b
delete rules that are in chainrules
oxinabox Apr 28, 2020
cb9ec4c
delete diag rule also
oxinabox Apr 28, 2020
aa29144
bump verion
oxinabox Apr 29, 2020
1fd415a
add in docs about prefer to use ChainRules
oxinabox Apr 29, 2020
d9a53a0
delete commented out fastmath code
oxinabox Apr 30, 2020
5cf8b10
comment all about the chainrules interface
oxinabox Apr 30, 2020
5300050
Update the Manifest.toml
oxinabox Apr 30, 2020
329b530
Update docs/src/adjoints.md
oxinabox May 4, 2020
8cd76f4
Update src/lib/array.jl
oxinabox May 4, 2020
497846a
WIP: support kwargs
oxinabox May 4, 2020
2ff6cc3
Fix type inference
oxinabox May 5, 2020
513b787
Fix nexting
oxinabox May 5, 2020
29e0244
Update src/compiler/chainrules.jl
oxinabox May 5, 2020
325d304
Update src/compiler/chainrules.jl
oxinabox May 5, 2020
1432563
pre1.3 do not worry about edges
oxinabox May 6, 2020
e93eca1
Fix constant folding
oxinabox May 6, 2020
1b4db9b
remove Manifest
oxinabox May 7, 2020
cfeda77
Remove pure as not needed anymore
oxinabox May 7, 2020
cef928c
Remove special handling of multiple inputs to pullbacks from ChainRul…
oxinabox May 20, 2020
e7ecdd6
Follow up after rebase
oxinabox May 28, 2020
41f4c17
mark as broken the test that fails on 1.0
oxinabox May 28, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
@@ -10,12 +10,33 @@ function has_chain_rrule(T)
end

# For now we are just not going to deal with thunks
wrap_chainrules(x) = unthunk(x)
wrap_chainrules(x::Tuple) = map(wrap_chainrules, x)
wrap_chainrules_output(x) = conj(unthunk(x))
wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x)
function wrap_chainrules_output(x::ChainRules.Composite{P, T}) where {P, T}
T_outer = T <: NamedTuple ? NamedTuple : Tuple # must be a Tuple or NamedTuple, don't care about exact parameter types
# Composite supports map as name preserving, and is fast
xp = map(wrap_chainrules_output, x)
convert(T_outer, xp)
end

wrap_chainrules_input(x) = conj(x)
wrap_chainrules_input(x::Tuple) = map(wrap_chainrules_input, x)
wrap_chainrules_input(::Nothing) = ChainRules.Zero()
function wrap_chainrules_input(xs::NamedTuple)
xs_comp = ChainRules.Composite{Any}(xs...)
# Composite supports map as name preserving, and is fast
xs_comp_p = map(wrap_chainrules_input, xs_comp)
end


function chain_rrule(f, args...)
y, By = rrule(f, args...)
back(::Nothing) = nothing
back(dy) = wrap_chainrules(By(dy))
return y, back
#@info "Using ChainRule" f, typeof.(args)
y, back = rrule(f, args...)

zpullback(dy) = wrap_chainrules_output(back(wrap_chainrules_input(dy)))
# `nothing->nothing` can be deleted after https://github.com/FluxML/Zygote.jl/issues/603
# though it might be worth keeping as a performance optimization (benchmarking pending)
zpullback(::Nothing) = nothing

y, zpullback
end
5 changes: 5 additions & 0 deletions src/lib/number.jl
Original file line number Diff line number Diff line change
@@ -70,6 +70,11 @@ end
@adjoint conj(x::Number) = conj(x), r̄ -> (conj(r̄),)
@adjoint imag(x::Number) = imag(x), ī -> (real(ī)*im,)

@adjoint abs(x::Real) = abs(x), Δ -> (real(Δ)*sign(x),)
@adjoint abs(x::Complex) = abs(x), Δ -> (real(Δ)*x/abs(x),)
@adjoint abs2(x::Number) = abs2(x), Δ -> (real(Δ)*(x + x),)


# DiffRules._abs_deriv(x::Complex) = x/abs(x)

# # adjoint for Fastmath operations