Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Axis tuple is not handled properly in softmax_cross_entropy_with_integer_labels loss #1162

Open
daskol opened this issue Dec 28, 2024 · 4 comments · May be fixed by #1165
Open

Axis tuple is not handled properly in softmax_cross_entropy_with_integer_labels loss #1162

daskol opened this issue Dec 28, 2024 · 4 comments · May be fixed by #1165
Assignees

Comments

@daskol
Copy link
Contributor

daskol commented Dec 28, 2024

According to the docstring of softmax_cross_entropy_with_integer_labels, axis admits None, int, or tuple[int, ...]. However, snippet below demonstrates that axis of tuple[int, int] causes an exception.

xs = jnp.ones((1, 2, 3, 4))
ys = jnp.zeros(xs.shape[:-2], dtype=jnp.int32)
mask = jnp.ones_like(xs, dtype=jnp.bool)
optax.softmax_cross_entropy_with_integer_labels(xs, ys, (-2, -1), mask)  # FAIL
# TypeError: 'tuple' object cannot be interpreted as an integer

The reason is that implementation exploits take_along_axis which admits scalar axis but not tuple.

logits_max = jnp.max(
logits, axis, keepdims=True, where=where, initial=-jnp.inf
)
logits -= jax.lax.stop_gradient(logits_max)
label_logits = jnp.take_along_axis(
logits, jnp.expand_dims(labels, axis), axis=axis
).take(0, axis=axis)
log_normalizers = jnp.log(jnp.sum(jnp.exp(logits), axis=axis, where=where))
return log_normalizers - label_logits

There are two options from my perspective.

  1. Parameter axis of softmax_cross_entropy_with_integer_labels should be either None or int.
  2. Fix a bug and implement "vector" axis parameter.
    However, semantic of the labels becomes unclear.
    Specifically, if axis is (-2, -1) as in example above than what shape of labels and how to specify a label of 2-dimensional slice of logits? Obvious solution is append another dimension to labels with len(axis) elements (i.e. (1, 2, 2) in example above). Another solution is to assume that elements labels are flat indices in 2-dimensional slices of logits in this example.
@rdyro
Copy link
Collaborator

rdyro commented Jan 2, 2025

Great catch! I think solution 2. can be rather ambiguous like you're pointing out. Given the nature of the cross-entropy loss it's probably not too much to ask the user to reshape the label axes to a single axis.

I'm working on a fix (using your solution 1.) here: #1164

@rdyro
Copy link
Collaborator

rdyro commented Jan 2, 2025

I'm also leaning towards 1. (restricting the argument to int or None) because one_hot doesn't support axis tuples and this function simulates an explicit one_hot application.

@rdyro rdyro self-assigned this Jan 2, 2025
@daskol
Copy link
Contributor Author

daskol commented Jan 2, 2025

Great! Option 1 is fair but I tried option 2 since it is needed to implement and reproduce (Shidani et al, 2024). So take a look at #1165, please.

@rdyro
Copy link
Collaborator

rdyro commented Jan 3, 2025

Sounds good, let's work on your PR instead of 1.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants