Skip to content

Commit

Permalink
[PyTorch] Skip context parallelism tests if not enough GPUs (#1508)
Browse files Browse the repository at this point in the history
* Skip context parallelism tests if not enough GPUs

Signed-off-by: Tim Moon <[email protected]>

* Apply suggestions from code review

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
  • Loading branch information
timmoon10 authored Feb 26, 2025
1 parent 5d85857 commit 2834e4a
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions tests/pytorch/fused_attn/test_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
# See LICENSE for license information.

import os
import pytest
import subprocess
from test_fused_attn import ModelConfig

import pytest
import torch
from transformer_engine.pytorch.attention import (
_flash_attn_2_plus,
_flash_attn_2_3_plus,
Expand All @@ -15,6 +16,8 @@
get_cudnn_version,
)

from test_fused_attn import ModelConfig

model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
Expand Down Expand Up @@ -58,6 +61,10 @@ def get_bash_arguments(num_gpus_per_node, **kwargs):
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"])
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2
if num_gpus > torch.cuda.device_count():
pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}")

config = model_configs_flash_attn[model]
if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
Expand All @@ -77,7 +84,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):

subprocess.run(
get_bash_arguments(
num_gpus_per_node=4 if cp_comm_type == "a2a+p2p" else 2,
num_gpus_per_node=num_gpus,
dtype=dtype,
model=model,
qkv_format=qkv_format,
Expand Down Expand Up @@ -115,6 +122,10 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"])
@pytest.mark.parametrize("fp8_mha", [False, True])
def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha):
num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2
if num_gpus > torch.cuda.device_count():
pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}")

if qkv_format == "thd" and get_device_compute_capability() < (9, 0):
pytest.skip("THD format is only supported on sm90+!")
if cp_comm_type == "all_gather" and get_cudnn_version() < (9, 3, 0):
Expand Down Expand Up @@ -155,7 +166,7 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha

subprocess.run(
get_bash_arguments(
num_gpus_per_node=4 if cp_comm_type == "a2a+p2p" else 2,
num_gpus_per_node=num_gpus,
dtype=dtype,
model=model,
qkv_format=qkv_format,
Expand Down

0 comments on commit 2834e4a

Please sign in to comment.