Skip to content

Commit

Permalink
Fix CUDA plugin CI. (#8593)
Browse files Browse the repository at this point in the history
  • Loading branch information
ysiraichi authored Feb 12, 2025
1 parent 3a4ad6f commit 42edbe1
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 56 deletions.
12 changes: 4 additions & 8 deletions .github/workflows/_build_torch_with_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ jobs:
image: ${{ inputs.dev-image }}
env:
_GLIBCXX_USE_CXX11_ABI: 0
TORCH_CUDA_ARCH_LIST: "5.2;7.0;7.5;8.0;9.0"
USE_CUDA: 1
MAX_JOBS: 24
steps:
- name: Checkout actions
uses: actions/checkout@v4
Expand All @@ -34,18 +37,11 @@ jobs:
with:
torch-commit: ${{ inputs.torch-commit }}
cuda: true
- name: Checkout PyTorch Repo
uses: actions/checkout@v4
with:
repository: pytorch/pytorch
path: pytorch
ref: ${{ inputs.torch-commit }}
submodules: recursive
- name: Build PyTorch with CUDA enabled
shell: bash
run: |
cd pytorch
TORCH_CUDA_ARCH_LIST="5.2;8.6" USE_CUDA=1 MAX_JOBS="$(nproc --ignore=4)" python setup.py bdist_wheel
python setup.py bdist_wheel
- name: Upload wheel
uses: actions/upload-artifact@v4
with:
Expand Down
8 changes: 6 additions & 2 deletions .github/workflows/_test_requiring_torch_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,12 @@ jobs:
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install -U --pre jax-cuda12-pjrt jax-cuda12-plugin -f https://storage.googleapis.com/jax-releases/jax_cuda_plugin_nightly_releases.html
pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
pip install --no-deps triton==2.3.0
if: ${{ matrix.run_triton_tests }}
- name: Install Triton
shell: bash
run: |
cd pytorch
make triton
- name: Python Tests
shell: bash
run: |
Expand All @@ -106,5 +110,5 @@ jobs:
- name: Triton Tests
shell: bash
run: |
PJRT_DEVICE=CUDA TRITON_PTXAS_PATH=/usr/local/cuda-12.1/bin/ptxas python pytorch/xla/test/test_triton.py
PJRT_DEVICE=CUDA TRITON_PTXAS_PATH=/usr/local/cuda-12.3/bin/ptxas python pytorch/xla/test/test_triton.py
if: ${{ matrix.run_triton_tests }}
84 changes: 41 additions & 43 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,25 +42,25 @@ jobs:
secrets:
gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }}

# Disable due to https://github.com/pytorch/xla/issues/8199
# build-torch-with-cuda:
# name: "Build PyTorch with CUDA"
# uses: ./.github/workflows/_build_torch_with_cuda.yml
# needs: get-torch-commit
# with:
# # note that to build a torch wheel with CUDA enabled, we do not need a GPU runner.
# dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_cuda_12.1
# torch-commit: ${{needs.get-torch-commit.outputs.torch_commit}}
# runner: linux.24xlarge
build-torch-with-cuda:
name: "Build PyTorch with CUDA"
uses: ./.github/workflows/_build_torch_with_cuda.yml
needs: get-torch-commit
with:
# TODO: bump CUDA version to either 12.4 or 12.6 (supported by PyTorch).
# Ref: https://github.com/pytorch/xla/issues/8700
dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_cuda_12.3
torch-commit: ${{needs.get-torch-commit.outputs.torch_commit}}
# note that to build a torch wheel with CUDA enabled, we do not need a GPU runner.
runner: linux.24xlarge

# Disable due to https://github.com/pytorch/xla/issues/8199
# build-cuda-plugin:
# name: "Build XLA CUDA plugin"
# uses: ./.github/workflows/_build_plugin.yml
# with:
# dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_cuda_12.1
# secrets:
# gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }}
build-cuda-plugin:
name: "Build XLA CUDA plugin"
uses: ./.github/workflows/_build_plugin.yml
with:
dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_cuda_12.3
secrets:
gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }}

test-python-cpu:
name: "CPU tests"
Expand All @@ -74,32 +74,30 @@ jobs:
secrets:
gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }}

# Disable due to https://github.com/pytorch/xla/issues/8199
# test-cuda:
# name: "GPU tests"
# uses: ./.github/workflows/_test.yml
# needs: [build-torch-xla, build-cuda-plugin, get-torch-commit]
# with:
# dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_cuda_12.1
# runner: linux.8xlarge.nvidia.gpu
# timeout-minutes: 300
# collect-coverage: false
# install-cuda-plugin: true
# torch-commit: ${{needs.get-torch-commit.outputs.torch_commit}}
# secrets:
# gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }}
test-cuda:
name: "GPU tests"
uses: ./.github/workflows/_test.yml
needs: [build-torch-xla, build-cuda-plugin, get-torch-commit]
with:
dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_cuda_12.3
runner: linux.g4dn.12xlarge.nvidia.gpu
timeout-minutes: 300
collect-coverage: false
install-cuda-plugin: true
torch-commit: ${{needs.get-torch-commit.outputs.torch_commit}}
secrets:
gcloud-service-key: ${{ secrets.GCLOUD_SERVICE_KEY }}

# Disable due to https://github.com/pytorch/xla/issues/8199
# test-cuda-with-pytorch-cuda-enabled:
# name: "GPU tests requiring torch CUDA"
# uses: ./.github/workflows/_test_requiring_torch_cuda.yml
# needs: [build-torch-with-cuda, build-torch-xla, build-cuda-plugin, get-torch-commit]
# with:
# dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_cuda_12.1
# runner: linux.8xlarge.nvidia.gpu
# timeout-minutes: 300
# collect-coverage: false
# torch-commit: ${{needs.get-torch-commit.outputs.torch_commit}}
test-cuda-with-pytorch-cuda-enabled:
name: "GPU tests requiring torch CUDA"
uses: ./.github/workflows/_test_requiring_torch_cuda.yml
needs: [build-torch-with-cuda, build-torch-xla, build-cuda-plugin, get-torch-commit]
with:
dev-image: us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/development:3.10_cuda_12.3
runner: linux.8xlarge.nvidia.gpu
timeout-minutes: 300
collect-coverage: false
torch-commit: ${{needs.get-torch-commit.outputs.torch_commit}}

test-tpu:
name: "TPU tests"
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/setup/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ runs:
- name: Setup CUDA environment
shell: bash
run: |
echo "PATH=$PATH:/usr/local/cuda-12.1/bin" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.1/lib64" >> $GITHUB_ENV
echo "PATH=$PATH:/usr/local/cuda-12.3/bin" >> $GITHUB_ENV
echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-12.3/lib64" >> $GITHUB_ENV
if: ${{ inputs.cuda }}
- name: Setup gcloud
shell: bash
Expand Down
2 changes: 1 addition & 1 deletion infra/ansible/config/vars.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
cuda_repo: debian11
cuda_version: "11.8"
# Determines supported GPUs. See https://developer.nvidia.com/cuda-gpus
cuda_compute_capabilities: 7.0,7.5,8.0,9.0
cuda_compute_capabilities: 5.2,7.0,7.5,8.0,9.0
# Used for fetching clang from the right repo, see apt.yaml.
llvm_debian_repo: bullseye
clang_version: 17
Expand Down
3 changes: 3 additions & 0 deletions test/test_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch_xla.experimental.triton as xla_triton
import torch_xla
from torch_xla import runtime as xr
from torch_xla.test.test_utils import skipIfCUDA

import triton
import triton.language as tl
Expand Down Expand Up @@ -241,6 +242,8 @@ def _attn_fwd(
tl.store(O_block_ptr, acc.to(Out.type.element_ty))


# Ref: https://github.com/pytorch/xla/pull/8593
@skipIfCUDA("GPU CI is failing")
class TritonTest(unittest.TestCase):

@unittest.skipIf(xr.device_type() != 'CUDA', "This test only works on GPU.")
Expand Down
3 changes: 3 additions & 0 deletions test/torch_distributed/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import torch_xla
import torch_xla.core.xla_model as xm
from torch_xla.test.test_utils import skipIfCUDA

# Setup import folders.
xla_test_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))
Expand Down Expand Up @@ -38,6 +39,8 @@ def _ddp_correctness(rank,
def test_ddp_correctness(self):
torch_xla.launch(self._ddp_correctness, args=(False, FLAGS.debug))

# Ref: https://github.com/pytorch/xla/pull/8593
@skipIfCUDA("GPU CI is failing")
def test_ddp_correctness_with_gradient_as_bucket_view(self):
torch_xla.launch(self._ddp_correctness, args=(False, FLAGS.debug, True))

Expand Down
5 changes: 5 additions & 0 deletions torch_xla/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
import torch_xla.utils.utils as xu


def skipIfCUDA(reason):
accelerator = xr.device_type() or ""
return lambda f: unittest.skipIf(accelerator.lower() == "cuda", reason)(f)


def mp_test(func):
"""Wraps a `unittest.TestCase` function running it within an isolated process.
Expand Down

0 comments on commit 42edbe1

Please sign in to comment.