Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
1004: ensure `sum(f,x)` works on GPU r=DhairyaLGandhi a=oxinabox Seems like there is some pains from #990 re:GPU. In particular we broke DiffEqSensitivity https://buildkite.com/julialang/diffeqsensitivity-dot-jl/builds/169#d254017e-e824-4d9c-854d-f3b348395599/411-877 @ChrisRackauckas 's "M"WE is ``` using DiffEqFlux, OrdinaryDiffEq, DiffEqSensitivity using CUDA, Test, Zygote CUDA.allowscalar(false) H = CuArray(rand(Float32, 2, 2)) ann = FastChain(FastDense(1, 4, tanh)) p = initial_params(ann) function func(x, p, t) ann([t],p)[1]*H*x end x0 = CuArray(rand(Float32, 2)) x1 = CuArray(rand(Float32, 2)) prob = ODEProblem(func, x0, (0.0f0, 1.0f0)) function evolve(p) solve(prob, Tsit5(), p=p, save_start=false, save_everystep=false, abstol=1e-4, reltol=1e-4, sensealg=QuadratureAdjoint(autojacvec=ZygoteVJP())).u[1] end function cost(p) x = evolve(p) c = sum(abs,x - x1) #println(c) c end grad = Zygote.gradient(cost,p)[1] @test !iszero(grad[1]) @test iszero(grad[2:4]) @test !iszero(grad[5]) @test iszero(grad[6:end]) ``` I am hoping we can get it to fail with just `sum(f, xs)` (which I have added to tests)} I can't run GPU locally which makes testing this hard. If I have to I will spin up an EC2 instance, but I would really rather not. I think what is going on is, from looking at [the logs](https://buildkite.com/julialang/diffeqsensitivity-dot-jl/builds/169#d254017e-e824-4d9c-854d-f3b348395599/411-877) The error happens in during the forward pass. In particular here https://github.com/JuliaDiff/ChainRules.jl/blob/52a0eeadf8d19bff491f224517b7b064ce1ba378/src/rulesets/Base/mapreduce.jl#L46 I think this was why Zygote implemented the pullback of sum(f, x) as sum(f.(x)) (which is slower and more allocate-y than our never version) so that it could hit the code that Zygote has special for CUDA that does forwards-mode. (Which means it doesn't need the Context object containing the IdDict) So I think the solution in short-term is probably to add the old rule for sum back in (but for CuArray only) here. https://github.com/FluxML/Zygote.jl/blob/531da8bb7753f46294bc13f9d2a2fdd54917f926/src/lib/broadcast.jl#L244 ``` # Make sure sum(f, ::CuArray) uses forward mode broadcast AD defined above # Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU safe @adjoint function sum(f, xs::CuArray; kws...) @Assert !haskey(kws, :init) # TODO add init support (julia 1.6) return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs) end ``` In the longer-term, we will probably default to doing the f from sum(f, xs) in forward-mode anyway. So Zygote's rule config can be updated to say that it does use ForwardDiff.jl for it's frule_via_ad. Co-authored-by: Lyndon White <[email protected]>
- Loading branch information
18a6f2a
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JuliaRegistrator register
18a6f2a
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registration pull request created: JuliaRegistries/General/39339
After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.
This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via: