-
-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Don't create nested thunks when accumulating #1555
Conversation
Otherwise, it's too easy to create massive types that freeze compilation and blow the stack.
When I looked, it wasn't clear that This change limits how often thunks can help. If In the linked Flux issue, |
I came up with a simpler MWE for my testing which shows the core of the problem, though I'm sure it could be simplified more: function test(m, x, n)
total = 0f0
for _ in 1:n
total += sum(m(x))
end
return total
end
test (generic function with 2 methods)
gradient(Dense(10 => 3), input_matrix, 500) The problem is that it's not hard to end up with a loop that runs a couple hundred times and generates a thunk to be accumulated somewhere. For this pattern of code, it likely makes sense to unthunk after each iteration anyhow because you're less likely to be saving work by deferring the accumulation. That does leave cases without loops but with accumulation, like |
Yes, that's better. I think you mean This one never wanted thunks. It would be nice if the following variant (which computes lots of dW contributions) could keep the thunks? W = randn(3, 10); x = randn(10);
gradient(randn(3)) do b # want only bias contributions
m = Dense(W, b)
test(m, x, 50) # ok, crashes at 500
end Here's an idea I was sketching. We could have another kind of thunk to store thunks... with a dynamic length: struct AccumThunk{T} <: AbstractThunk
acc::Vector{T}
end
accum(x::Thunk, y::Thunk) = AccumThunk([x, y]) # keep both thunks
accum(x::AccumThunk, y::Thunk) = AccumThunk(vcat(x.acc, y)) # add this one to the list
accum(x::AccumThunk, y::AccumThunk) = AccumThunk(vcat(x.acc, y.acc)) # combine lists
accum(x::AccumThunk, y::Nothing) = x # don't store nothings... (but also do all this without ambiguities!)
function unthunk(x::AccumThunk)
dxs = unthunk.(x.acc)
# Now I want to allocate just one more array for the sum...
broadcast(accum, dxs...) # but ideally without the splat?
end For your example (where all gradients are kept) I think this needs half the memory in total. Although function unthunk(x::AccumThunk)
a = accum(unthunk(x.acc[1]), unthunk(x.acc[2])) # this a is newly allocated, right? surely?
for b in x.acc[3:end]
a = add!!(a, b) # this might let InplaceThunks actually do something?
end
a
end |
One tricky thing with AccumThunk is what happens when the accumulated thunks have different types. For example, if the hypothetical If we were able to confidently modify the AD compiler (i.e. unlikely), I could see a pass which only inserts unthunks in the presence of control flow. That would eliminate the loop problem, but for now we're probably limited to solutions that don't touch code generation. |
Can we merge this in the interim to stop real-world code from erroring? It doesn't seem like there's an easy and quick to implement alternative, so I'd rather leave time to develop one than to force one through as a bugfix. |
Otherwise, it's too easy to create massive types that freeze compilation and blow the stack.
Ideally, we'd avoid allocations by using
ChainRulesCore.add!!
. But:add!!
) don't supportnothing
as a zero tangent.Hopefully fixes FluxML/Flux.jl#2585, #966 (comment).
PR Checklist