From a53454c55e222fcc7375676b665b7e1276464170 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Thu, 16 Jan 2025 04:53:23 +0800 Subject: [PATCH] fix: sgl-kernel link cuda (#2906) --- sgl-kernel/build.sh | 2 + sgl-kernel/pyproject.toml | 2 +- sgl-kernel/setup.py | 2 +- .../csrc/sampling_scaling_penalties.cu | 81 +++++++++---------- .../src/sgl-kernel/csrc/vectorization.cuh | 5 +- 5 files changed, 44 insertions(+), 48 deletions(-) diff --git a/sgl-kernel/build.sh b/sgl-kernel/build.sh index 799b724dfe6..55ce9df7f33 100755 --- a/sgl-kernel/build.sh +++ b/sgl-kernel/build.sh @@ -11,6 +11,8 @@ docker run --rm \ ${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir torch==2.4.0 --index-url https://download.pytorch.org/whl/cu${CUDA_VERSION//.} && \ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX' && \ export CUDA_VERSION=${CUDA_VERSION} && \ + mkdir -p /usr/lib/x86_64-linux-gnu/ && \ + ln -s /usr/local/cuda-${CUDA_VERSION}/targets/x86_64-linux/lib/stubs/libcuda.so /usr/lib/x86_64-linux-gnu/libcuda.so && \ cd /sgl-kernel && \ ${PYTHON_ROOT_PATH}/bin/python setup.py bdist_wheel " diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index 6a6a0d1fe4e..b0554bd8fed 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.0.2.post13" +version = "0.0.2.post14" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.8" diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 2d2d9258ade..da6b465d841 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -41,7 +41,7 @@ def update_wheel_platform_tag(): ] cxx_flags = ["-O3"] libraries = ["c10", "torch", "torch_python", "cuda"] -extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib"] +extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"] ext_modules = [ CUDAExtension( name="sgl_kernel.ops._kernels", diff --git a/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu b/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu index 30264caa366..a61d4b86059 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu @@ -1,64 +1,59 @@ #include #include #include + #include + #include "utils.hpp" #include "vectorization.cuh" template -__global__ void sampling_scaling_penalties_kernel( - const scalar_t* logits, - const scalar_t* scaling_penalties, - scalar_t* output, - const int32_t numel) { - - const int32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - const int32_t stride = blockDim.x * gridDim.x; +__global__ void sampling_scaling_penalties_kernel(const scalar_t* logits, const scalar_t* scaling_penalties, + scalar_t* output, const int32_t numel) { + const int32_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const int32_t stride = blockDim.x * gridDim.x; - auto const* vectorized_logits = reinterpret_cast const*>(logits); - auto const* vectorized_penalties = reinterpret_cast const*>(scaling_penalties); - auto* vectorized_output = reinterpret_cast*>(output); + auto const* vectorized_logits = reinterpret_cast const*>(logits); + auto const* vectorized_penalties = reinterpret_cast const*>(scaling_penalties); + auto* vectorized_output = reinterpret_cast*>(output); - const int32_t num_vec_elems = numel >> 2; + const int32_t num_vec_elems = numel >> 2; #pragma unroll 4 - for (int32_t i = tid; i < num_vec_elems; i += stride) { - vec4_t logits_vec = vectorized_logits[i]; - vec4_t penalties_vec = vectorized_penalties[i]; - vec4_t out_vec; - - out_vec.x = logits_vec.x > 0 ? logits_vec.x / penalties_vec.x : logits_vec.x * penalties_vec.x; - out_vec.y = logits_vec.y > 0 ? logits_vec.y / penalties_vec.y : logits_vec.y * penalties_vec.y; - out_vec.z = logits_vec.z > 0 ? logits_vec.z / penalties_vec.z : logits_vec.z * penalties_vec.z; - out_vec.w = logits_vec.w > 0 ? logits_vec.w / penalties_vec.w : logits_vec.w * penalties_vec.w; - - vectorized_output[i] = out_vec; - } - - const int32_t start_idx = num_vec_elems * 4; - for (int32_t i = start_idx + tid; i < numel; i += stride) { - scalar_t logit = logits[i]; - scalar_t penalty = scaling_penalties[i]; - output[i] = logit > 0 ? logit / penalty : logit * penalty; - } + for (int32_t i = tid; i < num_vec_elems; i += stride) { + vec4_t logits_vec = vectorized_logits[i]; + vec4_t penalties_vec = vectorized_penalties[i]; + vec4_t out_vec; + + out_vec.x = logits_vec.x > 0 ? logits_vec.x / penalties_vec.x : logits_vec.x * penalties_vec.x; + out_vec.y = logits_vec.y > 0 ? logits_vec.y / penalties_vec.y : logits_vec.y * penalties_vec.y; + out_vec.z = logits_vec.z > 0 ? logits_vec.z / penalties_vec.z : logits_vec.z * penalties_vec.z; + out_vec.w = logits_vec.w > 0 ? logits_vec.w / penalties_vec.w : logits_vec.w * penalties_vec.w; + + vectorized_output[i] = out_vec; + } + + const int32_t start_idx = num_vec_elems * 4; + for (int32_t i = start_idx + tid; i < numel; i += stride) { + scalar_t logit = logits[i]; + scalar_t penalty = scaling_penalties[i]; + output[i] = logit > 0 ? logit / penalty : logit * penalty; + } } torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torch::Tensor& scaling_penalties) { - auto output = torch::empty_like(logits); - const auto numel = logits.numel(); - const int threads = 512; + auto output = torch::empty_like(logits); + const auto numel = logits.numel(); + const int threads = 512; - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, - logits.scalar_type(), "sampling_scaling_penalties_kernel", ([&] { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, logits.scalar_type(), "sampling_scaling_penalties_kernel", ([&] { const int blocks = (numel + threads * 4 - 1) / (threads * 4); sampling_scaling_penalties_kernel<<>>( - logits.data_ptr(), - scaling_penalties.data_ptr(), - output.data_ptr(), - numel); - })); + logits.data_ptr(), scaling_penalties.data_ptr(), output.data_ptr(), numel); + })); - return output; + return output; } diff --git a/sgl-kernel/src/sgl-kernel/csrc/vectorization.cuh b/sgl-kernel/src/sgl-kernel/csrc/vectorization.cuh index cb36d0e7a45..2bfb710189b 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/vectorization.cuh +++ b/sgl-kernel/src/sgl-kernel/csrc/vectorization.cuh @@ -6,8 +6,8 @@ // Include both AMD and NVIDIA fp8 types to avoid circular import // TODO(luka/varun) use FP8_TYPE instead after refactoring -#include #include +#include // Vectorization containers template @@ -20,8 +20,7 @@ struct __align__(8) vec4_t { template struct __align__(4) q8x4_t { - static_assert(std::is_same_v || - std::is_same_v || + static_assert(std::is_same_v || std::is_same_v || std::is_same_v); quant_type_t x; quant_type_t y;