Skip to content

Commit

Permalink
Insert _project into getproperty's gradient, and then improve `z2…
Browse files Browse the repository at this point in the history
…d` etc. to restore stability (#1104)

* insert _project into getproperty

* use zygote2differential in _project

* improve type-stability of zygote2differential

* fix 1 test and break 2

* 2 not broken in fact

* skip inference test

* skip more inference tests

* improve inference for 1.6

* skip a test on 1.6

* skip 2

* handle nothings

* re-enable some inference tests on 1.6

* arrays of abstract tangents, and NamedTuple tests

* reverse dispatch for wrap_chainrules_input

* fix a typo

* fix more notation

* restore a test

* add DynamicPPL.jl

* fix a test

* try removing piracy

* restore some piracy, tidy

* reinterpret

* reinterpret

* collapse nothings

* DistributionsAD too

* collapse zeros in z2d

* comments

* indents

* change one comment
  • Loading branch information
mcabbott authored Nov 7, 2021
1 parent 60f53e7 commit 4ed3a86
Show file tree
Hide file tree
Showing 9 changed files with 239 additions and 87 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/Downstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ jobs:
package:
- {user: FluxML, repo: Flux.jl, group: All}
- {user: FluxML, repo: NNlib.jl, group: All}
- {user: TuringLang, repo: DynamicPPL.jl, group: All}
- {user: TuringLang, repo: DistributionsAD.jl, group: Zygote}
- {user: SciML, repo: DiffEqFlux.jl, group: Layers}
- {user: SciML, repo: NeuralPDE.jl, group: NNPDE}
steps:
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.29"
version = "0.6.30"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
154 changes: 116 additions & 38 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,28 +115,61 @@ for T_outer in (:Tuple, :NamedTuple)
ChainRulesCore.backing(xp) # this is accessing ChainRulesCore internals, but it is prob safe enough, and it is fastest
end
end
# Could `reinterpret` instead of broadcasting here -- TODO
@inline wrap_chainrules_output(xs::AbstractArray{<:ChainRules.Tangent}) = wrap_chainrules_output.(xs)
wrap_chainrules_output(dxs::AbstractArray{<:Number}) = dxs
wrap_chainrules_output(dxs::AbstractArray{<:AbstractArray{<:Number}}) = dxs
wrap_chainrules_output(dxs::AbstractArray) = map(wrap_chainrules_output, dxs)
#=
# As an optimisation, we can convert by `reinterpret` for bitstypes, e.g. arrays of tuples of numbers
@inline function wrap_chainrules_output(dxs::AbstractArray{<:ChainRules.Tangent{<:Any, B}}) where {B}
if isbitstype(B)
# B is the backing type. It still contains NoTangent etc, which need converting to Nothing
reinterpret(wrap_chainrules_output(B), dxs)
else
map(wrap_chainrules_output, dxs)
end
end
wrap_chainrules_output(::Type{<:AbstractZero}) = Nothing
wrap_chainrules_output(::Type{NamedTuple{L,T}}) where {L,T} = NamedTuple{L,wrap_chainrules_output(T)}
@generated function wrap_chainrules_output(::Type{T}) where T<:Tuple
inner = map(wrap_chainrules_output, T.parameters)
:(Tuple{$(inner...)})
end
=#

"""
wrap_chainrules_input(x)
wrap_chainrules_input(dx)
Convert `x` from the format Zygote uses internally to differentials types ChainRules uses.
Convert `dx` from the format Zygote uses internally to differentials types ChainRules uses.
"""
@inline wrap_chainrules_input(x) = x
@inline wrap_chainrules_input(dx) = dx
@inline wrap_chainrules_input(::Nothing) = ChainRules.ZeroTangent()
@inline wrap_chainrules_input(::Tuple{Vararg{Nothing}}) = ChainRules.ZeroTangent()
@inline wrap_chainrules_input(::AbstractArray{Nothing}) = ChainRules.ZeroTangent()
@inline function wrap_chainrules_input(xs::Union{Tuple, NamedTuple})
xp = map(wrap_chainrules_input, xs)
ChainRules.Tangent{Any, typeof(xp)}(xp)
@inline function wrap_chainrules_input(dxs::Union{Tuple, NamedTuple})
xp = map(wrap_chainrules_input, dxs)
# This produces Tangent{Any} since it does not get to see the primal, `x`.
ChainRulesCore.Tangent{Any, typeof(xp)}(xp)
end
# For mutable types, including x=Ref(1), Zygote makes Ref{Any}(::NamedTuple)
@inline wrap_chainrules_input(x::Ref) = wrap_chainrules_input(x[])
# Could `reinterpret` instead of broadcasting here -- TODO
@inline wrap_chainrules_input(xs::AbstractArray{<:Ref}) = wrap_chainrules_input.(xs)
@inline wrap_chainrules_input(xs::AbstractArray{<:Union{Nothing, <:Ref}}) = wrap_chainrules_input.(xs) # no test invented for this
@inline wrap_chainrules_input(xs::AbstractArray{<:NamedTuple}) = wrap_chainrules_input.(xs)
@inline wrap_chainrules_input(xs::AbstractArray{<:Union{Nothing, <:NamedTuple}}) = wrap_chainrules_input.(xs)
@inline wrap_chainrules_input(dx::Ref) = wrap_chainrules_input(dx[])
# For arrays, whitelist the safe ones, but always look inside Any[]:
@inline wrap_chainrules_input(dxs::AbstractArray{<:Number}) = dxs
@inline wrap_chainrules_input(dxs::AbstractArray{<:AbstractArray{<:Number}}) = dxs
@inline wrap_chainrules_input(dxs::AbstractArray) = map(wrap_chainrules_input, dxs)

#=
# Could `reinterpret` instead here? See issue 1112.
# One easy case, might be this:
@inline wrap_chainrules_input(xs::Base.ReinterpretArray{<:NamedTuple, <:Tangent}) = parent(xs)
# This is for `z2d` reinterpret below:
wrap_chainrules_input(::Type{Nothing}) = NoTangent
wrap_chainrules_input(::Type{NamedTuple{L,T}}) where {L,T} = NamedTuple{L,wrap_chainrules_input(T)}
@generated function wrap_chainrules_input(::Type{T}) where T<:Tuple
inner = map(wrap_chainrules_input, T.parameters)
:(Tuple{$(inner...)})
end
=#

"""
_project(x, dx)
Expand All @@ -146,21 +179,13 @@ Also handles some Zygote-specific corrections, such as `x::Array, dx::Tuple`.
Safe to apply to arbitrary input.
"""
@inline function _project(x, dx)
# Note that this use of `wrap_chainrules_input` has the primal `x`, so could
# avoid making `Tangent{Any}`, perhaps via `zygote2differential` -- TODO.
wrap_chainrules_output(ProjectTo(x)(wrap_chainrules_input(dx)))
wrap_chainrules_output(ProjectTo(x)(zygote2differential(dx, x)))
end

# Restore splatted arrays
_project(x::AbstractArray, dx::Tuple) = _project(x, reshape(collect(dx), axes(x)))

# Piracy:
# wrap_chainrules_input doesn't handle array of Union{Int,Nothing}
(::ChainRulesCore.ProjectTo)(::Nothing) = ChainRulesCore.NoTangent()

# CRC likes Tangent{<:Complex}, but Zygote makes Tangent{Any}
(project::ProjectTo{<:Complex})(dx::Tangent) = project(Complex(dx.re, dx.im))

# CRC likes Tangent{AbstractArray}, but Zygote makes Tangent{Any}
# in particular this would hit https://github.com/JuliaDiff/ChainRulesCore.jl/blob/2ec2549b73b22bc08f554dae864fb650cfb9c3d7/src/projection.jl#L139
# if we were not losing track of the Primal in the Tangent
Expand Down Expand Up @@ -236,32 +261,85 @@ end
zygote2differential(dx, primal)
Convert input `dx` from the Zygote format to the ChainRules differential types.
This is similar to `wrap_chainrules_input(dx)`, but because it gets `primal::T`,
it can turn `NamedTuple`s into `Tangent{T}(...)` not `Tangent{Any}(...)`.
"""
zygote2differential(x, primal) = z2d(x, primal)
zygote2differential(::Nothing, ::Any) = NoTangent()
zygote2differential(t::Tuple, primal::Tuple) = map(z2d, t, primal)
zygote2differential(t::Tuple, primal) = (@warn "primal should be a tuple, not $primal"; return t)
z2d(x, ::Any) = x

z2d(::Nothing, ::Any) = NoTangent()
z2d(a::AbstractArray{<:Number}, primal::AbstractArray{T}) where T = a
# Could probably `reinterpret` instead of broadcasting here -- TODO
z2d(a::AbstractArray, primal::AbstractArray{T}) where T = z2d.(a, primal)
z2d(::Tuple{Vararg{Nothing}}, ::Tuple) = NoTangent() # collapse all-zero case
z2d(dx, ::Any) = dx
z2d(dx::AbstractArray{<:Number}, primal::AbstractArray) = dx
z2d(dx::AbstractArray{<:AbstractArray{<:Number}}, primal::AbstractArray) = dx
z2d(dx::AbstractArray, primal::AbstractArray) = map(z2d, dx, primal)
#=
# As an optimisation, we can convert by `reinterpret` for bitstypes, e.g. arrays of tuples of numbers
function z2d(dx::AbstractArray{S}, primal::AbstractArray{P}) where {S,P}
if isbitstype(S)
T = wrap_chainrules_input(S)
reinterpret(Tangent{P,T}, dx)
else
map(z2d, dx, primal)
end
end
=#

# Note: this should never be hit if we are converting things right, but it seems to be
# happening in the wild for sufficiently weird functions/types.
# This fixes most (all?) cases, but it would be good to find what we miss.
z2d(x::Union{AbstractZero, Tangent}, ::Any) = return x
function z2d(t::Tuple, primal::Tuple)
tp::Tuple = map(z2d, t, primal)
primal_type = typeof(primal)
return canonicalize(Tangent{primal_type, typeof(tp)}(tp))

function z2d(delta::Tuple, primal::Tuple)
backing = map(z2d, delta, primal)
if backing isa Tuple{Vararg{AbstractZero}}
return NoTangent() # collapse all-zero case
else
return canonicalize(Tangent{typeof(primal), typeof(backing)}(backing))
end
end

function z2d(t::NamedTuple, primal)
primal_type = typeof(primal)
fnames = fieldnames(primal_type)
complete_t = NamedTuple{fnames}(fn in keys(t) ? t[fn] : nothing for fn in fnames)
primals = NamedTuple{fnames}(getfield(primal, fn) for fn in fnames)
tp::NamedTuple = map(z2d, complete_t, primals)
return canonicalize(Tangent{primal_type, typeof(tp)}(tp))
# Dict handling in Zygote is a mess... should this become a `Tangent{Dict,Dict}` ?
# Right now it uses a NamedTuple but not for fields of the AbstractDict struct
z2d(dx::NamedTuple, primal::AbstractDict) = dx

function z2d(delta::NamedTuple, primal::T) where T # arbitrart struct
fnames = fieldnames(T)
deltas = map(n -> get(delta, n, nothing), fnames)
primals = map(n -> getfield(primal, n), fnames)
inner = map(z2d, deltas, primals) # recurse into fields
if inner isa Tuple{Vararg{AbstractZero}}
return NoTangent() # collapse all-zero case
else
backing = NamedTuple{fnames}(inner)
return canonicalize(Tangent{T, typeof(backing)}(backing))
end
end

# Dict case matches signature for ambiguity reasons:
z2d(dx::NamedTuple{L,S}, primal::AbstractDict) where {L,S<:Tuple{Vararg{Union{Number,Nothing}}}} = dx
# On Julia <= 1.6, this fixes easy cases which do not require recursion into fields, e.g.
# @inferred Zygote.z2d((re=1, im=nothing), 3.0+im)
@generated function z2d(delta::NamedTuple{L,S}, primal::T) where {L,S<:Tuple{Vararg{Union{Number,Nothing}}}, T}
fnames = fieldnames(T)
deltas = map(fnames) do n
i = findfirst(isequal(n), L)
if i == nothing || S.parameters[i] == Nothing
:(NoTangent())
else
:(delta.$n)
end
end
if all(d -> d == :(NoTangent()), deltas)
return :(NoTangent()) # collapse all-zero case
else
return quote
backing = NamedTuple{$fnames}(($(deltas...),))
Tangent{$T, typeof(backing)}(backing)
end
end
end

z2d(dx::Ref, primal) = z2d(dx[], primal) # mutable structs
3 changes: 2 additions & 1 deletion src/lib/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ end
function back(Δ)
accum_param(__context__, val, Δ) === nothing && return
if isimmutable(x)
((; nt_nothing(x)..., pair(Val(f), Δ, x)...), nothing)
dx = (; nt_nothing(x)..., pair(Val(f), Δ, x)...)
(_project(x, dx), nothing)
else
dx = grad_mut(__context__, x)
dx[] = (; dx[]..., pair(Val(f), accum(getfield(dx[], f), Δ))...)
Expand Down
24 changes: 24 additions & 0 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,27 @@ end
@fastmath x^2.0
end == (4.0,)
end

@testset "zygote2differential inference" begin
@test @inferred(Zygote.z2d(1.0, 2.0)) isa Real
@test @inferred(Zygote.z2d([1,2,3], [4,5,6])) isa Vector
@test @inferred(Zygote.z2d((1, 2.0, 3+4im), (5, 6.0, 7+8im))) isa Tangent{<:Tuple}

# Below Julia 1.7, these need a @generated version to be inferred:
@test @inferred(Zygote.z2d((re=1,), 3.0+im)) isa Tangent{ComplexF64}
@test @inferred(Zygote.z2d((re=1, im=nothing), 3.0+im)) isa Tangent{ComplexF64}

# collapse nothings
@test @inferred(Zygote.z2d((nothing,), (1,))) === NoTangent()
@test @inferred(Zygote.z2d((nothing, nothing), (1,2))) === NoTangent()

# To test the generic case, we need a struct within a struct.
nested = Tangent{Base.RefValue{ComplexF64}}(; x=Tangent{ComplexF64}(; re=1, im=NoTangent()),)
if VERSION > v"1.7-"
@test @inferred(Zygote.z2d((; x=(; re=1)), Ref(3.0+im))) == nested
@test @inferred(Zygote.z2d((; x=(; re=nothing)), Ref(3.0+im))) === NoTangent()
else
@test Zygote.z2d((; x=(; re=1)), Ref(3.0+im)) == nested
@test Zygote.z2d((; x=(; re=nothing)), Ref(3.0+im)) === NoTangent()
end
end
4 changes: 3 additions & 1 deletion test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ end
y, back = @inferred pullback(x -> x.m, g)
@test y == getfield(g, :m)
# This type instability is due to the handling of non-bitstypes in `accum_param`
@test Base.return_types(back, Tuple{Vector{Float64}}) == Any[Union{Tuple{Nothing}, typeof(((m = [1.0, 0.0, 0.0], P = nothing),))}]
if VERSION > v"1.7-"
@test Base.return_types(back, Tuple{Vector{Float64}}) == Any[Union{Tuple{Nothing}, typeof(((m = [1.0, 0.0, 0.0], P = nothing),))}]
end
@test back([1., 0, 0]) == ((m = [1.0, 0.0, 0.0], P = nothing),)

Base.getproperty(g::Gaussian, s::Symbol) = 2getfield(g, s)
Expand Down
37 changes: 35 additions & 2 deletions test/features.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,13 @@ end

@test gradient(t -> t[1]*t[2], (2, 3)) == ((3, 2),)

@test gradient(x -> x.re, 2+3im) === (1.0 + 0.0im,)
@test gradient(x -> x.re, 2+3im) === (1.0 + 0.0im,) # one NamedTuple
@test gradient(x -> x.re*x.im, 2+3im) == (3.0 + 2.0im,) # two, different fields
@test gradient(x -> x.re*x.im + x.re, 2+3im) == (4.0 + 2.0im,) # three, with accumulation

@test gradient(x -> x.re*x.im, 2+3im) == (3.0 + 2.0im,)
@test gradient(x -> abs2(x * x.re), 4+5im) == (456.0 + 160.0im,) # gradient participates
@test gradient(x -> abs2(x * real(x)), 4+5im) == (456.0 + 160.0im,) # function not getproperty
@test gradient(x -> abs2(x * getfield(x, :re)), 4+5im) == (456.0 + 160.0im,)

struct Bar{T}
a::T
Expand Down Expand Up @@ -418,6 +422,11 @@ end

@test gradient((x,y,z) -> sum((x,y,z)[1:2]), 7, 8.8, 9.9) == (1.0, 1.0, nothing)
@test gradient((x,y,z) -> sum((x,y,z)[[1,2,1]]), 1,2,3) == (2, 1, nothing)

@test gradient(xs -> sum(x -> x[2], xs), [(1,2,3), (4,5,6)]) == ([(nothing, 1.0, nothing), (nothing, 1.0, nothing)],)
@test gradient(xs -> sum(x -> prod(x[2:3]), xs), [(1,2,3), (4,5,6)]) == ([(nothing, 3.0, 2.0), (nothing, 6.0, 5.0)],)
@test gradient(xs -> sum(first, xs), fill((4,3),2)) == ([(1.0, nothing), (1.0, nothing)],)
@test gradient(xs -> sum(x -> abs2(x[1]), xs), fill((4,3),2)) == ([(8.0, nothing), (8.0, nothing)],)
end

@testset "@timed" begin
Expand Down Expand Up @@ -452,6 +461,13 @@ end
@test gradient(x -> x.x^2 + x.x, Ref(3)) === ((x = 7.0,),)
@test gradient(x -> real(x.x^2 + im * x.x), Ref(4)) === ((x = 8.0,),)

# Field access of contents:
@test gradient(x -> abs2(x.x) + 7 * x.x.re, Ref(1+im)) == ((x = 9.0 + 2.0im,),)
@test_broken gradient(x -> abs2(x[1].x) + 7 * x[1].x.re, [Ref(1+im)]) == ([(x = 9.0 + 2.0im,)],)
@test_broken gradient(x -> abs2(x[1].x) + 7 * real(x[1].x), [Ref(1+im)]) == ([(x = 9.0 + 2.0im,)],) # worked on 0.6.0, 0.6.20

@test_broken gradient(x -> abs2(x[].x) + 7 * real(x[].x), Ref(Ref(1+im))) == ((x = 9.0 + 2.0im,),) # gives nothing, same in 0.6.0

# Array of mutables:
@test gradient(x -> sum(getindex.(x).^2), Ref.(1:3))[1] == [(;x=2i) for i in 1:3]
@test gradient(x -> sum(abs2getindex, x), Ref.(1:3))[1] == [(;x=2i) for i in 1:3]
Expand All @@ -464,6 +480,17 @@ end
@test gradient(x -> sum(sum, Ref(x) .* [1,2,3]), [4,5]) == ([6.0, 6.0],)
end

@testset "NamedTuples" begin
@test gradient(x -> x.a, (a=1, b=2)) == ((a = 1, b = nothing),)
@test gradient(x -> x[1].a, [(a=1, b=2)]) == ([(a = 1, b = nothing)],)
@test gradient(x -> x[1].a, [(a=1, b=2), (a=3, b=4)]) == ([(a = 1, b = nothing), nothing],)

# Mix with Ref
@test gradient(x -> x[].a, Ref((a=1, b=2))) == ((x = (a = 1, b = nothing),),)
@test gradient(x -> x[1][].a, [Ref((a=1, b=2)), Ref((a=3, b=4))]) == ([(x = (a = 1, b = nothing),), nothing],)
@test gradient(x -> x[1].a, [(a=1, b=2), "three"]) == ([(a = 1, b = nothing), nothing],)
end

function type_test()
Complex{<:Real}
end
Expand Down Expand Up @@ -692,4 +719,10 @@ end
@test gradient(x -> sum(gradient(y -> sum(y.^2), x)[1]), [1, 2])[1] [2, 2]
@test gradient(x -> sum(gradient(y -> sum(sin.(y)), x)[1]), [1, 2])[1] [-0.8414709848078965, -0.9092974268256817]
@test gradient(x -> sum(abs, gradient(y -> sum(log.(2 .* exp.(y)) .^ 2), x)[1]), [1, 2])[1] [2,2]

# getproperty, Tangents, etc
@test gradient(xs -> sum((x->x.im^2).(xs)), [1+2im,3])[1] == [4im, 0]
@test gradient(xs -> sum((x->x.im^2), xs), [1+2im,3])[1] == [4im, 0]
@test gradient(xs -> sum(map(x->x.im^2, xs)), [1+2im,3])[1] == [4im, 0]
@test gradient(xs -> mapreduce(x->x.im^2, +, xs), [1+2im,3])[1] == [4im, 0]
end
12 changes: 12 additions & 0 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1878,6 +1878,18 @@ end
a = rand(3)
@test Zygote.gradient(x->sum(x .+ rand.()), a) == (ones(3),)

@testset "Zygote 660" begin
# https://github.com/FluxML/Zygote.jl/pull/660
function example(x,N)
ax = axes(x)
extraAxe = ax[2+N:end]
filledLoc = fill(1, N)
return x[:, filledLoc..., extraAxe...]
end
y, back = pullback(example, randn(5,3,4,3), 2)
@test back(zero(y).=1) isa Tuple{Array{Float64,4}, Nothing}
end

@testset "CRC issue 440" begin
# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/440
f(x,y) = sum(sum, [[x[i],y[i]] for i=1:length(x)])
Expand Down
Loading

2 comments on commit 4ed3a86

@mcabbott
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/48364

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.30 -m "<description of version>" 4ed3a86db708a27bfe0afd5aeaa6408dd8d43a3e
git push origin v0.6.30

Please sign in to comment.