diff --git a/flash_attn/flash_attn_triton_amd/test.py b/flash_attn/flash_attn_triton_amd/test.py index df94b09e8..8255bb23b 100644 --- a/flash_attn/flash_attn_triton_amd/test.py +++ b/flash_attn/flash_attn_triton_amd/test.py @@ -744,43 +744,43 @@ def test_op_fwd_decode_int4_kv(B, Mq, Mkv, Hq, Hkv, K, dtype=torch.float16): @pytest.mark.parametrize( "Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD", [ - # (1, 1, 1, 1, 1, 1), - # (1, 1, 1, 2, 2, 16), + (1, 1, 1, 1, 1, 1), + (1, 1, 1, 2, 2, 16), (1, 1, 1, 2, 4, 16), - # (1, 2, 2, 2, 4, 16), - # (1, 4, 1, 2, 4, 16), - # (1, 4, 2, 2, 4, 16), - # (1, 1, 1, 4, 2, 16), - # (1, 1, 1, 4, 4, 16), - # (1, 2, 2, 4, 4, 16), - # (2, 1, 1, 4, 4, 16), - # (2, 2, 2, 4, 4, 16), - # (1, 1, 1, 128, 64, 16), - # (2, 2, 2, 2, 128, 1), - # (2, 3, 3, 2, 128, 16), - # (3, 2, 2, 256, 512, 16), - # (3, 3, 3, 128, 128, 64), - # (2, 4, 4, 1024, 1024, 64), - # (4, 6, 6, 108, 256, 224), - # (4, 8, 8, 2048, 2048, 128), - # (4, 16, 16, 4096, 4096, 64), - # (2, 4, 4, 8192, 8192, 32), - # # fa configs - # (4, 6, 1, 113, 203, 256), - # (4, 6, 1, 128, 217, 256), - # (4, 6, 2, 113, 211, 128), - # (4, 6, 2, 108, 256, 128), - # (4, 6, 1, 256, 512, 64), - # (4, 6, 1, 512, 256, 64), - # (4, 6, 2, 1024, 1024, 32), - # (4, 6, 2, 1023, 1024, 32), - # (4, 6, 6, 1024, 1023, 32), - # (4, 6, 6, 2048, 2048, 32), + (1, 2, 2, 2, 4, 16), + (1, 4, 1, 2, 4, 16), + (1, 4, 2, 2, 4, 16), + (1, 1, 1, 4, 2, 16), + (1, 1, 1, 4, 4, 16), + (1, 2, 2, 4, 4, 16), + (2, 1, 1, 4, 4, 16), + (2, 2, 2, 4, 4, 16), + (1, 1, 1, 128, 64, 16), + (2, 2, 2, 2, 128, 1), + (2, 3, 3, 2, 128, 16), + (3, 2, 2, 256, 512, 16), + (3, 3, 3, 128, 128, 64), + (2, 4, 4, 1024, 1024, 64), + (4, 6, 6, 108, 256, 224), + (4, 8, 8, 2048, 2048, 128), + (4, 16, 16, 4096, 4096, 64), + (2, 4, 4, 8192, 8192, 32), + # fa configs + (4, 6, 1, 113, 203, 256), + (4, 6, 1, 128, 217, 256), + (4, 6, 2, 113, 211, 128), + (4, 6, 2, 108, 256, 128), + (4, 6, 1, 256, 512, 64), + (4, 6, 1, 512, 256, 64), + (4, 6, 2, 1024, 1024, 32), + (4, 6, 2, 1023, 1024, 32), + (4, 6, 6, 1024, 1023, 32), + (4, 6, 6, 2048, 2048, 32), ], ) @pytest.mark.parametrize('causal', [False]) @pytest.mark.parametrize('dropout_p', [0.0]) -@pytest.mark.parametrize('DEBUG_INPUT', [False]) +@pytest.mark.parametrize('DEBUG_INPUT', [True]) @pytest.mark.skipif(not arch_supports_fp8(), reason="fp8 not supported on this device") def test_op_prefill_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, DEBUG_INPUT): device = "cuda" @@ -879,10 +879,6 @@ def test_op_prefill_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, print("S_dmask_fp16:", S_dmask_fp16, S_dmask_fp16.shape if S_dmask_fp16 is not None else None ) print("S_dmask_fp8:", S_dmask_fp8, S_dmask_fp8.shape if S_dmask_fp16 is not None else None) torch.testing.assert_close(S_dmask_fp16.to(torch.float32) if S_dmask_fp16 is not None else None, S_dmask_fp8.to(torch.float32) if S_dmask_fp8 is not None else None, atol=ATOL_fp8, rtol=RTOL_fp8) - - if HQ // HK != 1: - print("Skipping backward for MQA/GQA cases because atomic_add doesnot support fp8") - return # fp8 backward pass dq_fp8, dk_fp8, dv_fp8 = torch.autograd.grad(out_fp8, (q_fp8, k_fp8, v_fp8), do_fp8)