Skip to content

Commit

Permalink
[kernel] port rope cuda kernel to sgl-kernel (#2993)
Browse files Browse the repository at this point in the history
Co-authored-by: Yineng Zhang <[email protected]>
  • Loading branch information
ByronHsu and zhyncs authored Jan 20, 2025
1 parent 73401fd commit b5caa22
Show file tree
Hide file tree
Showing 8 changed files with 255 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,6 @@ work_dirs/
compile_commands.json

*.iml

# VSCode
.vscode
2 changes: 1 addition & 1 deletion sgl-kernel/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "sgl-kernel"
version = "0.0.2.post14"
version = "0.0.2.post15"
description = "Kernel Library for SGLang"
readme = "README.md"
requires-python = ">=3.8"
Expand Down
1 change: 1 addition & 0 deletions sgl-kernel/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def update_wheel_platform_tag():
"src/sgl-kernel/csrc/int8_gemm_kernel.cu",
"src/sgl-kernel/csrc/sampling_scaling_penalties.cu",
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
"src/sgl-kernel/csrc/rotary_embedding.cu",
],
include_dirs=include_dirs,
extra_compile_args={
Expand Down
2 changes: 2 additions & 0 deletions sgl-kernel/src/sgl-kernel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
int8_scaled_mm,
moe_align_block_size,
register_graph_buffers,
rotary_embedding,
sampling_scaling_penalties,
)

Expand All @@ -18,4 +19,5 @@
"sampling_scaling_penalties",
"get_graph_buffer_ipc_meta",
"register_graph_buffers",
"rotary_embedding",
]
119 changes: 119 additions & 0 deletions sgl-kernel/src/sgl-kernel/csrc/rotary_embedding.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
// Reference: https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>

template <typename scalar_t, bool IS_NEOX>
inline __device__ void apply_token_rotary_embedding(scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr,
const scalar_t* __restrict__ sin_ptr, int rot_offset,
int embed_dim) {
int x_index, y_index;
scalar_t cos, sin;
if (IS_NEOX) {
// GPT-NeoX style rotary embedding.
x_index = rot_offset;
y_index = embed_dim + rot_offset;
cos = __ldg(cos_ptr + x_index);
sin = __ldg(sin_ptr + x_index);
} else {
// GPT-J style rotary embedding.
x_index = 2 * rot_offset;
y_index = 2 * rot_offset + 1;
cos = __ldg(cos_ptr + x_index / 2);
sin = __ldg(sin_ptr + x_index / 2);
}

const scalar_t x = arr[x_index];
const scalar_t y = arr[y_index];
arr[x_index] = x * cos - y * sin;
arr[y_index] = y * cos + x * sin;
}

template <typename scalar_t, bool IS_NEOX>
inline __device__ void apply_rotary_embedding(scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const scalar_t* cache_ptr, const int head_size, const int num_heads,
const int num_kv_heads, const int rot_dim, const int token_idx,
const int64_t query_stride, const int64_t key_stride) {
const int embed_dim = rot_dim / 2;
const scalar_t* cos_ptr = cache_ptr;
const scalar_t* sin_ptr = cache_ptr + embed_dim;

const int nq = num_heads * embed_dim;
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
}

const int nk = num_kv_heads * embed_dim;
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
const int head_idx = i / embed_dim;
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
const int rot_offset = i % embed_dim;
apply_token_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim);
}
}

template <typename scalar_t, bool IS_NEOX>
__global__ void rotary_embedding_kernel(const int64_t* __restrict__ positions, // [batch_size, seq_len] or
// [num_tokens]
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
// 2]
const int rot_dim, const int64_t query_stride, const int64_t key_stride,
const int num_heads, const int num_kv_heads, const int head_size) {
// Each thread block is responsible for one token.
const int token_idx = blockIdx.x;
int64_t pos = positions[token_idx];
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim;

apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim,
token_idx, query_stride, key_stride);
}

void rotary_embedding(torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
// [num_tokens, num_heads * head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size]
int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) {
int64_t num_tokens = query.numel() / query.size(-1);
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.size(-1) / head_size;
int64_t query_stride = query.stride(-2);
int64_t key_stride = key.stride(-2);

dim3 grid(num_tokens);
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::BFloat16, at::ScalarType::Half, query.scalar_type(), "rotary_embedding", [&] {
if (is_neox) {
rotary_embedding_kernel<scalar_t, true>
<<<grid, block, 0, stream>>>(positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), rot_dim,
query_stride, key_stride, num_heads, num_kv_heads, head_size);
} else {
rotary_embedding_kernel<scalar_t, false>
<<<grid, block, 0, stream>>>(positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), rot_dim,
query_stride, key_stride, num_heads, num_kv_heads, head_size);
}
});
}
6 changes: 6 additions & 0 deletions sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma
const torch::Tensor& scales_b, const torch::Dtype& out_dtype,
const c10::optional<torch::Tensor>& bias);

// rotary embedding
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// trt_reduce
m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)");
Expand All @@ -39,4 +43,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("sampling_scaling_penalties", &sampling_scaling_penalties, "Sampling scaling penalties (CUDA)");
// int8_scaled_mm
m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)");
// rotary embedding
m.def("rotary_embedding", &rotary_embedding, "Rotary Embedding (CUDA)");
}
5 changes: 5 additions & 0 deletions sgl-kernel/src/sgl-kernel/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm
from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size
from sgl_kernel.ops._kernels import register_graph_buffers as _register_graph_buffers
from sgl_kernel.ops._kernels import rotary_embedding as _rotary_embedding
from sgl_kernel.ops._kernels import (
sampling_scaling_penalties as _sampling_scaling_penalties,
)
Expand Down Expand Up @@ -71,3 +72,7 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
out_dtype,
bias,
)


def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox):
return _rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox)
118 changes: 118 additions & 0 deletions sgl-kernel/tests/test_rotary_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from typing import Optional, Tuple

import torch
from vllm.model_executor.layers.rotary_embedding import (
RotaryEmbedding as VLLMRotaryEmbedding,
)


class SGLRotaryEmbedding(VLLMRotaryEmbedding):

def forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
from sgl_kernel import rotary_embedding

self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)

rotary_embedding(
positions,
query,
key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
)
return query, key


# Compare the output of SGLRotaryEmbedding's forward_cuda with VLLMRotaryEmbedding's forward_native


def test_rotary_embedding():
# Test case 1: FP32
def run_test(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
dtype,
batch_size,
seq_len,
num_heads,
test_name,
):
print(f"\nRunning {test_name}...")
# Initialize both implementations
sgl_rope = SGLRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
).to("cuda")
vllm_rope = VLLMRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
).to("cuda")

# Regular forward pass
positions = torch.arange(seq_len, device="cuda").repeat(batch_size)
query = torch.randn(
batch_size * seq_len, num_heads * head_size, device="cuda", dtype=dtype
)
key = torch.randn(
batch_size * seq_len, num_heads * head_size, device="cuda", dtype=dtype
)

# Make copies for both implementations
query_sgl = query.clone()
key_sgl = key.clone()
query_vllm = query.clone()
key_vllm = key.clone()

# Run both implementations
query_sgl_out, key_sgl_out = sgl_rope.forward_cuda(
positions, query_sgl, key_sgl
)
query_vllm_out, key_vllm_out = vllm_rope.forward_native(
positions, query_vllm, key_vllm
)

# Compare outputs
torch.testing.assert_close(query_sgl_out, query_vllm_out, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(key_sgl_out, key_vllm_out, rtol=1e-3, atol=1e-3)

print(f"{test_name} passed!")

# Test Case 1: FP32 with larger dimensions
run_test(
head_size=128,
rotary_dim=64,
max_position=4096,
base=10000,
is_neox_style=True,
dtype=torch.float32,
batch_size=4,
seq_len=32,
num_heads=8,
test_name="FP32 Test",
)

# Test Case 2: BF16 with smaller dimensions
run_test(
head_size=64,
rotary_dim=32,
max_position=2048,
base=8000,
is_neox_style=True,
dtype=torch.bfloat16,
batch_size=2,
seq_len=16,
num_heads=4,
test_name="BF16 Test",
)


if __name__ == "__main__":
test_rotary_embedding()

0 comments on commit b5caa22

Please sign in to comment.