-
Notifications
You must be signed in to change notification settings - Fork 366
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JAX] THD ring attention #1454
base: main
Are you sure you want to change the base?
[JAX] THD ring attention #1454
Conversation
/te-ci jax L1 |
/te-ci jax L1 |
d486032
to
33ac4d2
Compare
/te-ci jax L1 |
33ac4d2
to
ddd6a8b
Compare
/te-ci jax L1 |
6dd5fdb
to
4c17948
Compare
/te-ci jax L1 |
transformer_engine/jax/attention.py
Outdated
if strategy == ReorderStrategy.DualChunkSwap: | ||
return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, True) | ||
if strategy == ReorderStrategy.Striped: | ||
return _inverse_reorder_causal_striped(tensor, cp_size, seq_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi,
Why we implemented the reorder_causal_load_balancing()
in jax/cpp_extensions/attention.py
but the _inverse_reorder_causal_striped
in attention.py
?
I think we should make the _reorder_causal_striped
have the same API as the tex.reorder_causal_load_balancing()
which accepts the boolean if_inverse
and can handle both cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The last argument for reorder_causal_load_balancing
is actually for to_contiguous
instead of inverse and not inverse. The reason that reorder_causal_load_balancing
need to be under cpp_extensions/attention.py
is because that it is also needed by the cpp_extensions/attention.py
, but _reorder_causal_striped
doesn't need to be instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But for a better alignment, I can move it into the cpp_extension/attention.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -310,7 +312,7 @@ def abstract( | |||
rng_state_shape = (seed_aval.shape[0], checker.rng_state_size) | |||
rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype) | |||
|
|||
if config.attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: | |||
if config.attn_bias_type == AttnBiasType.NO_BIAS: | |||
bias_batch = bias_heads = 0 | |||
else: | |||
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, how does the full shape of the bias_aval
here look like? Is this bias for PreBias or PostBias or both?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When no_bias, there is a 0 shape bias passed. When it is not, it is intend for both PreBias and PostBias.
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Reese Wang <[email protected]>
4c17948
to
fc2ebcb
Compare
Signed-off-by: Reese Wang <[email protected]>
/te-ci jax L1 |
Description
Support P2P context parallel (ring attn) with THD format. This feature is only available for self attn + causal + segment_ids/pos + load balancing (reorder before the attn and inverse-reorder after the attn).
Type of change
Changes
kv_groups
intest_distributed_fused_attn
AttnBiasType
,AttnMaskType
,QKVLayout
in cpp_extenion/attention.py for maintaining the readibility.Checklist: