Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

encountered "Batching rule for 'triton_kernel_call' not implemented" #327

Open
tlogn opened this issue Jan 22, 2025 · 4 comments
Open

encountered "Batching rule for 'triton_kernel_call' not implemented" #327

tlogn opened this issue Jan 22, 2025 · 4 comments

Comments

@tlogn
Copy link

tlogn commented Jan 22, 2025

Hello, I am facing an issue with implementing triton flash attention using jax-triton. I am receiving the error message "Batching rule for 'triton_kernel_call' not implemented". Can anyone provide assistance? Thank you!

@dfm
Copy link
Contributor

dfm commented Jan 22, 2025

I don't think jax-triton supports any of the usual JAX transformations (notably vmap in this case, but also grad, for example) out of the box. I expect that you'll need to implement custom transformation rules using jax.batching.custom_vmap and probably also jax.custom_vjp for grad support. Hope this helps!

@tlogn
Copy link
Author

tlogn commented Jan 23, 2025

I am also curious about how to customize the batching rules for the Triton kernel. It appears that the 'triton_kernel_call' function applies to all Triton kernels. Is there a possibility of conflicts in batching rules when implementing multiple Triton kernels?

@dfm
Copy link
Contributor

dfm commented Jan 23, 2025

You definitely don't want to register an explicit batching rule for triton_kernel_call_p primitive. Instead, you'll wrap your specific kernel using custom_vmap. For example:

@jax.custom_batching.custom_vmap
def my_function(x, y):
  ...
  return jt.triton_call(...)

@my_function.def_vmap
def my_function_vmap(axis_size, batched_in, x, y):
  ...

Then, only your instance of triton_call will have the custom batching behavior.

copybara-service bot pushed a commit that referenced this issue Jan 23, 2025
…mations.

#327 notes that the current error messages expose unnecessary JAX and jax_triton implementation details when jax.vmap or jax.grad are used with triton_call. This change improves those error messages to clearly identify the frontend API, and provide alternative suggestions.

PiperOrigin-RevId: 718881931
@tlogn
Copy link
Author

tlogn commented Jan 23, 2025

Got it, thank you !

copybara-service bot pushed a commit that referenced this issue Jan 23, 2025
…mations.

#327 notes that the current error messages expose unnecessary JAX and jax_triton implementation details when jax.vmap or jax.grad are used with triton_call. This change improves those error messages to clearly identify the frontend API, and provide alternative suggestions.

PiperOrigin-RevId: 718881931
copybara-service bot pushed a commit that referenced this issue Jan 23, 2025
…mations.

#327 notes that the current error messages expose unnecessary JAX and jax_triton implementation details when jax.vmap or jax.grad are used with triton_call. This change improves those error messages to clearly identify the frontend API, and provide alternative suggestions.

PiperOrigin-RevId: 718988408
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants