From 39691100b7847e20fbd0361341e5619455da55d3 Mon Sep 17 00:00:00 2001 From: Boyue Li Date: Thu, 27 Feb 2025 17:19:33 -0800 Subject: [PATCH] Enable 4- and 8-device Flash Attention layers tests. --- run_tests.sh | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/run_tests.sh b/run_tests.sh index da2e7f691..07d9ab2c2 100755 --- a/run_tests.sh +++ b/run_tests.sh @@ -69,3 +69,12 @@ while [ ${#TEST_PIDS[@]} -ne 0 ]; do wait -n -p PID ${!TEST_PIDS[@]} || exit_if_error $? "Test failed." unset TEST_PIDS[$PID] done + +# Simulate 4- and 8-device environment to run Flash Attention layer tests. +# Run at the end to avoid OOM. +for num_devices in 4 8; do + XLA_FLAGS="--xla_force_host_platform_device_count=${num_devices}" pytest \ + --durations=100 -v -n auto \ + -m "not (gs_login or tpu or high_cpu or fp64)" axlearn/common/flash_attention/layer_test.py \ + --dist worksteal +done