-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ChainRules #366
Merged
Merged
Add ChainRules #366
Changes from all commits
Commits
Show all changes
35 commits
Select commit
Hold shift + click to select a range
1c3ee6d
Add ChainRules fallback
oxinabox 0280d45
Test ChainRules integration directly
oxinabox c2b37c5
use metaprogramming in blacklist
oxinabox ec5c65f
compact choosing pullback mechanism code
oxinabox 3d9bd62
Simplify conversion Zygote style pullback
oxinabox 449a6a7
add ChainRules, rm DiffRules
oxinabox c9f2dd0
wrap CR types
MikeInnes 56c8402
wrap `nothing`
MikeInnes f640621
record fastmath regression
MikeInnes 6b354e5
conjugate appropriately
oxinabox b86e09e
multiple input and multiple output functions with chainrules
oxinabox 02a302c
delete extra code added to test by mistake
oxinabox 3e1e8e0
Integration test that we have identity working right
oxinabox 5a9ec6a
Nested AD working
oxinabox d9db59b
delete rules that are in chainrules
oxinabox cb9ec4c
delete diag rule also
oxinabox aa29144
bump verion
oxinabox 1fd415a
add in docs about prefer to use ChainRules
oxinabox d9a53a0
delete commented out fastmath code
oxinabox 5cf8b10
comment all about the chainrules interface
oxinabox 5300050
Update the Manifest.toml
oxinabox 329b530
Update docs/src/adjoints.md
oxinabox 8cd76f4
Update src/lib/array.jl
oxinabox 497846a
WIP: support kwargs
oxinabox 2ff6cc3
Fix type inference
oxinabox 513b787
Fix nexting
oxinabox 29e0244
Update src/compiler/chainrules.jl
oxinabox 325d304
Update src/compiler/chainrules.jl
oxinabox 1432563
pre1.3 do not worry about edges
oxinabox e93eca1
Fix constant folding
oxinabox 1b4db9b
remove Manifest
oxinabox cfeda77
Remove pure as not needed anymore
oxinabox cef928c
Remove special handling of multiple inputs to pullbacks from ChainRul…
oxinabox e7ecdd6
Follow up after rebase
oxinabox 41f4c17
mark as broken the test that fails on 1.0
oxinabox File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
const chainrules_fallback = which(rrule, Tuple{Any}) | ||
|
||
""" | ||
has_chain_rrule(T) | ||
|
||
For a type-tuple `T` e.g. `Tuple{typeof(f), Int, Float64}`, checks if there is a `rrule` defined for it. | ||
Excluding the generic fallback. | ||
The first return value is `true` if the `rrule` exists, `false` otherwise. | ||
If it does not, then the second argument is a list of edges to attach to the CodeInfo for a generated function, | ||
such that if a suitable rule is defined later, the generated function will recompile. | ||
""" | ||
function has_chain_rrule(T) | ||
m = meta(Tuple{typeof(rrule),T.parameters...}) | ||
if m.method !== chainrules_fallback | ||
# found a rrule, no need to add any edges | ||
return true, nothing | ||
end | ||
|
||
# did not find anything, will have to attach edges so it recompiles if one is added | ||
@static if VERSION >= v"1.3" | ||
@assert m.code.edges !== nothing | ||
return false, m.code.edges | ||
else | ||
# pre-julia 1.3 there are no edges | ||
return false, tuple() | ||
end | ||
end | ||
|
||
""" | ||
is_kwfunc(sigt...) | ||
|
||
willtebbutt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Determines if `sigt` is the type signature of a kwfunction. | ||
Each element of `sigt` should be a type. | ||
Either the first 3 types are a kwfunc type, a NamedTuple and the matching base function type, | ||
or the first argument is the base function type and it is not a kwfunction. | ||
the remaining types in `sigt` are the types of the argument. | ||
|
||
""" | ||
is_kwfunc(::Vararg{Any}) = false | ||
is_kwfunc(k, ::Type{<:NamedTuple}, f, args...) = k===Core.kwftype(f) | ||
|
||
|
||
""" | ||
wrap_chainrules_output(x) | ||
|
||
Convert `x` from the differentials types ChainRules uses to the format Zygote uses internally | ||
(including conjugating complex gradients). | ||
""" | ||
@inline wrap_chainrules_output(x) = conj(unthunk(x)) # For now we are just not going to deal with thunks | ||
@inline wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x) | ||
@inline wrap_chainrules_output(x::ChainRules.AbstractZero) = nothing | ||
for T_outer in (:Tuple, :NamedTuple) | ||
# we create separate methods rather than using a `Union` + an `if` so that we avoid a | ||
# branch that changes output type, because nested AD on that kinda thing makes Zygote less | ||
# than happy. | ||
@eval @inline function wrap_chainrules_output(x::ChainRules.Composite{P, T}) where {P, T<:$T_outer} | ||
xp = map(wrap_chainrules_output, x) | ||
convert($T_outer, xp) | ||
end | ||
end | ||
|
||
""" | ||
wrap_chainrules_input(x) | ||
|
||
Convert `x` from the format Zygote uses internally (including conjugated complex gradients) | ||
to differentials types ChainRules uses. | ||
""" | ||
@inline wrap_chainrules_input(x) = conj(x) | ||
@inline wrap_chainrules_input(::Nothing) = ChainRules.Zero() | ||
@inline function wrap_chainrules_input(xs::Union{Tuple, NamedTuple}) | ||
xp = map(wrap_chainrules_input, xs) | ||
ChainRules.Composite{Any, typeof(xp)}(xp) | ||
end | ||
|
||
""" | ||
ZBack{F}(back) <: Function | ||
|
||
Wrapper for a ChainRules pullback `back`, that causes it to follow Zygote conventions. | ||
(A functor here is used rather than a closure to avoid boxing issues); | ||
""" | ||
struct ZBack{F} <: Function | ||
back::F | ||
end | ||
@inline (s::ZBack)(dy) = wrap_chainrules_output(s.back(wrap_chainrules_input(dy))) | ||
# `nothing->nothing` can be deleted after https://github.com/FluxML/Zygote.jl/issues/603 | ||
# though it might be worth keeping as a performance optimization (benchmarking pending) | ||
@inline (s::ZBack)(::Nothing) = nothing | ||
|
||
""" | ||
chain_rrule(f, args...) | ||
|
||
Returns a the (primal) value of `f(args...)` and a pullback, by invoking `ChainRulesCore.rrule(f, args...)`. | ||
The pullback is appropriately wrapped up to follow Zygote conventions. | ||
""" | ||
@inline function chain_rrule(f, args...) | ||
y, back = rrule(f, args...) | ||
return y, ZBack(back) | ||
end | ||
|
||
oxinabox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
""" | ||
chain_rrule_kw(kwf, kwargs, f, args...) | ||
|
||
As per [`chain_rrule`](@ref) but with support for kwargs. | ||
`kwf` should be the kwfunc matching to `f`, and `kwargs` are a `NamedTuple` of keyword arguments. | ||
""" | ||
@inline function chain_rrule_kw(kwf, kwargs, f, args...) | ||
y, back = rrule(f, args...; kwargs...) | ||
kw_zpullback(dy) = (nothing, nothing, ZBack(back)(dy)...) # first two nothings are for kwfunc and kwargs | ||
return y, kw_zpullback | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe this is meant to be expanded
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how so?
Note this is an internal docstring, so it is for people who are areading the code along-side.