diff --git a/axlearn/common/flash_attention/gpu_attention_test.py b/axlearn/common/flash_attention/gpu_attention_test.py index 181a68883..4f8daed9a 100644 --- a/axlearn/common/flash_attention/gpu_attention_test.py +++ b/axlearn/common/flash_attention/gpu_attention_test.py @@ -10,6 +10,7 @@ Currently tested on A100/H100. """ + import functools from typing import Literal @@ -28,9 +29,6 @@ from axlearn.common.flash_attention.utils import _repeat_kv_heads, mha_reference from axlearn.common.test_utils import TestCase -if jax.default_backend() != "gpu": - pytest.skip(reason="Incompatible hardware", allow_module_level=True) - @pytest.mark.parametrize( "batch_size,seq_len,num_heads,per_head_dim", @@ -51,6 +49,7 @@ @pytest.mark.parametrize("attention_bias_type", [None, "2d", "4d"]) @pytest.mark.parametrize("use_segment_ids", [True, False]) @pytest.mark.parametrize("input_dtype", [jnp.float16, jnp.float32]) +@pytest.mark.gpu def test_triton_fwd_only_against_ref( batch_size: int, seq_len: int, @@ -119,6 +118,7 @@ def test_triton_fwd_only_against_ref( chex.assert_trees_all_close(o, o_ref, atol=0.03) +@pytest.mark.gpu class FlashDecodingTest(TestCase): """Tests FlashDecoding.""" @@ -222,6 +222,7 @@ def test_decode_against_ref( @pytest.mark.parametrize("block_size", [64, 128]) @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("input_dtype", [jnp.float16, jnp.float32]) +@pytest.mark.gpu def test_triton_against_xla_ref( batch_size: int, num_heads: int, @@ -338,6 +339,7 @@ def ref_fn(q, k, v, bias, segment_ids, k5): ) @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16]) +@pytest.mark.gpu def test_cudnn_against_triton_ref( batch_size: int, num_heads: int, @@ -399,6 +401,7 @@ def ref_fn(q, k, v): @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16]) @pytest.mark.parametrize("dropout_rate", [0.1, 0.25]) +@pytest.mark.gpu def test_cudnn_dropout_against_xla_dropout( batch_size: int, num_heads: int, @@ -479,6 +482,7 @@ def ref_fn(q, k, v): raise ValueError(f"Unsupported dtype: {dtype}") +@pytest.mark.gpu def test_cudnn_dropout_determinism(): """Tests that cuDNN dropout produces identical outputs across runs.""" k1, k2, k3 = jax.random.split(jax.random.PRNGKey(3), 3) diff --git a/axlearn/common/flash_attention/tpu_attention_test.py b/axlearn/common/flash_attention/tpu_attention_test.py index f9a99c310..ff4ae6b7f 100644 --- a/axlearn/common/flash_attention/tpu_attention_test.py +++ b/axlearn/common/flash_attention/tpu_attention_test.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. """Tests TPU FlashAttention kernels.""" + from __future__ import annotations import unittest @@ -29,10 +30,6 @@ from axlearn.common.test_utils import TestCase, is_supported_mesh_shape from axlearn.common.utils import Tensor -# Comment out to test on CPU manually. Technically, this test runs on the CPU, albeit very slowly. -if jax.default_backend() != "tpu": - pytest.skip(reason="Incompatible hardware", allow_module_level=True) - def setUpModule(): # If on CPU, emulate 4 devices. @@ -51,6 +48,7 @@ def jax_fn_mask(query_position: Tensor, key_position: Tensor) -> Tensor: return jnp.greater_equal(query_position, key_position) +@pytest.mark.tpu class TestFlashAttention(TestCase): """Tests FlashAttention layer.""" diff --git a/run_tests.sh b/run_tests.sh index b1f4ca104..e2ca20625 100755 --- a/run_tests.sh +++ b/run_tests.sh @@ -52,7 +52,7 @@ fi UNQUOTED_PYTEST_FILES=$(echo $1 | tr -d "'") pytest --durations=100 -v -n auto \ - -m "not (gs_login or tpu or high_cpu or fp64)" ${UNQUOTED_PYTEST_FILES} \ + -m "not (gs_login or tpu or gpu or high_cpu or fp64)" ${UNQUOTED_PYTEST_FILES} \ --dist worksteal & TEST_PIDS[$!]=1