Skip to content

Commit

Permalink
Update limitations.md
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott authored Jan 4, 2025
1 parent bb3730e commit c8db49a
Showing 1 changed file with 20 additions and 27 deletions.
47 changes: 20 additions & 27 deletions docs/src/limitations.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ Let's explore this with a more concrete example. Here we define a simple mutatin
```julia
function f!(x)
x .= 2 .* x

return x
end
```
Expand All @@ -42,43 +41,36 @@ Stacktrace:
...
```
We got an error message and a long stacktrace. The error informs us that our code performs array mutation by calling `copyto!` (we might not have directly called this function, but it is being invoked somewhere in the call stack). We see that our code includes `x .= ...` which is given as an example of array mutation. Other examples of mutating operations include:
- setting values (`x .= ...`)
- appending/popping values (`push!(x, v)` / `pop!(x)`)
- calling mutating functions (`mul!(C, A, B)`)
- setting values (`x[i] = val` or `x .= values`)
- appending/popping values (`push!(x, v)` or `pop!(x)`)
- calling mutating functions (such as `LinearAlgebra.mul!(C, A, B)`)

!!! warning

Non-mutating functions might also use mutation under the hood. This can be done for performance reasons or code re-use.

```julia
function g!(x, y)
x .= 2 .* y

function g_inner!(x, y)
for i in eachindex(x, y)
x[i] = 2 * y[i]
end
return x
end
g(y) = g!(similar(y), y)
```
Here `g` is a "non-mutating function," and it indeed does not mutate `y`, its only argument. But it still allocates a new array and calls `g!` on this array which will result in a mutating operation. You may encounter such functions when working with another package.

Specifically for array mutation, we can use [`Zygote.Buffer`](@ref) to re-write our function. For example, let's fix the function `g!` above.
```julia
function g!(x, y)
x .= 2 .* y

return x
function g_outer(y)
z = similar(y)
g_inner!(z, y)
return z
end
```
Here `g_outer` does not mutate `y`, its only argument. But it still allocates a new array `z` and calls `g_inner!` on this array, which will result in a mutating operation. You may encounter such functions when working with another package.

function g(y)
x = Zygote.Buffer(y) # Buffer supports syntax like similar
g!(x, y)
return copy(x) # this step makes the Buffer immutable (w/o actually copying)
end
How can you solve this problem?
* Re-write the code not to use mutation. Here we can obviously write `g_better(y) = 2 .* y` using broadcasting. Many other cases may be solved by writing comprehensions `[f(x, y) for x in xs, y in ys]` or using `map(f, xs, ys)`, instead of explicitly allocating an output array and then writing into it.
* Write a custom rule, defining `rrule(::typeof(g), y)` using what you know about `g` to derive the right expression.
* Use another AD package instead of Zygote for part of the calculation. Replacing `g(y)` with `Zygote.forwarddiff(g, y)` will compute the same value, but when it is time to find the gradient, this job is outsourced to [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl). ForwardDiff has its own limitations but mutation isn't one of them.

julia> gradient(rand(3)) do y
sum(g(y))
end
([2.0, 2.0, 2.0],)
```
Finally, there is also [`Zygote.Buffer`](@ref) which aims to handle the pattern of allocating space and then mutating it. But it has many bugs and is not really recommended.

## Try-catch statements

Expand Down Expand Up @@ -136,7 +128,8 @@ For all of the errors above, the suggested solutions are similar. You have the f
2. define a [custom `ChainRulesCore.rrule`](https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/example.html)
3. open an [issue on Zygote](https://github.com/FluxML/Zygote.jl/issues)

Avoiding the operation is simple, just don't do it! If you are using a mutating function, try to use a non-mutating variant. If you are using `try`/`catch` statements, try to use more graceful error handling such as returning `nothing` or another sentinel value. Recall that array mutation can also be avoided by using [`Zygote.Buffer`](@ref) as discussed above.
Avoiding the operation is simple, just don't do it! If you are using a mutating function, try to use a non-mutating variant. Instead of allocating an array and writing into it, try to make the output directly using broadcasting, `map`, or a comprehension.
If you are using `try`/`catch` statements, try to use more graceful error handling such as returning `nothing` or another sentinel value.

Sometimes, we cannot avoid expressions that Zygote cannot differentiate, but we may be able to manually derive a gradient. In these cases, you can write [a custom `rrule`](https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/example.html) using ChainRules.jl. Please refer to the linked ChainRules documentation for how to do this. _This solution is the only solution available for foreign call expressions._ Below, we provide a custom `rrule` for `jclock`.
```julia
Expand Down

0 comments on commit c8db49a

Please sign in to comment.