Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't create nested thunks when accumulating #1555

Merged
merged 3 commits into from
Jan 31, 2025
Merged

Conversation

ToucheSir
Copy link
Member

Otherwise, it's too easy to create massive types that freeze compilation and blow the stack.

Ideally, we'd avoid allocations by using ChainRulesCore.add!!. But:

  • The ChainRulesCore APIs (including add!!) don't support nothing as a zero tangent.
  • I'm not sure if this is safe.

Hopefully fixes FluxML/Flux.jl#2585, #966 (comment).

PR Checklist

  • Tests are added
  • Documentation, if applicable

Otherwise, it's too easy to create massive types that freeze compilation and blow the stack.
@mcabbott
Copy link
Member

When I looked, it wasn't clear that add!! could ever be safely used. There is no requirement that rrule, or even unthunk, return a fresh array which is safe to mutate -- it could share memory with another gradient elsewhere.

This change limits how often thunks can help. If y = f(x) depends internally on z, then gradient(f, x) need not compute dz. After this change, I think f(x) = sum(z * x) which uses z once will be fine, but f(x) = sum(z * x) / sum(z) / sum(x) which uses z twice will un-thunk, and compute dz = dz1 + dz2.

In the linked Flux issue, for row in eachrow(symm_func_matrix) is a sum over 512 things. I don't know whether there are 512 contributions to something which could be discarded (like dz) or just things which cannot be(like dx) but either way, 512 nested thunks is not going to work.

@ToucheSir
Copy link
Member Author

I came up with a simpler MWE for my testing which shows the core of the problem, though I'm sure it could be simplified more:

function test(m, x, n)
  total = 0f0
  for _ in 1:n
    total += sum(m(x))
  end
  return total
end
test (generic function with 2 methods)

gradient(Dense(10 => 3), input_matrix, 500)

The problem is that it's not hard to end up with a loop that runs a couple hundred times and generates a thunk to be accumulated somewhere. For this pattern of code, it likely makes sense to unthunk after each iteration anyhow because you're less likely to be saving work by deferring the accumulation.

That does leave cases without loops but with accumulation, like f(x) = g(x) + h(x). But I think there are still enough cases of values being used once that thunks are still worth it.

@mcabbott
Copy link
Member

mcabbott commented Jan 24, 2025

Yes, that's better. I think you mean gradient(test, Dense(10 => 3), rand32(10), 50), and this crashes.

This one never wanted thunks. It would be nice if the following variant (which computes lots of dW contributions) could keep the thunks?

W = randn(3, 10); x = randn(10);
gradient(randn(3)) do b  # want only bias contributions
  m = Dense(W, b)
  test(m, x, 50)  # ok, crashes at 500
end

Here's an idea I was sketching. We could have another kind of thunk to store thunks... with a dynamic length:

struct AccumThunk{T} <: AbstractThunk
  acc::Vector{T}
end

accum(x::Thunk, y::Thunk) = AccumThunk([x, y])  # keep both thunks
accum(x::AccumThunk, y::Thunk) = AccumThunk(vcat(x.acc, y))  # add this one to the list
accum(x::AccumThunk, y::AccumThunk) = AccumThunk(vcat(x.acc, y.acc))  # combine lists
accum(x::AccumThunk, y::Nothing) = x  # don't store nothings... (but also do all this without ambiguities!)

function unthunk(x::AccumThunk)
    dxs = unthunk.(x.acc)
    # Now I want to allocate just one more array for the sum...
    broadcast(accum, dxs...)  # but ideally without the splat?
end

For your example (where all gradients are kept) I think this needs half the memory in total. Although unthunk.(x.acc) allocates them all at once so the peak is higher, perhaps? It might be safe to do this, not very sure:

function unthunk(x::AccumThunk)
    a = accum(unthunk(x.acc[1]), unthunk(x.acc[2]))  # this a is newly allocated, right? surely?
    for b in x.acc[3:end]
        a = add!!(a, b)  # this might let InplaceThunks actually do something?
    end
    a
end

@ToucheSir
Copy link
Member Author

One tricky thing with AccumThunk is what happens when the accumulated thunks have different types. For example, if the hypothetical f(x) = g(x) + h(x) had different thunked computations from g and h.

If we were able to confidently modify the AD compiler (i.e. unlikely), I could see a pass which only inserts unthunks in the presence of control flow. That would eliminate the loop problem, but for now we're probably limited to solutions that don't touch code generation.

@ToucheSir
Copy link
Member Author

Can we merge this in the interim to stop real-world code from erroring? It doesn't seem like there's an easy and quick to implement alternative, so I'd rather leave time to develop one than to force one through as a bugfix.

@ToucheSir ToucheSir merged commit 12cd77d into master Jan 31, 2025
9 of 12 checks passed
@ToucheSir ToucheSir deleted the bc/unthunk-accum branch January 31, 2025 00:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Infinite time of gradient
2 participants