From 8a96f749885cb52a0e38381d47fd80e0897e3bb3 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Mon, 27 Jan 2025 20:29:28 +0800 Subject: [PATCH] chore: bump 0.0.3 for sgl-kernel (#3178) Co-authored-by: ispobock Co-authored-by: BBuf <35585791+BBuf@users.noreply.github.com> Co-authored-by: HandH1998 <007aabbcc411@gmail.com> Co-authored-by: yizhang2077 <1109276519@qq.com> Co-authored-by: ByronHsu --- sgl-kernel/pyproject.toml | 2 +- .../src/sgl-kernel/csrc/rotary_embedding.cu | 119 ------------------ sgl-kernel/version.py | 2 +- 3 files changed, 2 insertions(+), 121 deletions(-) delete mode 100644 sgl-kernel/src/sgl-kernel/csrc/rotary_embedding.cu diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index 8664fb09021..aca6f045054 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.0.2.post20" +version = "0.0.3" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.9" diff --git a/sgl-kernel/src/sgl-kernel/csrc/rotary_embedding.cu b/sgl-kernel/src/sgl-kernel/csrc/rotary_embedding.cu deleted file mode 100644 index d02554fb11c..00000000000 --- a/sgl-kernel/src/sgl-kernel/csrc/rotary_embedding.cu +++ /dev/null @@ -1,119 +0,0 @@ -// Reference: https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu - -#include -#include -#include - -template -inline __device__ void apply_token_rotary_embedding(scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr, - const scalar_t* __restrict__ sin_ptr, int rot_offset, - int embed_dim) { - int x_index, y_index; - scalar_t cos, sin; - if (IS_NEOX) { - // GPT-NeoX style rotary embedding. - x_index = rot_offset; - y_index = embed_dim + rot_offset; - cos = __ldg(cos_ptr + x_index); - sin = __ldg(sin_ptr + x_index); - } else { - // GPT-J style rotary embedding. - x_index = 2 * rot_offset; - y_index = 2 * rot_offset + 1; - cos = __ldg(cos_ptr + x_index / 2); - sin = __ldg(sin_ptr + x_index / 2); - } - - const scalar_t x = arr[x_index]; - const scalar_t y = arr[y_index]; - arr[x_index] = x * cos - y * sin; - arr[y_index] = y * cos + x * sin; -} - -template -inline __device__ void apply_rotary_embedding(scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, - // head_size] or [num_tokens, num_heads, - // head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, - // head_size] or [num_tokens, num_kv_heads, - // head_size] - const scalar_t* cache_ptr, const int head_size, const int num_heads, - const int num_kv_heads, const int rot_dim, const int token_idx, - const int64_t query_stride, const int64_t key_stride) { - const int embed_dim = rot_dim / 2; - const scalar_t* cos_ptr = cache_ptr; - const scalar_t* sin_ptr = cache_ptr + embed_dim; - - const int nq = num_heads * embed_dim; - for (int i = threadIdx.x; i < nq; i += blockDim.x) { - const int head_idx = i / embed_dim; - const int64_t token_head = token_idx * query_stride + head_idx * head_size; - const int rot_offset = i % embed_dim; - apply_token_rotary_embedding(query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); - } - - const int nk = num_kv_heads * embed_dim; - for (int i = threadIdx.x; i < nk; i += blockDim.x) { - const int head_idx = i / embed_dim; - const int64_t token_head = token_idx * key_stride + head_idx * head_size; - const int rot_offset = i % embed_dim; - apply_token_rotary_embedding(key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); - } -} - -template -__global__ void rotary_embedding_kernel(const int64_t* __restrict__ positions, // [batch_size, seq_len] or - // [num_tokens] - scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, - // head_size] or [num_tokens, num_heads, - // head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, - // head_size] or [num_tokens, num_kv_heads, - // head_size] - const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // - // 2] - const int rot_dim, const int64_t query_stride, const int64_t key_stride, - const int num_heads, const int num_kv_heads, const int head_size) { - // Each thread block is responsible for one token. - const int token_idx = blockIdx.x; - int64_t pos = positions[token_idx]; - const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; - - apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, - token_idx, query_stride, key_stride); -} - -void rotary_embedding(torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] - torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or - // [num_tokens, num_heads * head_size] - torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or - // [num_tokens, num_kv_heads * head_size] - int64_t head_size, - torch::Tensor& cos_sin_cache, // [max_position, rot_dim] - bool is_neox) { - int64_t num_tokens = query.numel() / query.size(-1); - int rot_dim = cos_sin_cache.size(1); - int num_heads = query.size(-1) / head_size; - int num_kv_heads = key.size(-1) / head_size; - int64_t query_stride = query.stride(-2); - int64_t key_stride = key.stride(-2); - - dim3 grid(num_tokens); // each block is responsible for one token - dim3 block(std::min(num_heads * rot_dim / 2, 512)); - const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::BFloat16, at::ScalarType::Half, query.scalar_type(), "rotary_embedding", [&] { - if (is_neox) { - rotary_embedding_kernel - <<>>(positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), rot_dim, - query_stride, key_stride, num_heads, num_kv_heads, head_size); - } else { - rotary_embedding_kernel - <<>>(positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), rot_dim, - query_stride, key_stride, num_heads, num_kv_heads, head_size); - } - }); -} diff --git a/sgl-kernel/version.py b/sgl-kernel/version.py index 45807e905cc..27fdca497c3 100644 --- a/sgl-kernel/version.py +++ b/sgl-kernel/version.py @@ -1 +1 @@ -__version__ = "0.0.2.post20" +__version__ = "0.0.3"