-
Notifications
You must be signed in to change notification settings - Fork 805
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[kernel] port rope cuda kernel to sgl-kernel (#2993)
Co-authored-by: Yineng Zhang <[email protected]>
- Loading branch information
Showing
8 changed files
with
255 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -222,3 +222,6 @@ work_dirs/ | |
compile_commands.json | ||
|
||
*.iml | ||
|
||
# VSCode | ||
.vscode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
// Reference: https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu | ||
|
||
#include <ATen/cuda/CUDAContext.h> | ||
#include <c10/cuda/CUDAGuard.h> | ||
#include <torch/all.h> | ||
|
||
template <typename scalar_t, bool IS_NEOX> | ||
inline __device__ void apply_token_rotary_embedding(scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr, | ||
const scalar_t* __restrict__ sin_ptr, int rot_offset, | ||
int embed_dim) { | ||
int x_index, y_index; | ||
scalar_t cos, sin; | ||
if (IS_NEOX) { | ||
// GPT-NeoX style rotary embedding. | ||
x_index = rot_offset; | ||
y_index = embed_dim + rot_offset; | ||
cos = __ldg(cos_ptr + x_index); | ||
sin = __ldg(sin_ptr + x_index); | ||
} else { | ||
// GPT-J style rotary embedding. | ||
x_index = 2 * rot_offset; | ||
y_index = 2 * rot_offset + 1; | ||
cos = __ldg(cos_ptr + x_index / 2); | ||
sin = __ldg(sin_ptr + x_index / 2); | ||
} | ||
|
||
const scalar_t x = arr[x_index]; | ||
const scalar_t y = arr[y_index]; | ||
arr[x_index] = x * cos - y * sin; | ||
arr[y_index] = y * cos + x * sin; | ||
} | ||
|
||
template <typename scalar_t, bool IS_NEOX> | ||
inline __device__ void apply_rotary_embedding(scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, | ||
// head_size] or [num_tokens, num_heads, | ||
// head_size] | ||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, | ||
// head_size] or [num_tokens, num_kv_heads, | ||
// head_size] | ||
const scalar_t* cache_ptr, const int head_size, const int num_heads, | ||
const int num_kv_heads, const int rot_dim, const int token_idx, | ||
const int64_t query_stride, const int64_t key_stride) { | ||
const int embed_dim = rot_dim / 2; | ||
const scalar_t* cos_ptr = cache_ptr; | ||
const scalar_t* sin_ptr = cache_ptr + embed_dim; | ||
|
||
const int nq = num_heads * embed_dim; | ||
for (int i = threadIdx.x; i < nq; i += blockDim.x) { | ||
const int head_idx = i / embed_dim; | ||
const int64_t token_head = token_idx * query_stride + head_idx * head_size; | ||
const int rot_offset = i % embed_dim; | ||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); | ||
} | ||
|
||
const int nk = num_kv_heads * embed_dim; | ||
for (int i = threadIdx.x; i < nk; i += blockDim.x) { | ||
const int head_idx = i / embed_dim; | ||
const int64_t token_head = token_idx * key_stride + head_idx * head_size; | ||
const int rot_offset = i % embed_dim; | ||
apply_token_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); | ||
} | ||
} | ||
|
||
template <typename scalar_t, bool IS_NEOX> | ||
__global__ void rotary_embedding_kernel(const int64_t* __restrict__ positions, // [batch_size, seq_len] or | ||
// [num_tokens] | ||
scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, | ||
// head_size] or [num_tokens, num_heads, | ||
// head_size] | ||
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, | ||
// head_size] or [num_tokens, num_kv_heads, | ||
// head_size] | ||
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // | ||
// 2] | ||
const int rot_dim, const int64_t query_stride, const int64_t key_stride, | ||
const int num_heads, const int num_kv_heads, const int head_size) { | ||
// Each thread block is responsible for one token. | ||
const int token_idx = blockIdx.x; | ||
int64_t pos = positions[token_idx]; | ||
const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; | ||
|
||
apply_rotary_embedding<scalar_t, IS_NEOX>(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, | ||
token_idx, query_stride, key_stride); | ||
} | ||
|
||
void rotary_embedding(torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] | ||
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or | ||
// [num_tokens, num_heads * head_size] | ||
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or | ||
// [num_tokens, num_kv_heads * head_size] | ||
int64_t head_size, | ||
torch::Tensor& cos_sin_cache, // [max_position, rot_dim] | ||
bool is_neox) { | ||
int64_t num_tokens = query.numel() / query.size(-1); | ||
int rot_dim = cos_sin_cache.size(1); | ||
int num_heads = query.size(-1) / head_size; | ||
int num_kv_heads = key.size(-1) / head_size; | ||
int64_t query_stride = query.stride(-2); | ||
int64_t key_stride = key.stride(-2); | ||
|
||
dim3 grid(num_tokens); | ||
dim3 block(std::min<int64_t>(num_heads * rot_dim / 2, 512)); | ||
const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); | ||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
AT_DISPATCH_FLOATING_TYPES_AND2( | ||
at::ScalarType::BFloat16, at::ScalarType::Half, query.scalar_type(), "rotary_embedding", [&] { | ||
if (is_neox) { | ||
rotary_embedding_kernel<scalar_t, true> | ||
<<<grid, block, 0, stream>>>(positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(), | ||
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), rot_dim, | ||
query_stride, key_stride, num_heads, num_kv_heads, head_size); | ||
} else { | ||
rotary_embedding_kernel<scalar_t, false> | ||
<<<grid, block, 0, stream>>>(positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(), | ||
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), rot_dim, | ||
query_stride, key_stride, num_heads, num_kv_heads, head_size); | ||
} | ||
}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
from typing import Optional, Tuple | ||
|
||
import torch | ||
from vllm.model_executor.layers.rotary_embedding import ( | ||
RotaryEmbedding as VLLMRotaryEmbedding, | ||
) | ||
|
||
|
||
class SGLRotaryEmbedding(VLLMRotaryEmbedding): | ||
|
||
def forward_cuda( | ||
self, | ||
positions: torch.Tensor, | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
offsets: Optional[torch.Tensor] = None, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
from sgl_kernel import rotary_embedding | ||
|
||
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) | ||
|
||
rotary_embedding( | ||
positions, | ||
query, | ||
key, | ||
self.head_size, | ||
self.cos_sin_cache, | ||
self.is_neox_style, | ||
) | ||
return query, key | ||
|
||
|
||
# Compare the output of SGLRotaryEmbedding's forward_cuda with VLLMRotaryEmbedding's forward_native | ||
|
||
|
||
def test_rotary_embedding(): | ||
# Test case 1: FP32 | ||
def run_test( | ||
head_size, | ||
rotary_dim, | ||
max_position, | ||
base, | ||
is_neox_style, | ||
dtype, | ||
batch_size, | ||
seq_len, | ||
num_heads, | ||
test_name, | ||
): | ||
print(f"\nRunning {test_name}...") | ||
# Initialize both implementations | ||
sgl_rope = SGLRotaryEmbedding( | ||
head_size, rotary_dim, max_position, base, is_neox_style, dtype | ||
).to("cuda") | ||
vllm_rope = VLLMRotaryEmbedding( | ||
head_size, rotary_dim, max_position, base, is_neox_style, dtype | ||
).to("cuda") | ||
|
||
# Regular forward pass | ||
positions = torch.arange(seq_len, device="cuda").repeat(batch_size) | ||
query = torch.randn( | ||
batch_size * seq_len, num_heads * head_size, device="cuda", dtype=dtype | ||
) | ||
key = torch.randn( | ||
batch_size * seq_len, num_heads * head_size, device="cuda", dtype=dtype | ||
) | ||
|
||
# Make copies for both implementations | ||
query_sgl = query.clone() | ||
key_sgl = key.clone() | ||
query_vllm = query.clone() | ||
key_vllm = key.clone() | ||
|
||
# Run both implementations | ||
query_sgl_out, key_sgl_out = sgl_rope.forward_cuda( | ||
positions, query_sgl, key_sgl | ||
) | ||
query_vllm_out, key_vllm_out = vllm_rope.forward_native( | ||
positions, query_vllm, key_vllm | ||
) | ||
|
||
# Compare outputs | ||
torch.testing.assert_close(query_sgl_out, query_vllm_out, rtol=1e-3, atol=1e-3) | ||
torch.testing.assert_close(key_sgl_out, key_vllm_out, rtol=1e-3, atol=1e-3) | ||
|
||
print(f"{test_name} passed!") | ||
|
||
# Test Case 1: FP32 with larger dimensions | ||
run_test( | ||
head_size=128, | ||
rotary_dim=64, | ||
max_position=4096, | ||
base=10000, | ||
is_neox_style=True, | ||
dtype=torch.float32, | ||
batch_size=4, | ||
seq_len=32, | ||
num_heads=8, | ||
test_name="FP32 Test", | ||
) | ||
|
||
# Test Case 2: BF16 with smaller dimensions | ||
run_test( | ||
head_size=64, | ||
rotary_dim=32, | ||
max_position=2048, | ||
base=8000, | ||
is_neox_style=True, | ||
dtype=torch.bfloat16, | ||
batch_size=2, | ||
seq_len=16, | ||
num_heads=4, | ||
test_name="BF16 Test", | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_rotary_embedding() |