Skip to content

Commit

Permalink
add sampling_scaling_penalties kernel (#2846)
Browse files Browse the repository at this point in the history
  • Loading branch information
BBuf authored Jan 13, 2025
1 parent c4f9707 commit e2b16c4
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 1 deletion.
1 change: 1 addition & 0 deletions sgl-kernel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ add_library(_kernels SHARED
src/sgl-kernel/csrc/trt_reduce_kernel.cu
src/sgl-kernel/csrc/moe_align_kernel.cu
src/sgl-kernel/csrc/int8_gemm_kernel.cu
src/sgl-kernel/csrc/sampling_scaling_penalties.cu
src/sgl-kernel/csrc/sgl_kernel_ops.cu
)

Expand Down
2 changes: 1 addition & 1 deletion sgl-kernel/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "sgl-kernel"
version = "0.0.2.post11"
version = "0.0.2.post12"
description = "Kernel Library for SGLang"
readme = "README.md"
requires-python = ">=3.8"
Expand Down
1 change: 1 addition & 0 deletions sgl-kernel/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def update_wheel_platform_tag():
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/moe_align_kernel.cu",
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
],
include_dirs=include_dirs,
Expand Down
2 changes: 2 additions & 0 deletions sgl-kernel/src/sgl-kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
init_custom_reduce,
int8_scaled_mm,
moe_align_block_size,
sampling_scaling_penalties,
)

__all__ = [
Expand All @@ -12,4 +13,5 @@
"custom_dispose",
"custom_reduce",
"int8_scaled_mm",
"sampling_scaling_penalties",
]
64 changes: 64 additions & 0 deletions sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <THC/THCAtomics.cuh>
#include "utils.hpp"
#include "vectorization.cuh"

template <typename scalar_t>
__global__ void sampling_scaling_penalties_kernel(
const scalar_t* logits,
const scalar_t* scaling_penalties,
scalar_t* output,
const int32_t numel) {

const int32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const int32_t stride = blockDim.x * gridDim.x;

auto const* vectorized_logits = reinterpret_cast<vec4_t<scalar_t> const*>(logits);
auto const* vectorized_penalties = reinterpret_cast<vec4_t<scalar_t> const*>(scaling_penalties);
auto* vectorized_output = reinterpret_cast<vec4_t<scalar_t>*>(output);

const int32_t num_vec_elems = numel >> 2;

#pragma unroll 4
for (int32_t i = tid; i < num_vec_elems; i += stride) {
vec4_t<scalar_t> logits_vec = vectorized_logits[i];
vec4_t<scalar_t> penalties_vec = vectorized_penalties[i];
vec4_t<scalar_t> out_vec;

out_vec.x = logits_vec.x > 0 ? logits_vec.x / penalties_vec.x : logits_vec.x * penalties_vec.x;
out_vec.y = logits_vec.y > 0 ? logits_vec.y / penalties_vec.y : logits_vec.y * penalties_vec.y;
out_vec.z = logits_vec.z > 0 ? logits_vec.z / penalties_vec.z : logits_vec.z * penalties_vec.z;
out_vec.w = logits_vec.w > 0 ? logits_vec.w / penalties_vec.w : logits_vec.w * penalties_vec.w;

vectorized_output[i] = out_vec;
}

const int32_t start_idx = num_vec_elems * 4;
for (int32_t i = start_idx + tid; i < numel; i += stride) {
scalar_t logit = logits[i];
scalar_t penalty = scaling_penalties[i];
output[i] = logit > 0 ? logit / penalty : logit * penalty;
}
}

torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torch::Tensor& scaling_penalties) {
auto output = torch::empty_like(logits);
const auto numel = logits.numel();
const int threads = 512;

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
logits.scalar_type(), "sampling_scaling_penalties_kernel", ([&] {
const int blocks = (numel + threads * 4 - 1) / (threads * 4);
sampling_scaling_penalties_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
logits.data_ptr<scalar_t>(),
scaling_penalties.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(),
numel);
}));

return output;
}
5 changes: 5 additions & 0 deletions sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad,
torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer);

// sampling_scaling_penalties
torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torch::Tensor& scaling_penalties);

// int8_scaled_mm
torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a,
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
Expand All @@ -24,6 +27,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("all_reduce", &all_reduce, "custom all reduce (CUDA)");
// moe_align_block_size
m.def("moe_align_block_size", &moe_align_block_size, "MOE Align Block Size (CUDA)");
// sampling_scaling_penalties
m.def("sampling_scaling_penalties", &sampling_scaling_penalties, "Sampling scaling penalties (CUDA)");
// int8_scaled_mm
m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)");
}
30 changes: 30 additions & 0 deletions sgl-kernel/src/sgl-kernel/csrc/vectorization.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Adapted from https://github.com/vllm-project/vllm/blob/main/csrc/quantization/vectorization.cuh
#pragma once
/**
* __device__ datatypes vectorized by 4
*/

// Include both AMD and NVIDIA fp8 types to avoid circular import
// TODO(luka/varun) use FP8_TYPE instead after refactoring
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e4m3fn.h>

// Vectorization containers
template <typename scalar_t>
struct __align__(8) vec4_t {
scalar_t x;
scalar_t y;
scalar_t z;
scalar_t w;
};

template <typename quant_type_t>
struct __align__(4) q8x4_t {
static_assert(std::is_same_v<quant_type_t, int8_t> ||
std::is_same_v<quant_type_t, c10::Float8_e4m3fn> ||
std::is_same_v<quant_type_t, c10::Float8_e4m3fnuz>);
quant_type_t x;
quant_type_t y;
quant_type_t z;
quant_type_t w;
};
7 changes: 7 additions & 0 deletions sgl-kernel/src/sgl-kernel/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar
from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm
from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size
from sgl_kernel.ops._kernels import (
sampling_scaling_penalties as _sampling_scaling_penalties,
)


def init_custom_reduce(rank_id, num_devices, buffers, barrier_in, barrier_out):
Expand Down Expand Up @@ -39,6 +42,10 @@ def moe_align_block_size(
)


def sampling_scaling_penalties(logits, scaling_penalties):
return _sampling_scaling_penalties(logits, scaling_penalties)


def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
return _int8_scaled_mm(
mat_a,
Expand Down
39 changes: 39 additions & 0 deletions sgl-kernel/tests/test_sampling_scaling_penalties.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch
from sgl_kernel import sampling_scaling_penalties


def test_sampling_scaling_penalties():
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 65]
vocab_sizes = [2048, 4096, 8192, 16384, 32768, 32767]
dtypes = [torch.float32, torch.half, torch.bfloat16]
device = torch.device("cuda")

for dtype in dtypes:
rtol = 1e-3
atol = 1e-3

for bs in batch_sizes:
for vocab_size in vocab_sizes:
logits = torch.randn(bs, vocab_size, device=device, dtype=dtype)
scaling_penalties = (
torch.rand(bs, vocab_size, device=device, dtype=dtype) + 0.5
)

ref_output = torch.where(
logits > 0, logits / scaling_penalties, logits * scaling_penalties
)

kernel_output = sampling_scaling_penalties(logits, scaling_penalties)

torch.testing.assert_close(
kernel_output,
ref_output,
rtol=rtol,
atol=atol,
msg=f"Failed for batch_size={bs}, vocab_size={vocab_size}, dtype={dtype}",
)


if __name__ == "__main__":
test_sampling_scaling_penalties()
print("All tests passed!")

0 comments on commit e2b16c4

Please sign in to comment.