-
Notifications
You must be signed in to change notification settings - Fork 863
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
147 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
140 changes: 140 additions & 0 deletions
140
sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
// Adapted from https://github.com/flashinfer-ai/flashinfer/blob/v0.1.6/include/flashinfer/norm.cuh | ||
// and https://github.com/flashinfer-ai/flashinfer/blob/v0.1.6/python/csrc/norm.cu | ||
// TODO(zhyncs): tmp fix, v0.1.6 enables SGLang e2e to pass CIs unlike v0.2.0 | ||
|
||
#include <ATen/cuda/CUDAContext.h> | ||
|
||
#include <flashinfer/math.cuh> | ||
#include <flashinfer/utils.cuh> | ||
#include <flashinfer/vec_dtypes.cuh> | ||
#include <numeric> | ||
|
||
#include "utils.h" | ||
|
||
using namespace flashinfer; | ||
|
||
template <uint32_t VEC_SIZE, typename T> | ||
__global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ residual, T* __restrict__ weight, | ||
const uint32_t d, float eps) { | ||
const uint32_t bx = blockIdx.x; | ||
const uint32_t tx = threadIdx.x, ty = threadIdx.y; | ||
constexpr uint32_t warp_size = 32; | ||
const uint32_t num_warps = blockDim.y; | ||
const uint32_t thread_id = tx + ty * warp_size; | ||
const uint32_t num_threads = num_warps * warp_size; | ||
const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads); | ||
extern __shared__ float smem[]; | ||
|
||
float sum_sq = 0.f; | ||
|
||
for (uint32_t i = 0; i < rounds; i++) { | ||
vec_t<T, VEC_SIZE> input_vec; | ||
input_vec.fill(0.f); | ||
vec_t<T, VEC_SIZE> residual_vec; | ||
residual_vec.fill(0.f); | ||
if ((i * num_threads + thread_id) * VEC_SIZE < d) { | ||
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); | ||
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); | ||
} | ||
#pragma unroll | ||
for (uint32_t j = 0; j < VEC_SIZE; j++) { | ||
float x = float(input_vec[j]); | ||
x += float(residual_vec[j]); | ||
sum_sq += x * x; | ||
residual_vec[j] = (T)x; | ||
} | ||
if ((i * num_threads + thread_id) * VEC_SIZE < d) { | ||
residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); | ||
} | ||
} | ||
|
||
// first, warp reduce sum | ||
#pragma unroll | ||
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) { | ||
sum_sq += math::shfl_xor_sync(sum_sq, offset); | ||
} | ||
|
||
smem[ty] = sum_sq; | ||
__syncthreads(); | ||
// then, cross warp reduce sum using only the first warp | ||
if (ty == 0) { | ||
sum_sq = (tx < num_warps) ? smem[tx] : 0.f; | ||
#pragma unroll | ||
for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) { | ||
sum_sq += math::shfl_xor_sync(sum_sq, offset); | ||
} | ||
smem[0] = sum_sq; | ||
} | ||
__syncthreads(); | ||
|
||
float rms_rcp = math::rsqrt(smem[0] / float(d) + eps); | ||
|
||
for (uint32_t i = 0; i < rounds; i++) { | ||
vec_t<T, VEC_SIZE> input_vec; | ||
vec_t<T, VEC_SIZE> weight_vec; | ||
vec_t<T, VEC_SIZE> residual_vec; | ||
input_vec.fill(0.f); | ||
weight_vec.fill(0.f); | ||
residual_vec.fill(0.f); | ||
if ((i * num_threads + thread_id) * VEC_SIZE < d) { | ||
input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); | ||
weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); | ||
residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); | ||
} | ||
#pragma unroll | ||
for (uint32_t j = 0; j < VEC_SIZE; j++) { | ||
input_vec[j] = float(residual_vec[j]) * rms_rcp * float(weight_vec[j]); | ||
} | ||
if ((i * num_threads + thread_id) * VEC_SIZE < d) { | ||
input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); | ||
} | ||
} | ||
} | ||
|
||
template <typename T> | ||
cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d, float eps = 1e-5, | ||
cudaStream_t stream = 0) { | ||
const uint32_t vec_size = std::gcd(16 / sizeof(T), d); | ||
|
||
const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size); | ||
const uint32_t num_warps = ceil_div(block_size, 32); | ||
dim3 nblks(batch_size); | ||
dim3 nthrs(32, num_warps); | ||
const uint32_t smem_size = num_warps * sizeof(float); | ||
void* args[] = {&input, &residual, &weight, &d, &eps}; | ||
|
||
DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { | ||
auto kernel = FusedAddRMSNormKernel<VEC_SIZE, T>; | ||
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); | ||
}); | ||
|
||
return cudaSuccess; | ||
} | ||
|
||
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps) { | ||
CHECK_INPUT(input); | ||
CHECK_INPUT(residual); | ||
CHECK_INPUT(weight); | ||
auto device = input.device(); | ||
CHECK_EQ(residual.device(), device); | ||
CHECK_EQ(weight.device(), device); | ||
CHECK_DIM(2, input); // input: (batch_size, hidden_size) | ||
CHECK_DIM(2, residual); // residual: (batch_size, hidden_size) | ||
CHECK_DIM(1, weight); // weight: (hidden_size) | ||
CHECK_EQ(input.size(0), residual.size(0)); | ||
CHECK_EQ(input.size(1), residual.size(1)); | ||
CHECK_EQ(input.size(1), weight.size(0)); | ||
unsigned int batch_size = input.size(0); | ||
unsigned int hidden_size = input.size(1); | ||
|
||
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream(); | ||
// support float16, bfloat16 and float32 | ||
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { | ||
cudaError_t status = | ||
FusedAddRMSNorm(static_cast<c_type*>(input.data_ptr()), static_cast<c_type*>(residual.data_ptr()), | ||
static_cast<c_type*>(weight.data_ptr()), batch_size, hidden_size, eps, torch_current_stream); | ||
TORCH_CHECK(status == cudaSuccess, | ||
"FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status))); | ||
return true; | ||
}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters