Skip to content

Commit

Permalink
Don't grad through max subtract in softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 committed Oct 18, 2021
1 parent 19e6ef3 commit 8916210
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion lib/axon/activations.ex
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,21 @@ defmodule Axon.Activations do
end
end)

max_val = Nx.reduce_max(x, axes: [opts[:axis]], keep_axes: true)
# This is a scaling term designed to prevent over/under flow when x is very
# large. Consider cases where the intermediate value e^x with large positive
# x, e^x tends towards infinity or 0. This poisons the rest of the
# calculation which would otherwise be normalized with the division by sum(e^x).
# Thus we can scale by the max value in the tensor which guarantees all values
# are smaller than 0.
#
# Given the expression is essentially:
#
# e^(x - C) / sum(e^(x - C))
#
# We are essentially treating the max value as a constant term, C. Thus there
# is no need to differentiate through the max. See also: https://github.com/google/jax/pull/2260
# for a note on performance.
max_val = stop_grad(Nx.reduce_max(x, axes: [opts[:axis]], keep_axes: true))

stable_exp =
x
Expand Down

0 comments on commit 8916210

Please sign in to comment.