-
Notifications
You must be signed in to change notification settings - Fork 865
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add sampling_scaling_penalties kernel (#2846)
- Loading branch information
Showing
9 changed files
with
150 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
64 changes: 64 additions & 0 deletions
64
sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
#include <ATen/ATen.h> | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <c10/cuda/CUDAGuard.h> | ||
#include <THC/THCAtomics.cuh> | ||
#include "utils.hpp" | ||
#include "vectorization.cuh" | ||
|
||
template <typename scalar_t> | ||
__global__ void sampling_scaling_penalties_kernel( | ||
const scalar_t* logits, | ||
const scalar_t* scaling_penalties, | ||
scalar_t* output, | ||
const int32_t numel) { | ||
|
||
const int32_t tid = blockIdx.x * blockDim.x + threadIdx.x; | ||
const int32_t stride = blockDim.x * gridDim.x; | ||
|
||
auto const* vectorized_logits = reinterpret_cast<vec4_t<scalar_t> const*>(logits); | ||
auto const* vectorized_penalties = reinterpret_cast<vec4_t<scalar_t> const*>(scaling_penalties); | ||
auto* vectorized_output = reinterpret_cast<vec4_t<scalar_t>*>(output); | ||
|
||
const int32_t num_vec_elems = numel >> 2; | ||
|
||
#pragma unroll 4 | ||
for (int32_t i = tid; i < num_vec_elems; i += stride) { | ||
vec4_t<scalar_t> logits_vec = vectorized_logits[i]; | ||
vec4_t<scalar_t> penalties_vec = vectorized_penalties[i]; | ||
vec4_t<scalar_t> out_vec; | ||
|
||
out_vec.x = logits_vec.x > 0 ? logits_vec.x / penalties_vec.x : logits_vec.x * penalties_vec.x; | ||
out_vec.y = logits_vec.y > 0 ? logits_vec.y / penalties_vec.y : logits_vec.y * penalties_vec.y; | ||
out_vec.z = logits_vec.z > 0 ? logits_vec.z / penalties_vec.z : logits_vec.z * penalties_vec.z; | ||
out_vec.w = logits_vec.w > 0 ? logits_vec.w / penalties_vec.w : logits_vec.w * penalties_vec.w; | ||
|
||
vectorized_output[i] = out_vec; | ||
} | ||
|
||
const int32_t start_idx = num_vec_elems * 4; | ||
for (int32_t i = start_idx + tid; i < numel; i += stride) { | ||
scalar_t logit = logits[i]; | ||
scalar_t penalty = scaling_penalties[i]; | ||
output[i] = logit > 0 ? logit / penalty : logit * penalty; | ||
} | ||
} | ||
|
||
torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torch::Tensor& scaling_penalties) { | ||
auto output = torch::empty_like(logits); | ||
const auto numel = logits.numel(); | ||
const int threads = 512; | ||
|
||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
|
||
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, | ||
logits.scalar_type(), "sampling_scaling_penalties_kernel", ([&] { | ||
const int blocks = (numel + threads * 4 - 1) / (threads * 4); | ||
sampling_scaling_penalties_kernel<scalar_t><<<blocks, threads, 0, stream>>>( | ||
logits.data_ptr<scalar_t>(), | ||
scaling_penalties.data_ptr<scalar_t>(), | ||
output.data_ptr<scalar_t>(), | ||
numel); | ||
})); | ||
|
||
return output; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
// Adapted from https://github.com/vllm-project/vllm/blob/main/csrc/quantization/vectorization.cuh | ||
#pragma once | ||
/** | ||
* __device__ datatypes vectorized by 4 | ||
*/ | ||
|
||
// Include both AMD and NVIDIA fp8 types to avoid circular import | ||
// TODO(luka/varun) use FP8_TYPE instead after refactoring | ||
#include <c10/util/Float8_e4m3fnuz.h> | ||
#include <c10/util/Float8_e4m3fn.h> | ||
|
||
// Vectorization containers | ||
template <typename scalar_t> | ||
struct __align__(8) vec4_t { | ||
scalar_t x; | ||
scalar_t y; | ||
scalar_t z; | ||
scalar_t w; | ||
}; | ||
|
||
template <typename quant_type_t> | ||
struct __align__(4) q8x4_t { | ||
static_assert(std::is_same_v<quant_type_t, int8_t> || | ||
std::is_same_v<quant_type_t, c10::Float8_e4m3fn> || | ||
std::is_same_v<quant_type_t, c10::Float8_e4m3fnuz>); | ||
quant_type_t x; | ||
quant_type_t y; | ||
quant_type_t z; | ||
quant_type_t w; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import torch | ||
from sgl_kernel import sampling_scaling_penalties | ||
|
||
|
||
def test_sampling_scaling_penalties(): | ||
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 65] | ||
vocab_sizes = [2048, 4096, 8192, 16384, 32768, 32767] | ||
dtypes = [torch.float32, torch.half, torch.bfloat16] | ||
device = torch.device("cuda") | ||
|
||
for dtype in dtypes: | ||
rtol = 1e-3 | ||
atol = 1e-3 | ||
|
||
for bs in batch_sizes: | ||
for vocab_size in vocab_sizes: | ||
logits = torch.randn(bs, vocab_size, device=device, dtype=dtype) | ||
scaling_penalties = ( | ||
torch.rand(bs, vocab_size, device=device, dtype=dtype) + 0.5 | ||
) | ||
|
||
ref_output = torch.where( | ||
logits > 0, logits / scaling_penalties, logits * scaling_penalties | ||
) | ||
|
||
kernel_output = sampling_scaling_penalties(logits, scaling_penalties) | ||
|
||
torch.testing.assert_close( | ||
kernel_output, | ||
ref_output, | ||
rtol=rtol, | ||
atol=atol, | ||
msg=f"Failed for batch_size={bs}, vocab_size={vocab_size}, dtype={dtype}", | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_sampling_scaling_penalties() | ||
print("All tests passed!") |