Skip to content

Commit

Permalink
fix: sgl-kernel link cuda (#2906)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhyncs authored Jan 15, 2025
1 parent 6cb3974 commit a53454c
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 48 deletions.
2 changes: 2 additions & 0 deletions sgl-kernel/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ docker run --rm \
${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir torch==2.4.0 --index-url https://download.pytorch.org/whl/cu${CUDA_VERSION//.} && \
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX' && \
export CUDA_VERSION=${CUDA_VERSION} && \
mkdir -p /usr/lib/x86_64-linux-gnu/ && \
ln -s /usr/local/cuda-${CUDA_VERSION}/targets/x86_64-linux/lib/stubs/libcuda.so /usr/lib/x86_64-linux-gnu/libcuda.so && \
cd /sgl-kernel && \
${PYTHON_ROOT_PATH}/bin/python setup.py bdist_wheel
"
2 changes: 1 addition & 1 deletion sgl-kernel/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "sgl-kernel"
version = "0.0.2.post13"
version = "0.0.2.post14"
description = "Kernel Library for SGLang"
readme = "README.md"
requires-python = ">=3.8"
Expand Down
2 changes: 1 addition & 1 deletion sgl-kernel/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def update_wheel_platform_tag():
]
cxx_flags = ["-O3"]
libraries = ["c10", "torch", "torch_python", "cuda"]
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib"]
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"]
ext_modules = [
CUDAExtension(
name="sgl_kernel.ops._kernels",
Expand Down
81 changes: 38 additions & 43 deletions sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu
Original file line number Diff line number Diff line change
@@ -1,64 +1,59 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include <THC/THCAtomics.cuh>

#include "utils.hpp"
#include "vectorization.cuh"

template <typename scalar_t>
__global__ void sampling_scaling_penalties_kernel(
const scalar_t* logits,
const scalar_t* scaling_penalties,
scalar_t* output,
const int32_t numel) {

const int32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const int32_t stride = blockDim.x * gridDim.x;
__global__ void sampling_scaling_penalties_kernel(const scalar_t* logits, const scalar_t* scaling_penalties,
scalar_t* output, const int32_t numel) {
const int32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const int32_t stride = blockDim.x * gridDim.x;

auto const* vectorized_logits = reinterpret_cast<vec4_t<scalar_t> const*>(logits);
auto const* vectorized_penalties = reinterpret_cast<vec4_t<scalar_t> const*>(scaling_penalties);
auto* vectorized_output = reinterpret_cast<vec4_t<scalar_t>*>(output);
auto const* vectorized_logits = reinterpret_cast<vec4_t<scalar_t> const*>(logits);
auto const* vectorized_penalties = reinterpret_cast<vec4_t<scalar_t> const*>(scaling_penalties);
auto* vectorized_output = reinterpret_cast<vec4_t<scalar_t>*>(output);

const int32_t num_vec_elems = numel >> 2;
const int32_t num_vec_elems = numel >> 2;

#pragma unroll 4
for (int32_t i = tid; i < num_vec_elems; i += stride) {
vec4_t<scalar_t> logits_vec = vectorized_logits[i];
vec4_t<scalar_t> penalties_vec = vectorized_penalties[i];
vec4_t<scalar_t> out_vec;

out_vec.x = logits_vec.x > 0 ? logits_vec.x / penalties_vec.x : logits_vec.x * penalties_vec.x;
out_vec.y = logits_vec.y > 0 ? logits_vec.y / penalties_vec.y : logits_vec.y * penalties_vec.y;
out_vec.z = logits_vec.z > 0 ? logits_vec.z / penalties_vec.z : logits_vec.z * penalties_vec.z;
out_vec.w = logits_vec.w > 0 ? logits_vec.w / penalties_vec.w : logits_vec.w * penalties_vec.w;

vectorized_output[i] = out_vec;
}

const int32_t start_idx = num_vec_elems * 4;
for (int32_t i = start_idx + tid; i < numel; i += stride) {
scalar_t logit = logits[i];
scalar_t penalty = scaling_penalties[i];
output[i] = logit > 0 ? logit / penalty : logit * penalty;
}
for (int32_t i = tid; i < num_vec_elems; i += stride) {
vec4_t<scalar_t> logits_vec = vectorized_logits[i];
vec4_t<scalar_t> penalties_vec = vectorized_penalties[i];
vec4_t<scalar_t> out_vec;

out_vec.x = logits_vec.x > 0 ? logits_vec.x / penalties_vec.x : logits_vec.x * penalties_vec.x;
out_vec.y = logits_vec.y > 0 ? logits_vec.y / penalties_vec.y : logits_vec.y * penalties_vec.y;
out_vec.z = logits_vec.z > 0 ? logits_vec.z / penalties_vec.z : logits_vec.z * penalties_vec.z;
out_vec.w = logits_vec.w > 0 ? logits_vec.w / penalties_vec.w : logits_vec.w * penalties_vec.w;

vectorized_output[i] = out_vec;
}

const int32_t start_idx = num_vec_elems * 4;
for (int32_t i = start_idx + tid; i < numel; i += stride) {
scalar_t logit = logits[i];
scalar_t penalty = scaling_penalties[i];
output[i] = logit > 0 ? logit / penalty : logit * penalty;
}
}

torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torch::Tensor& scaling_penalties) {
auto output = torch::empty_like(logits);
const auto numel = logits.numel();
const int threads = 512;
auto output = torch::empty_like(logits);
const auto numel = logits.numel();
const int threads = 512;

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
logits.scalar_type(), "sampling_scaling_penalties_kernel", ([&] {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, logits.scalar_type(), "sampling_scaling_penalties_kernel", ([&] {
const int blocks = (numel + threads * 4 - 1) / (threads * 4);
sampling_scaling_penalties_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
logits.data_ptr<scalar_t>(),
scaling_penalties.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(),
numel);
}));
logits.data_ptr<scalar_t>(), scaling_penalties.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(), numel);
}));

return output;
return output;
}
5 changes: 2 additions & 3 deletions sgl-kernel/src/sgl-kernel/csrc/vectorization.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

// Include both AMD and NVIDIA fp8 types to avoid circular import
// TODO(luka/varun) use FP8_TYPE instead after refactoring
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fnuz.h>

// Vectorization containers
template <typename scalar_t>
Expand All @@ -20,8 +20,7 @@ struct __align__(8) vec4_t {

template <typename quant_type_t>
struct __align__(4) q8x4_t {
static_assert(std::is_same_v<quant_type_t, int8_t> ||
std::is_same_v<quant_type_t, c10::Float8_e4m3fn> ||
static_assert(std::is_same_v<quant_type_t, int8_t> || std::is_same_v<quant_type_t, c10::Float8_e4m3fn> ||
std::is_same_v<quant_type_t, c10::Float8_e4m3fnuz>);
quant_type_t x;
quant_type_t y;
Expand Down

0 comments on commit a53454c

Please sign in to comment.