diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py index 9866591e8d..85950347ba 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py @@ -3,9 +3,10 @@ # See LICENSE for license information. import os -import pytest import subprocess -from test_fused_attn import ModelConfig + +import pytest +import torch from transformer_engine.pytorch.attention import ( _flash_attn_2_plus, _flash_attn_2_3_plus, @@ -15,6 +16,8 @@ get_cudnn_version, ) +from test_fused_attn import ModelConfig + model_configs_flash_attn = { # test: b, h, hg, d, sq, skv, p, mask, bias "cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA @@ -58,6 +61,10 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): @pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) @pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"]) def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): + num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 + if num_gpus > torch.cuda.device_count(): + pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") + config = model_configs_flash_attn[model] if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): pytest.skip("CP implementation with KV P2P does not support sliding window yet!") @@ -77,7 +84,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): subprocess.run( get_bash_arguments( - num_gpus_per_node=4 if cp_comm_type == "a2a+p2p" else 2, + num_gpus_per_node=num_gpus, dtype=dtype, model=model, qkv_format=qkv_format, @@ -115,6 +122,10 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"]) @pytest.mark.parametrize("fp8_mha", [False, True]) def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha): + num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 + if num_gpus > torch.cuda.device_count(): + pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") + if qkv_format == "thd" and get_device_compute_capability() < (9, 0): pytest.skip("THD format is only supported on sm90+!") if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0): @@ -155,7 +166,7 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha subprocess.run( get_bash_arguments( - num_gpus_per_node=4 if cp_comm_type == "a2a+p2p" else 2, + num_gpus_per_node=num_gpus, dtype=dtype, model=model, qkv_format=qkv_format,