Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fp8 backward #119

Draft
wants to merge 19 commits into
base: main_perf
Choose a base branch
from
Draft

fp8 backward #119

wants to merge 19 commits into from

Conversation

micmelesse
Copy link
Collaborator

@micmelesse micmelesse commented Jan 24, 2025

add fp8 backward

@micmelesse micmelesse changed the title add backward test case fp8 backward Jan 24, 2025
@micmelesse micmelesse force-pushed the micmelesse/fp8_bwd branch 4 times, most recently from 6b691eb to 297742b Compare February 3, 2025 09:24
@micmelesse micmelesse marked this pull request as ready for review February 4, 2025 13:37
Copy link

@brunomazzottiamd brunomazzottiamd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm approving the PR because I can't see anything wrong with it. I just left some questions and cleanup suggestions.

flash_attn/flash_attn_triton_amd/README.md Show resolved Hide resolved
flash_attn/flash_attn_triton_amd/bwd_prefill.py Outdated Show resolved Hide resolved
flash_attn/flash_attn_triton_amd/bwd_prefill.py Outdated Show resolved Hide resolved
flash_attn/flash_attn_triton_amd/bwd_prefill.py Outdated Show resolved Hide resolved
@@ -553,6 +636,14 @@ def attention_prefill_backward_triton_impl(
print("use_exp2:", use_exp2)
print("sequence_parallel:", sequence_parallel)

is_fp8 = arch_supports_fp8() and q.dtype in {torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz}
if is_fp8:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this empty if is_fp8: statement can be removed. Do we need to print or debug anything inside it?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm marking this discussion as unresolved because I think the empty if statement is still here.

flash_attn/flash_attn_triton_amd/test.py Show resolved Hide resolved
flash_attn/flash_attn_triton_amd/test.py Show resolved Hide resolved
flash_attn/flash_attn_triton_amd/test.py Outdated Show resolved Hide resolved
Enable BWD fp8 with per block scale factors for p
and ds

This is a combination of 9 commits.

Enable BWD fp8

This is a combination of 12 commits.

add backward test case

save clean up

disable ci

lse is good

dv matches

reduce diff

use do fp8 for dv

kinda working

group size is a constexpr

clean up a bit

everything except mqa/gqa works

skip mqa cases

20 cases have nan on dropout

save what you have

disable tests

failing

enable tests

per block descale_p and descale_ds

use max(abs(())

clean up tests a bit more

fix bug

disable ci for now

pass variables

add flags

add alternate path. Still need to load descale factors

dv working

dk works

save
@micmelesse micmelesse marked this pull request as draft February 7, 2025 19:27
Copy link

@brunomazzottiamd brunomazzottiamd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job Michael! Kudos for introducing compute_fp8_scaling_factors Triton function, it's really useful to avoid code repetition.

sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides,
*bias_strides, *alibi_strides, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k,
dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes,
HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, VARLEN=is_varlen,
BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True,
USE_ALIBI=False if alibi_slopes is None else True, ENABLE_DROPOUT=dropout_p
> 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, IS_FP8=is_fp8)
> 0.0, USE_EXP2=use_exp2, RETURN_SCORES=return_softmax, IS_FP8=is_fp8, FP8_MAX=torch.finfo(torch.float8_e4m3fnuz).max)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion:

Since is_fp8 is defined as:

is_fp8 = arch_supports_fp8() and q.dtype in {torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz}

I think it's better to compute FP8_MAX as:

FP8_MAX=torch.finfo(q.dtype).max

@@ -553,6 +636,14 @@ def attention_prefill_backward_triton_impl(
print("use_exp2:", use_exp2)
print("sequence_parallel:", sequence_parallel)

is_fp8 = arch_supports_fp8() and q.dtype in {torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz}
if is_fp8:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm marking this discussion as unresolved because I think the empty if statement is still here.

@@ -650,56 +741,26 @@ def attention_prefill_backward_triton_impl(
do,
delta,
stride_oz, stride_oh, stride_om, stride_ok,
stride_oz, stride_oh, stride_om, stride_ok,
stride_oz, stride_oh, stride_om, stride_ok, # FIXME: don't share strides with derivatives this was causing a lot of issues

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Should we merge this PR without addressing this FIXME:? Is it supposed to be addressed later?

IS_VARLEN=is_varlen,
GROUP_SIZE=group_size,
IS_FP8=is_fp8,
FP8_MAX=torch.finfo(torch.float8_e4m3fnuz).max

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion:

Since is_fp8 is computed as:

is_fp8 = arch_supports_fp8() and q.dtype in {torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz}

I think it's better to compute FP8_MAX as:`

FP8_MAX=torch.finfo(q.dtype).max

DEBUG_TRITON: bool = False,
DEBUG_TRITON_DETAIL: bool = False,
):
IS_FP8 = arch_supports_fp8() and q.dtype in {torch.float8_e4m3fnuz, torch.float8_e4m3fn, torch.float8_e5m2, torch.float8_e5m2fnuz}
if IS_FP8:
FP8_MAX = torch.finfo(torch.float8_e4m3fnuz).max

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Compute FP8_MAX as torch.finfo(q.dtype).max.

(4, 8, 8, 2048, 2048, 128),
(4, 16, 16, 4096, 4096, 64),
(2, 4, 4, 8192, 8192, 32),
# (1, 1, 1, 1, 1, 1),

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Is it our intention to leave only one test case? Or is this still a work in progress?

run: |
export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE"
pytest tests/test_flash_attn_triton_amd.py
# - name: Flash Attention Tests Using Reference Impl

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: My knowledge of GitHub actions is almost none, so please take this comment with a grain of salt... As far as I can see, MI300 integration job is commented out. Am I correct? Do we really want to merge this way?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants