Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Feb 6, 2025
1 parent 1009f29 commit b725cdc
Showing 1 changed file with 32 additions and 36 deletions.
68 changes: 32 additions & 36 deletions flash_attn/flash_attn_triton_amd/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,43 +744,43 @@ def test_op_fwd_decode_int4_kv(B, Mq, Mkv, Hq, Hkv, K, dtype=torch.float16):
@pytest.mark.parametrize(
"Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD",
[
# (1, 1, 1, 1, 1, 1),
# (1, 1, 1, 2, 2, 16),
(1, 1, 1, 1, 1, 1),
(1, 1, 1, 2, 2, 16),
(1, 1, 1, 2, 4, 16),
# (1, 2, 2, 2, 4, 16),
# (1, 4, 1, 2, 4, 16),
# (1, 4, 2, 2, 4, 16),
# (1, 1, 1, 4, 2, 16),
# (1, 1, 1, 4, 4, 16),
# (1, 2, 2, 4, 4, 16),
# (2, 1, 1, 4, 4, 16),
# (2, 2, 2, 4, 4, 16),
# (1, 1, 1, 128, 64, 16),
# (2, 2, 2, 2, 128, 1),
# (2, 3, 3, 2, 128, 16),
# (3, 2, 2, 256, 512, 16),
# (3, 3, 3, 128, 128, 64),
# (2, 4, 4, 1024, 1024, 64),
# (4, 6, 6, 108, 256, 224),
# (4, 8, 8, 2048, 2048, 128),
# (4, 16, 16, 4096, 4096, 64),
# (2, 4, 4, 8192, 8192, 32),
# # fa configs
# (4, 6, 1, 113, 203, 256),
# (4, 6, 1, 128, 217, 256),
# (4, 6, 2, 113, 211, 128),
# (4, 6, 2, 108, 256, 128),
# (4, 6, 1, 256, 512, 64),
# (4, 6, 1, 512, 256, 64),
# (4, 6, 2, 1024, 1024, 32),
# (4, 6, 2, 1023, 1024, 32),
# (4, 6, 6, 1024, 1023, 32),
# (4, 6, 6, 2048, 2048, 32),
(1, 2, 2, 2, 4, 16),
(1, 4, 1, 2, 4, 16),
(1, 4, 2, 2, 4, 16),
(1, 1, 1, 4, 2, 16),
(1, 1, 1, 4, 4, 16),
(1, 2, 2, 4, 4, 16),
(2, 1, 1, 4, 4, 16),
(2, 2, 2, 4, 4, 16),
(1, 1, 1, 128, 64, 16),
(2, 2, 2, 2, 128, 1),
(2, 3, 3, 2, 128, 16),
(3, 2, 2, 256, 512, 16),
(3, 3, 3, 128, 128, 64),
(2, 4, 4, 1024, 1024, 64),
(4, 6, 6, 108, 256, 224),
(4, 8, 8, 2048, 2048, 128),
(4, 16, 16, 4096, 4096, 64),
(2, 4, 4, 8192, 8192, 32),
# fa configs
(4, 6, 1, 113, 203, 256),
(4, 6, 1, 128, 217, 256),
(4, 6, 2, 113, 211, 128),
(4, 6, 2, 108, 256, 128),
(4, 6, 1, 256, 512, 64),
(4, 6, 1, 512, 256, 64),
(4, 6, 2, 1024, 1024, 32),
(4, 6, 2, 1023, 1024, 32),
(4, 6, 6, 1024, 1023, 32),
(4, 6, 6, 2048, 2048, 32),
],
)
@pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('DEBUG_INPUT', [False])
@pytest.mark.parametrize('DEBUG_INPUT', [True])
@pytest.mark.skipif(not arch_supports_fp8(), reason="fp8 not supported on this device")
def test_op_prefill_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, DEBUG_INPUT):
device = "cuda"
Expand Down Expand Up @@ -879,10 +879,6 @@ def test_op_prefill_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p,
print("S_dmask_fp16:", S_dmask_fp16, S_dmask_fp16.shape if S_dmask_fp16 is not None else None )
print("S_dmask_fp8:", S_dmask_fp8, S_dmask_fp8.shape if S_dmask_fp16 is not None else None)
torch.testing.assert_close(S_dmask_fp16.to(torch.float32) if S_dmask_fp16 is not None else None, S_dmask_fp8.to(torch.float32) if S_dmask_fp8 is not None else None, atol=ATOL_fp8, rtol=RTOL_fp8)

if HQ // HK != 1:
print("Skipping backward for MQA/GQA cases because atomic_add doesnot support fp8")
return

# fp8 backward pass
dq_fp8, dk_fp8, dv_fp8 = torch.autograd.grad(out_fp8, (q_fp8, k_fp8, v_fp8), do_fp8)
Expand Down

0 comments on commit b725cdc

Please sign in to comment.