-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
encountered "Batching rule for 'triton_kernel_call' not implemented" #327
Comments
I don't think jax-triton supports any of the usual JAX transformations (notably |
I am also curious about how to customize the batching rules for the Triton kernel. It appears that the 'triton_kernel_call' function applies to all Triton kernels. Is there a possibility of conflicts in batching rules when implementing multiple Triton kernels? |
You definitely don't want to register an explicit batching rule for @jax.custom_batching.custom_vmap
def my_function(x, y):
...
return jt.triton_call(...)
@my_function.def_vmap
def my_function_vmap(axis_size, batched_in, x, y):
... Then, only your instance of |
…mations. #327 notes that the current error messages expose unnecessary JAX and jax_triton implementation details when jax.vmap or jax.grad are used with triton_call. This change improves those error messages to clearly identify the frontend API, and provide alternative suggestions. PiperOrigin-RevId: 718881931
Got it, thank you ! |
…mations. #327 notes that the current error messages expose unnecessary JAX and jax_triton implementation details when jax.vmap or jax.grad are used with triton_call. This change improves those error messages to clearly identify the frontend API, and provide alternative suggestions. PiperOrigin-RevId: 718881931
…mations. #327 notes that the current error messages expose unnecessary JAX and jax_triton implementation details when jax.vmap or jax.grad are used with triton_call. This change improves those error messages to clearly identify the frontend API, and provide alternative suggestions. PiperOrigin-RevId: 718988408
Hello, I am facing an issue with implementing triton flash attention using jax-triton. I am receiving the error message "Batching rule for 'triton_kernel_call' not implemented". Can anyone provide assistance? Thank you!
The text was updated successfully, but these errors were encountered: