Skip to content

Commit

Permalink
upd
Browse files Browse the repository at this point in the history
  • Loading branch information
zhyncs committed Jan 30, 2025
1 parent bc9b56d commit fc24280
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions sgl-kernel/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,9 @@ def _get_version():
"3rdparty/tensorrt_llm/common/stringUtils.cpp",
"3rdparty/tensorrt_llm/common/tllmException.cpp",
"3rdparty/tensorrt_llm/common/cudaFp8Utils.cu",
]

sources_cutlass_moe_gemm = [
"3rdparty/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp",
"3rdparty/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp",
"3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu",
Expand All @@ -123,18 +126,30 @@ def _get_version():
if torch.cuda.is_available():
if cuda_version >= (12, 0) and sm_version >= 90:
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
nvcc_flags.append("-DCOMPILE_HOPPER_TMA_GEMMS")
nvcc_flags.append("-DCUTLASS_ARCH_MMA_SM90_SUPPORTED")
sources.extend(sources_cutlass_moe_gemm)
if sm_version >= 90:
nvcc_flags.extend(nvcc_flags_fp8)
nvcc_flags.append("-DENABLE_FP8")
if sm_version >= 80:
nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")
nvcc_flags.append("-DCUTLASS_ARCH_MMA_SM80_SUPPORTED")
nvcc_flags.append("-DENABLE_BF16")
else:
# compilation environment without GPU
if enable_sm90a:
nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a")
nvcc_flags.append("-DCOMPILE_HOPPER_TMA_GEMMS")
nvcc_flags.append("-DCUTLASS_ARCH_MMA_SM90_SUPPORTED")
sources.extend(sources_cutlass_moe_gemm)
if enable_fp8:
nvcc_flags.extend(nvcc_flags_fp8)
nvcc_flags.append("-DENABLE_FP8")
if enable_bf16:
nvcc_flags.append("-DFLASHINFER_ENABLE_BF16")
nvcc_flags.append("-DCUTLASS_ARCH_MMA_SM80_SUPPORTED")
nvcc_flags.append("-DENABLE_BF16")

for flag in [
"-D__CUDA_NO_HALF_OPERATORS__",
Expand Down

0 comments on commit fc24280

Please sign in to comment.