Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPUAI-1250 - Flash Attention v2.04 module rotary cannot be used code fixed #47

Open
wants to merge 1 commit into
base: flash_attention_for_rocm
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/rotary/rotary_cuda.cu → csrc/rotary/rotary_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#include <torch/python.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/native/hip/Loops.cuh>

void apply_rotary_cuda(const torch::Tensor x1, const torch::Tensor x2,
const torch::Tensor cos, const torch::Tensor sin,
Expand Down
114 changes: 88 additions & 26 deletions csrc/rotary/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import sys
import warnings
import os
import glob
import shutil
from packaging.version import parse, Version

import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, IS_HIP_EXTENSION
from setuptools import setup, find_packages
import subprocess

Expand Down Expand Up @@ -89,34 +91,94 @@ def append_nvcc_threads(nvcc_extra_args):
cmdclass = {}
ext_modules = []

raise_if_cuda_home_none("rotary_emb")
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version < Version("11.0"):
raise RuntimeError("rotary_emb is only supported on CUDA 11 and above")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_70,code=sm_70")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80")
if bare_metal_version >= Version("11.8"):
def build_for_cuda():
raise_if_cuda_home_none("rotary_emb")
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version < Version("11.0"):
raise RuntimeError("rotary_emb is only supported on CUDA 11 and above")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_70,code=sm_70")
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")

ext_modules.append(
CUDAExtension(
'rotary_emb', [
'rotary.cpp',
'rotary_cuda.cu',
],
extra_compile_args={'cxx': ['-g', '-march=native', '-funroll-loops'],
'nvcc': append_nvcc_threads([
'-O3', '--use_fast_math', '--expt-extended-lambda'
] + cc_flag)
}
cc_flag.append("arch=compute_80,code=sm_80")
if bare_metal_version >= Version("11.8"):
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90")

ext_modules.append(
CUDAExtension(
'rotary_emb', [
'rotary.cpp',
'rotary_cuda.cu',
],
extra_compile_args={'cxx': ['-g', '-march=native', '-funroll-loops'],
'nvcc': append_nvcc_threads([
'-O3', '--use_fast_math', '--expt-extended-lambda'
] + cc_flag)
}
)
)


def rename_cpp_to_hip(cpp_files):
for entry in cpp_files:
shutil.copy(entry, os.path.splitext(entry)[0] + ".hip")

# Defining a function to validate the GPU architectures and update them if necessary
def validate_and_update_archs(archs):
# List of allowed architectures
allowed_archs = ["native", "gfx90a", "gfx940", "gfx941", "gfx942"]

# Validate if each element in archs is in allowed_archs
assert all(
arch in allowed_archs for arch in archs
), f"One of GPU archs of {archs} is invalid or not supported by Flash-Attention"


def build_for_rocm():
"""build for ROCm platform"""

archs = os.getenv("GPU_ARCHS", "native").split(";")
validate_and_update_archs(archs)
cc_flag = [f"--offload-arch={arch}" for arch in archs]

if int(os.environ.get("FLASH_ATTENTION_INTERNAL_USE_RTN", 0)):
print("RTN IS USED")
cc_flag.append("-DUSE_RTN_BF16_CONVERT")
else:
print("RTZ IS USED")

fa_sources = ["rotary.cpp", "rotary_cuda.cpp"] #+ glob.glob("src/*.cpp")

rename_cpp_to_hip(fa_sources)

ext_modules.append(
CUDAExtension(
'rotary_emb', [
'rotary.hip',
'rotary_cuda.hip'
], #+ glob.glob("src/*.hip"),
extra_compile_args={'cxx': ['-g', '-march=native', '-funroll-loops',"-DNDEBUG"],
'nvcc': [
'-O3', "-DNDEBUG"
] + cc_flag
}
)
)
)

BUILD_TARGET = os.environ.get("BUILD_TARGET", "auto")

if BUILD_TARGET == "auto":
if IS_HIP_EXTENSION:
build_for_rocm()
else:
build_for_cuda()
else:
if BUILD_TARGET == "cuda":
build_for_cuda()
elif BUILD_TARGET == "rocm":
build_for_rocm()
setup(
name="rotary_emb",
version="0.1",
Expand Down