-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fp8 backward #119
base: main_perf
Are you sure you want to change the base?
fp8 backward #119
Conversation
6b691eb
to
297742b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm approving the PR because I can't see anything wrong with it. I just left some questions and cleanup suggestions.
@@ -553,6 +636,14 @@ def attention_prefill_backward_triton_impl( | |||
print("use_exp2:", use_exp2) | |||
print("sequence_parallel:", sequence_parallel) | |||
|
|||
is_fp8 = arch_supports_fp8() and q.dtype in {torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz} | |||
if is_fp8: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this empty if is_fp8:
statement can be removed. Do we need to print or debug anything inside it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm marking this discussion as unresolved because I think the empty if
statement is still here.
7d0277c
to
4cd9e2a
Compare
Enable BWD fp8 with per block scale factors for p and ds This is a combination of 9 commits. Enable BWD fp8 This is a combination of 12 commits. add backward test case save clean up disable ci lse is good dv matches reduce diff use do fp8 for dv kinda working group size is a constexpr clean up a bit everything except mqa/gqa works skip mqa cases 20 cases have nan on dropout save what you have disable tests failing enable tests per block descale_p and descale_ds use max(abs(()) clean up tests a bit more fix bug disable ci for now pass variables add flags add alternate path. Still need to load descale factors dv working dk works save
b725cdc
to
e6a67b3
Compare
…th causal. Varlen has some issues. Might be related to strides.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job Michael! Kudos for introducing compute_fp8_scaling_factors
Triton function, it's really useful to avoid code repetition.
sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides, | ||
*bias_strides, *alibi_strides, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k, | ||
dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes, | ||
HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q, | ||
MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, VARLEN=is_varlen, | ||
BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True, | ||
USE_ALIBI=False if alibi_slopes is None else True, ENABLE_DROPOUT=dropout_p | ||
> 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, IS_FP8=is_fp8) | ||
> 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, IS_FP8=is_fp8, FP8_MAX=torch.finfo(torch.float8_e4m3fnuz).max) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion:
Since is_fp8
is defined as:
is_fp8 = arch_supports_fp8() and q.dtype in {torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz}
I think it's better to compute FP8_MAX
as:
FP8_MAX=torch.finfo(q.dtype).max
@@ -553,6 +636,14 @@ def attention_prefill_backward_triton_impl( | |||
print("use_exp2:", use_exp2) | |||
print("sequence_parallel:", sequence_parallel) | |||
|
|||
is_fp8 = arch_supports_fp8() and q.dtype in {torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz} | |||
if is_fp8: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm marking this discussion as unresolved because I think the empty if
statement is still here.
@@ -650,56 +741,26 @@ def attention_prefill_backward_triton_impl( | |||
do, | |||
delta, | |||
stride_oz, stride_oh, stride_om, stride_ok, | |||
stride_oz, stride_oh, stride_om, stride_ok, | |||
stride_oz, stride_oh, stride_om, stride_ok, # FIXME: don't share strides with derivatives this was causing a lot of issues |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: Should we merge this PR without addressing this FIXME:
? Is it supposed to be addressed later?
IS_VARLEN=is_varlen, | ||
GROUP_SIZE=group_size, | ||
IS_FP8=is_fp8, | ||
FP8_MAX=torch.finfo(torch.float8_e4m3fnuz).max |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion:
Since is_fp8
is computed as:
is_fp8 = arch_supports_fp8() and q.dtype in {torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz}
I think it's better to compute FP8_MAX
as:`
FP8_MAX=torch.finfo(q.dtype).max
DEBUG_TRITON: bool = False, | ||
DEBUG_TRITON_DETAIL: bool = False, | ||
): | ||
IS_FP8 = arch_supports_fp8() and q.dtype in {torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz} | ||
if IS_FP8: | ||
FP8_MAX = torch.finfo(torch.float8_e4m3fnuz).max |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion: Compute FP8_MAX
as torch.finfo(q.dtype).max
.
(4, 8, 8, 2048, 2048, 128), | ||
(4, 16, 16, 4096, 4096, 64), | ||
(2, 4, 4, 8192, 8192, 32), | ||
# (1, 1, 1, 1, 1, 1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: Is it our intention to leave only one test case? Or is this still a work in progress?
run: | | ||
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" | ||
pytest tests/test_flash_attn_triton_amd.py | ||
# - name: Flash Attention Tests Using Reference Impl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: My knowledge of GitHub actions is almost none, so please take this comment with a grain of salt... As far as I can see, MI300 integration job is commented out. Am I correct? Do we really want to merge this way?
add fp8 backward