Skip to content

Commit

Permalink
Add descriptive error messages when triton_call is used with transfor…
Browse files Browse the repository at this point in the history
…mations.

#327 notes that the current error messages expose unnecessary JAX and jax_triton implementation details when jax.vmap or jax.grad are used with triton_call. This change improves those error messages to clearly identify the frontend API, and provide alternative suggestions.

PiperOrigin-RevId: 718881931
  • Loading branch information
dfm authored and Google-ML-Automation committed Jan 23, 2025
1 parent 59f5703 commit 18b364f
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 0 deletions.
27 changes: 27 additions & 0 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
from jax._src.lib.mlir import ir
import jax.dlpack
import jax.extend as jex
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import xla
import jax.numpy as jnp
Expand Down Expand Up @@ -675,6 +677,31 @@ def prune_configs(configs, named_args, **kwargs):
platform="rocm",
)


def triton_kernel_call_raise_on_jvp(*args, **kwargs):
del args, kwargs # unused
raise NotImplementedError(
"jax_triton.triton_call does not support automatic differentiation. Use "
"jax.custom_jvp or jax.custom_vjp to implement a custom automatic "
"differentiation rule for your kernel."
)

ad.primitive_jvps[triton_kernel_call_p] = triton_kernel_call_raise_on_jvp


def triton_kernel_call_raise_on_vmap(*args, **kwargs):
del args, kwargs # unused
raise NotImplementedError(
"jax_triton.triton_call does not support batching with jax.vmap. Use "
"jax.custom_batching.custom_vmap to implement a custom batching rule for "
"your kernel."
)

batching.primitive_batchers[triton_kernel_call_p] = (
triton_kernel_call_raise_on_vmap
)


class ShapeDtype(Protocol):

@property
Expand Down
18 changes: 18 additions & 0 deletions tests/triton_call_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,24 @@ def test_autotune_with_input_output_aliasing(self):
out = add(x, y, kernel=kernel, input_output_aliases={0: 0})
np.testing.assert_allclose(out, expected)

def test_autodiff_exception(self):
x, y = create_random_inputs([10, 100], dtype="float32")
with self.assertRaisesRegex(
NotImplementedError,
r"jax_triton.triton_call does not support automatic differentiation.*"
r"jax\.custom_jvp or jax\.custom_vjp.*",
):
jax.grad(lambda x, y: jnp.sum(add(x, y, BLOCK_SIZE=32)))(x, y)

def test_batching_exception(self):
x, y = create_random_inputs([10, 100], dtype="float32")
with self.assertRaisesRegex(
NotImplementedError,
r"jax_triton.triton_call does not support batching.*"
r"jax\.custom_batching\.custom_vmap.*",
):
jax.vmap(lambda x, y: add(x, y, BLOCK_SIZE=32))(x, y)


if __name__ == "__main__":
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5"
Expand Down

0 comments on commit 18b364f

Please sign in to comment.