Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ChainRules #366

Merged
merged 35 commits into from
May 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
1c3ee6d
Add ChainRules fallback
oxinabox Oct 7, 2019
0280d45
Test ChainRules integration directly
oxinabox Oct 18, 2019
c2b37c5
use metaprogramming in blacklist
oxinabox Oct 20, 2019
ec5c65f
compact choosing pullback mechanism code
oxinabox Oct 23, 2019
3d9bd62
Simplify conversion Zygote style pullback
oxinabox Jan 13, 2020
449a6a7
add ChainRules, rm DiffRules
oxinabox Apr 20, 2020
c9f2dd0
wrap CR types
MikeInnes Apr 14, 2020
56c8402
wrap `nothing`
MikeInnes Apr 20, 2020
f640621
record fastmath regression
MikeInnes Apr 20, 2020
6b354e5
conjugate appropriately
oxinabox Apr 21, 2020
b86e09e
multiple input and multiple output functions with chainrules
oxinabox Apr 23, 2020
02a302c
delete extra code added to test by mistake
oxinabox Apr 23, 2020
3e1e8e0
Integration test that we have identity working right
oxinabox Apr 24, 2020
5a9ec6a
Nested AD working
oxinabox Apr 24, 2020
d9db59b
delete rules that are in chainrules
oxinabox Apr 28, 2020
cb9ec4c
delete diag rule also
oxinabox Apr 28, 2020
aa29144
bump verion
oxinabox Apr 29, 2020
1fd415a
add in docs about prefer to use ChainRules
oxinabox Apr 29, 2020
d9a53a0
delete commented out fastmath code
oxinabox Apr 30, 2020
5cf8b10
comment all about the chainrules interface
oxinabox Apr 30, 2020
5300050
Update the Manifest.toml
oxinabox Apr 30, 2020
329b530
Update docs/src/adjoints.md
oxinabox May 4, 2020
8cd76f4
Update src/lib/array.jl
oxinabox May 4, 2020
497846a
WIP: support kwargs
oxinabox May 4, 2020
2ff6cc3
Fix type inference
oxinabox May 5, 2020
513b787
Fix nexting
oxinabox May 5, 2020
29e0244
Update src/compiler/chainrules.jl
oxinabox May 5, 2020
325d304
Update src/compiler/chainrules.jl
oxinabox May 5, 2020
1432563
pre1.3 do not worry about edges
oxinabox May 6, 2020
e93eca1
Fix constant folding
oxinabox May 6, 2020
1b4db9b
remove Manifest
oxinabox May 7, 2020
cfeda77
Remove pure as not needed anymore
oxinabox May 7, 2020
cef928c
Remove special handling of multiple inputs to pullbacks from ChainRul…
oxinabox May 20, 2020
e7ecdd6
Follow up after rebase
oxinabox May 28, 2020
41f4c17
mark as broken the test that fails on 1.0
oxinabox May 28, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ version = "0.4.20"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Future = "9fa8497b-333b-5362-9e8d-4d0656e87820"
Expand All @@ -14,25 +14,21 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AbstractFFTs = "0.5"
ArrayLayouts = "0.1, 0.2, 0.3"
DiffRules = "0.0, 0.1, 1"
ArrayLayouts = "0.1, 0.2"
ChainRules = "0.6.0"
FillArrays = "0.8"
ForwardDiff = "0"
IRTools = "0.3"
IRTools = "=0.3.1"
MacroTools = "0.5"
NNlib = "0.6.5"
NaNMath = "0"
Requires = "0.5, 1.0"
SpecialFunctions = "0"
ZygoteRules = "0.2"
julia = "1"

Expand Down
10 changes: 10 additions & 0 deletions docs/src/adjoints.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
# Custom Adjoints

!!! note "Prefer to use ChainRules to define custom adjoints"
Zygote supports the use of [ChainRulesCore](http://www.juliadiff.org/ChainRulesCore.jl/stable/) to define custom sensitivities.
It is prefered to define the custom sensitivities using `ChainRulesCore.rrule` as they will work for many AD systems, not just Zygote.
These sensitivities can be added in your own package, or for Base functions they can be added to ChainRules.jl.

This documentation exists to descibe how Zygote works, and how adjoints can be directly defined for Zygote.
Defining adjoints this way does not make them accessible to other AD systems, but does let you do things that directly depend on how Zygote works.
It allows for specific definitions of adjoints that are only defined for Zgyote (which might work differently to more generic definitions defined for all AD).


The `@adjoint` macro is an important part of Zygote's interface; customising your backwards pass is not only possible but widely used and encouraged. While there are specific utilities available for common things like gradient clipping, understanding adjoints will give you the most flexibility. We first give a bit more background on what these pullback things are.

## Pullbacks
Expand Down
2 changes: 2 additions & 0 deletions src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using ArrayLayouts: MemoryLayout, AbstractColumnMajor

import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback, literal_getproperty

using ChainRules: ChainRules, rrule, unthunk
using IRTools
using MacroTools, Requires
using MacroTools: @forward
Expand All @@ -17,6 +18,7 @@ include("tools/buffer.jl")

include("compiler/reverse.jl")
include("compiler/emit.jl")
include("compiler/chainrules.jl")
include("compiler/interface.jl")
include("compiler/show.jl")

Expand Down
111 changes: 111 additions & 0 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
const chainrules_fallback = which(rrule, Tuple{Any})

"""
has_chain_rrule(T)

For a type-tuple `T` e.g. `Tuple{typeof(f), Int, Float64}`, checks if there is a `rrule` defined for it.
Excluding the generic fallback.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this is meant to be expanded

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how so?
Note this is an internal docstring, so it is for people who are areading the code along-side.

The first return value is `true` if the `rrule` exists, `false` otherwise.
If it does not, then the second argument is a list of edges to attach to the CodeInfo for a generated function,
such that if a suitable rule is defined later, the generated function will recompile.
"""
function has_chain_rrule(T)
m = meta(Tuple{typeof(rrule),T.parameters...})
if m.method !== chainrules_fallback
# found a rrule, no need to add any edges
return true, nothing
end

# did not find anything, will have to attach edges so it recompiles if one is added
@static if VERSION >= v"1.3"
@assert m.code.edges !== nothing
return false, m.code.edges
else
# pre-julia 1.3 there are no edges
return false, tuple()
end
end

"""
is_kwfunc(sigt...)

willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
Determines if `sigt` is the type signature of a kwfunction.
Each element of `sigt` should be a type.
Either the first 3 types are a kwfunc type, a NamedTuple and the matching base function type,
or the first argument is the base function type and it is not a kwfunction.
the remaining types in `sigt` are the types of the argument.

"""
is_kwfunc(::Vararg{Any}) = false
is_kwfunc(k, ::Type{<:NamedTuple}, f, args...) = k===Core.kwftype(f)


"""
wrap_chainrules_output(x)

Convert `x` from the differentials types ChainRules uses to the format Zygote uses internally
(including conjugating complex gradients).
"""
@inline wrap_chainrules_output(x) = conj(unthunk(x)) # For now we are just not going to deal with thunks
@inline wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x)
@inline wrap_chainrules_output(x::ChainRules.AbstractZero) = nothing
for T_outer in (:Tuple, :NamedTuple)
# we create separate methods rather than using a `Union` + an `if` so that we avoid a
# branch that changes output type, because nested AD on that kinda thing makes Zygote less
# than happy.
@eval @inline function wrap_chainrules_output(x::ChainRules.Composite{P, T}) where {P, T<:$T_outer}
xp = map(wrap_chainrules_output, x)
convert($T_outer, xp)
end
end

"""
wrap_chainrules_input(x)

Convert `x` from the format Zygote uses internally (including conjugated complex gradients)
to differentials types ChainRules uses.
"""
@inline wrap_chainrules_input(x) = conj(x)
@inline wrap_chainrules_input(::Nothing) = ChainRules.Zero()
@inline function wrap_chainrules_input(xs::Union{Tuple, NamedTuple})
xp = map(wrap_chainrules_input, xs)
ChainRules.Composite{Any, typeof(xp)}(xp)
end

"""
ZBack{F}(back) <: Function

Wrapper for a ChainRules pullback `back`, that causes it to follow Zygote conventions.
(A functor here is used rather than a closure to avoid boxing issues);
"""
struct ZBack{F} <: Function
back::F
end
@inline (s::ZBack)(dy) = wrap_chainrules_output(s.back(wrap_chainrules_input(dy)))
# `nothing->nothing` can be deleted after https://github.com/FluxML/Zygote.jl/issues/603
# though it might be worth keeping as a performance optimization (benchmarking pending)
@inline (s::ZBack)(::Nothing) = nothing

"""
chain_rrule(f, args...)

Returns a the (primal) value of `f(args...)` and a pullback, by invoking `ChainRulesCore.rrule(f, args...)`.
The pullback is appropriately wrapped up to follow Zygote conventions.
"""
@inline function chain_rrule(f, args...)
y, back = rrule(f, args...)
return y, ZBack(back)
end

oxinabox marked this conversation as resolved.
Show resolved Hide resolved

"""
chain_rrule_kw(kwf, kwargs, f, args...)

As per [`chain_rrule`](@ref) but with support for kwargs.
`kwf` should be the kwfunc matching to `f`, and `kwargs` are a `NamedTuple` of keyword arguments.
"""
@inline function chain_rrule_kw(kwf, kwargs, f, args...)
y, back = rrule(f, args...; kwargs...)
kw_zpullback(dy) = (nothing, nothing, ZBack(back)(dy)...) # first two nothings are for kwfunc and kwargs
return y, kw_zpullback
end
1 change: 0 additions & 1 deletion src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ end
# interface2.jl

# Wrappers

_pullback(f, args...) = _pullback(Context(), f, args...)

tailmemaybe(::Nothing) = nothing
Expand Down
11 changes: 11 additions & 0 deletions src/compiler/interface2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,24 @@ ignore_sig(T) = all(T -> T <: Type, T.parameters)
@generated function _pullback(ctx::AContext, f, args...)
T = Tuple{f,args...}
ignore_sig(T) && return :(f(args...), Pullback{$T}(()))

iskw = is_kwfunc(f, args...)
# if it is_kw then `args[1]` are the keyword args, `args[2]` is actual function
base_T = iskw ? Tuple{args[2:end]...} : T
hascr, cr_edges = has_chain_rrule(base_T)
chain_rrule_f = iskw ? :chain_rrule_kw : :chain_rrule
hascr && return :($chain_rrule_f(f, args...))

g = try _lookup_grad(T) catch e e end
!(g isa Tuple) && return :(f(args...), Pullback{$T}((f,)))
meta, forw, _ = g
argnames!(meta, Symbol("#self#"), :ctx, :f, :args)
forw = varargs!(meta, forw, 3)
# IRTools.verify(forw)
forw = slots!(pis!(inlineable!(forw)))
@static if VERSION >= v"1.3" # no edges pre-1.3
append!(meta.code.edges, cr_edges) # be ready to swap to using chainrule if one is declared
end
return update!(meta.code, forw)
end

Expand Down
16 changes: 2 additions & 14 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,6 @@ end
@adjoint parent(x::LinearAlgebra.Adjoint) = parent(x), ȳ -> (LinearAlgebra.Adjoint(ȳ),)
@adjoint parent(x::LinearAlgebra.Transpose) = parent(x), ȳ -> (LinearAlgebra.Transpose(ȳ),)

@adjoint dot(x::AbstractArray, y::AbstractArray) = dot(x, y), Δ->(Δ .* y, Δ .* x)

function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
m1, n1 = size(mat1)
mat1_rsh = reshape(mat1,(1,m1,1,n1))
Expand All @@ -361,18 +359,6 @@ end

@adjoint kron(a::AbstractMatrix, b::AbstractMatrix) = pullback(_kron, a, b)

@adjoint function Diagonal(d::AbstractVector)
back(Δ::NamedTuple) = (Δ.diag,)
back(Δ::AbstractMatrix) = (diag(Δ),)
return Diagonal(d), back
end

@adjoint diag(A::AbstractMatrix) = diag(A), Δ->(Diagonal(Δ),)

@adjoint det(xs::Union{Number, AbstractMatrix}) = det(xs), Δ -> (Δ * det(xs) * inv(xs)',)

@adjoint logdet(xs::Union{Number, AbstractMatrix}) = logdet(xs), Δ -> (Δ * inv(xs)',)

@adjoint logabsdet(xs::AbstractMatrix) = logabsdet(xs), Δ -> (Δ[1] * inv(xs)',)

@adjoint function inv(A::Union{Number, AbstractMatrix})
Expand Down Expand Up @@ -737,6 +723,8 @@ end
end
end

# ChainRules has this also but does not use FillArrays, so we have our own definition
# for improved performance. See https://github.com/JuliaDiff/ChainRules.jl/issues/46
Zygote.@adjoint function LinearAlgebra.tr(x::AbstractMatrix)
# x is a squre matrix checked by tr,
# so we could just use Eye(size(x, 1))
Expand Down
78 changes: 5 additions & 73 deletions src/lib/number.jl
Original file line number Diff line number Diff line change
@@ -1,37 +1,3 @@
using DiffRules, SpecialFunctions, NaNMath
using Base.FastMath: fast_op, make_fastmath

@nograd isinf, isnan, isfinite, div

# TODO use CSE here

for (M, f, arity) in DiffRules.diffrules()
arity == 1 || continue
Δ = :Δ
dx = DiffRules.diffrule(M, f, :x)
if f in [:abs, :abs2]
Δ = :(real($Δ))
else
dx = :(conj($dx))
end
@eval begin
@adjoint $M.$f(x::Number) = $M.$f(x),
Δ -> ($Δ * $dx,)
end
end

for (M, f, arity) in DiffRules.diffrules()
arity == 2 || continue
f == :^ && continue
da, db = DiffRules.diffrule(M, f, :a, :b)
@eval begin
@adjoint $M.$f(a::Number, b::Number) = $M.$f(a, b),
Δ -> (Δ * conj($da), Δ * conj($db))
end
end

@adjoint Base.:^(x::Number, p::Number) = x^p,
Δ -> (Δ * conj(p * x^(p-1)), Δ * conj(x^p * log(complex(x))))
@adjoint Base.literal_pow(::typeof(^), x::Number, ::Val{p}) where {p} =
Base.literal_pow(^,x,Val(p)),
Δ -> (nothing, Δ * conj(p * Base.literal_pow(^,x,Val(p-1))), nothing)
Expand All @@ -45,20 +11,6 @@ end

@adjoint Base.:+(xs::Number...) = +(xs...), Δ -> map(_ -> Δ, xs)

@adjoint Base.muladd(x::Number, y::Number, z::Number) =
Base.muladd(x, y, z), ō -> (y'ō, x'ō, ō)

@adjoint Base.fma(x::Number, y::Number, z::Number) =
Base.fma(x, y, z), ō -> (y'ō, x'ō, ō)

@adjoint function sincos(x)
s, c = sincos(x)
(s, c), ((s̄, c̄),) -> (s̄*c - c̄*s,)
end

@adjoint acosh(x::Complex) =
acosh(x), Δ -> (Δ * conj(inv(sqrt(x - 1) * sqrt(x + 1))),)

@adjoint a // b = (a // b, c̄ -> (c̄ * 1//b, - c̄ * a // b // b))

@nograd floor, ceil, trunc, round, hash
Expand All @@ -71,28 +23,8 @@ end
@adjoint conj(x::Number) = conj(x), r̄ -> (conj(r̄),)
@adjoint imag(x::Number) = imag(x), ī -> (real(ī)*im,)

DiffRules._abs_deriv(x::Complex) = x/abs(x)

# adjoint for Fastmath operations
for (f, fastf) in fast_op
if DiffRules.hasdiffrule(:Base, f, 1)
dx = DiffRules.diffrule(:Base, f, :x)
Δ = :Δ
if f in [:abs, :abs2]
Δ = :(real($Δ))
else
dx = :(conj($dx))
end
@eval begin
@adjoint Base.FastMath.$fastf(x::Number) =
Base.FastMath.$fastf(x), Δ -> ($Δ * make_fastmath($dx),)
end
elseif DiffRules.hasdiffrule(:Base, f, 2)
dx, dy = DiffRules.diffrule(:Base, f, :x, :y)
@eval begin
@adjoint Base.FastMath.$fastf(x::Number, y::Number) =
Base.FastMath.$fastf(x, y),
Δ -> (Δ * make_fastmath(conj($dx)), Δ * make_fastmath(conj($dy)))
end
end
end
# we intentionally define these here rather than falling back on ChainRules.jl
# because ChainRules doesn't really handle nonanalytic complex functions
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
@adjoint abs(x::Real) = abs(x), Δ -> (real(Δ)*sign(x),)
@adjoint abs(x::Complex) = abs(x), Δ -> (real(Δ)*x/abs(x),)
@adjoint abs2(x::Number) = abs2(x), Δ -> (real(Δ)*(x + x),)
Loading