Skip to content

Commit

Permalink
Merge #1004
Browse files Browse the repository at this point in the history
1004: ensure `sum(f,x)` works on GPU r=DhairyaLGandhi a=oxinabox

Seems like there is some pains from #990 re:GPU.
In particular we broke DiffEqSensitivity
https://buildkite.com/julialang/diffeqsensitivity-dot-jl/builds/169#d254017e-e824-4d9c-854d-f3b348395599/411-877

@ChrisRackauckas 's "M"WE is
```
using DiffEqFlux, OrdinaryDiffEq, DiffEqSensitivity
using CUDA, Test, Zygote
CUDA.allowscalar(false)
H = CuArray(rand(Float32, 2, 2))
ann = FastChain(FastDense(1, 4, tanh))
p = initial_params(ann)
function func(x, p, t)
    ann([t],p)[1]*H*x
end
x0 = CuArray(rand(Float32, 2))
x1 = CuArray(rand(Float32, 2))
prob = ODEProblem(func, x0, (0.0f0, 1.0f0))
function evolve(p)
    solve(prob, Tsit5(), p=p, save_start=false,
          save_everystep=false, abstol=1e-4, reltol=1e-4,
          sensealg=QuadratureAdjoint(autojacvec=ZygoteVJP())).u[1]
end
function cost(p)
    x = evolve(p)
    c = sum(abs,x - x1)
    #println(c)
    c
end
grad = Zygote.gradient(cost,p)[1]
@test !iszero(grad[1])
@test iszero(grad[2:4])
@test !iszero(grad[5])
@test iszero(grad[6:end])
```
I am hoping we can get it to fail with just `sum(f, xs)` (which I have added to tests)}
I can't run GPU locally which makes testing this hard.
If I have to I will spin up an EC2 instance, but I would really rather not.

I think what is going on is, from looking at [the logs](https://buildkite.com/julialang/diffeqsensitivity-dot-jl/builds/169#d254017e-e824-4d9c-854d-f3b348395599/411-877)

The error happens in during the forward pass.
In particular here
https://github.com/JuliaDiff/ChainRules.jl/blob/52a0eeadf8d19bff491f224517b7b064ce1ba378/src/rulesets/Base/mapreduce.jl#L46
I think this was why Zygote implemented
the pullback of sum(f, x) as sum(f.(x)) (which is slower and more allocate-y than our never version)
so that it could hit the code that Zygote has special for CUDA that does forwards-mode.
(Which means it doesn't need the Context object containing the IdDict)
So I think the solution in short-term is probably to add the old rule for sum back in (but for CuArray only) here.
https://github.com/FluxML/Zygote.jl/blob/531da8bb7753f46294bc13f9d2a2fdd54917f926/src/lib/broadcast.jl#L244
```
# Make sure sum(f, ::CuArray) uses forward mode broadcast AD defined above
# Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU safe
@adjoint function sum(f, xs::CuArray; kws...)
  @Assert !haskey(kws, :init) # TODO add init support (julia 1.6)
  return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
end
```

In the longer-term, we will probably default to doing the f from sum(f, xs)  in forward-mode anyway.
So Zygote's rule config can be updated to say that it does use ForwardDiff.jl for it's frule_via_ad.

Co-authored-by: Lyndon White <[email protected]>
  • Loading branch information
bors[bot] and oxinabox authored Jun 21, 2021
2 parents 9759239 + bd8c5fb commit 18a6f2a
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.13"
version = "0.6.14"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
9 changes: 8 additions & 1 deletion src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,14 @@ end
placeholder = similar(xs)
sum(xs, dims = dims), Δ -> (placeholder .= Δ,)
end


# Make sure sum(f, ::CuArray) uses broadcase through forward-mode defined above
# Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible
@adjoint function sum(f, xs::CUDA.CuArray; kws...)
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
end

@adjoint function Base.convert(::Type{T}, xs::Array) where {T<:CUDA.CuArray}
Base.convert(T, xs), Δ -> (nothing, Base.convert(Array, Δ),)
end
Expand Down
11 changes: 11 additions & 0 deletions test/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ end
@test g_gpu |> collect g
end

@testset "sum(f, x)" begin
a = Float32.([-1.5, -9.0, 2.4, -1.3, 0.01])
a_gpu = a |> cu

f(x) = sum(abs, x)
g = gradient(f, a)[1]
g_gpu = gradient(f, a_gpu)[1]
@test g_gpu isa CuArray
@test g_gpu |> collect g
end

@testset "jacobian" begin
v1 = cu(collect(1:3f0))

Expand Down

2 comments on commit 18a6f2a

@mcabbott
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/39339

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.14 -m "<description of version>" 18a6f2a14ff1466a37f8143a2c7f2013a9f970ea
git push origin v0.6.14

Please sign in to comment.