diff --git a/CMakeLists.txt b/CMakeLists.txt index aaab472eef143..b30b7c180d89a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -496,6 +496,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP") "csrc/rocm/attention.cu" "csrc/rocm/custom_kernels.cu" "csrc/rocm/fused_kernels.cu" + "csrc/rocm/fused_rope_and_reshape_cache.cu" "csrc/rocm/custom.cu") define_gpu_extension_target( diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index efda714f53c6c..a35f9b8e42ee0 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -1266,29 +1266,29 @@ void paged_attention( const std::string& kv_cache_dtype, double k_scale, double v_scale, const c10::optional& fp8_out_scale, int64_t partition_size) { const int head_size = query.size(2); - if (kv_cache_dtype == "auto") { - if (query.dtype() == at::ScalarType::Half) { - CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, _Float16, - vllm::Fp8KVCacheDataType::kAuto); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, __hip_bfloat16, - vllm::Fp8KVCacheDataType::kAuto); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } - } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") { - if (query.dtype() == at::ScalarType::Half) { - CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t, - vllm::Fp8KVCacheDataType::kFp8E4M3); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t, - vllm::Fp8KVCacheDataType::kFp8E4M3); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } - } else { - TORCH_CHECK(false, "Unsupported KV cache dtype: ", kv_cache_dtype); - } + // if (kv_cache_dtype == "auto") { + // if (query.dtype() == at::ScalarType::Half) { + // CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, _Float16, + // vllm::Fp8KVCacheDataType::kAuto); + // } else if (query.dtype() == at::ScalarType::BFloat16) { + // CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, __hip_bfloat16, + // vllm::Fp8KVCacheDataType::kAuto); + // } else { + // TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + // } + // } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") { + // if (query.dtype() == at::ScalarType::Half) { + // CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t, + // vllm::Fp8KVCacheDataType::kFp8E4M3); + // } else if (query.dtype() == at::ScalarType::BFloat16) { + // CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t, + // vllm::Fp8KVCacheDataType::kFp8E4M3); + // } else { + // TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + // } + // } else { + // TORCH_CHECK(false, "Unsupported KV cache dtype: ", kv_cache_dtype); + // } } #undef WARP_SIZE diff --git a/csrc/rocm/fused_rope_and_reshape_cache.cu b/csrc/rocm/fused_rope_and_reshape_cache.cu new file mode 100644 index 0000000000000..fe25980e6f56a --- /dev/null +++ b/csrc/rocm/fused_rope_and_reshape_cache.cu @@ -0,0 +1,377 @@ +#include +#include +#include + +#include "dispatch_utils.h" +#include "attention/attention_dtypes.h" +#ifndef USE_ROCM + #include + #include + #include + #include +#else + #include + #include + #include + #include + #include "quantization/fp8/amd/hip_float8.h" + #include "quantization/fp8/amd/quant_utils.cuh" + +using __nv_bfloat16 = __hip_bfloat16; +using __nv_bfloat162 = __hip_bfloat162; +#endif + +#ifdef USE_ROCM + #include "quantization/fp8/amd/quant_utils.cuh" +#else + #include "quantization/fp8/nvidia/quant_utils.cuh" +#endif + +namespace { + +template +struct __align__(16) vec_t; + +template +__device__ void apply_rope(scalar_t* __restrict__ arr_ptr, + const scalar_t* __restrict__ sin_ptr, + const scalar_t* __restrict__ cos_ptr, + int rot_offset, int embed_dim, + const bool IS_NEOX, + vec_t& __restrict__ out_xvec, + vec_t& __restrict__ out_yvec); + +template +struct __align__(16) vec_t { + scalar_t data[width]; + + __device__ vec_t() = default; + __device__ vec_t(const vec_t& __restrict__ other) { +#pragma unroll + for (int i = 0; i < width; ++i) data[i] = other.data[i]; + } + + __device__ vec_t operator*(const vec_t& __restrict__ other) const { + vec_t tmp{*this}; +#pragma unroll + for (int i = 0; i < width; ++i) tmp.data[i] *= other.data[i]; + return tmp; + } + + __device__ vec_t operator*(const float& scale) const { + vec_t tmp{*this}; +#pragma unroll + for (int i = 0; i < width; ++i) tmp.data[i] *= scale; + return tmp; + } + + __device__ vec_t operator+(const vec_t& __restrict__ other) const { + vec_t tmp{*this}; +#pragma unroll + for (int i = 0; i < width; ++i) tmp.data[i] += other.data[i]; + return tmp; + } + + __device__ vec_t operator-(const vec_t& __restrict__ other) const { + vec_t tmp{*this}; +#pragma unroll + for (int i = 0; i < width; ++i) tmp.data[i] -= other.data[i]; + return tmp; + } + + __device__ vec_t& operator=(const vec_t& __restrict__ other) { +#pragma unroll + for (int i = 0; i < width; ++i) data[i] = other.data[i]; + return *this; + } + + __device__ vec_t& operator+=(const vec_t& __restrict__ other) { +#pragma unroll + for (int i = 0; i < width; ++i) data[i] += other.data[i]; + return *this; + } + + __device__ scalar_t& operator [](const size_t& idx) { + return data[idx]; + } + + __device__ scalar_t operator [](const size_t& idx) const { + return data[idx]; + } + + friend + __device__ inline void apply_rope( + scalar_t* __restrict__ arr_ptr, + const scalar_t* __restrict__ sin_ptr, + const scalar_t* __restrict__ cos_ptr, + int rot_offset, int embed_dim, + const bool IS_NEOX, + vec_t& __restrict__ out_xvec, + vec_t& __restrict__ out_yvec); +}; + + +template +__device__ inline void apply_rope(scalar_t* __restrict__ arr_ptr, + const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, + int rot_offset, int embed_dim, + const bool IS_NEOX, + vec_t& __restrict__ out_xvec, + vec_t& __restrict__ out_yvec) { + const vec_t xvec = + *reinterpret_cast *>(arr_ptr + rot_offset); + const vec_t yvec = + *reinterpret_cast *>(arr_ptr + embed_dim + rot_offset); + + if (IS_NEOX) { + const vec_t cos = + *reinterpret_cast *>(cos_ptr + rot_offset); + const vec_t sin = + *reinterpret_cast *>(sin_ptr + rot_offset); +#pragma unroll + for (int i = 0; i < width; ++i) { + out_xvec[i] = xvec[i] * cos[i] - yvec[i] * sin[i]; + out_yvec[i] = yvec[i] * cos[i] + xvec[i] * sin[i]; + } + } else { + const vec_t xcos = + *reinterpret_cast *>(cos_ptr + rot_offset / 2); + const vec_t xsin = + *reinterpret_cast *>(sin_ptr + rot_offset / 2); +#pragma unroll + for (int i = 0; i < width / 2; ++i) { + int x_i = 2 * i; + int y_i = 2 * i + 1; + out_xvec[x_i] = xvec[x_i] * xcos[i] - xvec[y_i] * xsin[i]; + out_xvec[y_i] = xvec[y_i] * xcos[i] + xvec[x_i] * xsin[i]; + } + + const vec_t ycos = + *reinterpret_cast *>(cos_ptr + (embed_dim + rot_offset) / 2); + const vec_t ysin = + *reinterpret_cast *>(sin_ptr + (embed_dim + rot_offset) / 2); + +#pragma unroll + for (int i = 0; i < width / 2; ++i) { + int x_i = 2 * i; + int y_i = 2 * i + 1; + out_yvec[x_i] = yvec[x_i] * ycos[i] - yvec[y_i] * ysin[i]; + out_yvec[y_i] = yvec[y_i] * ycos[i] + yvec[x_i] * ysin[i]; + } + } +} + +template +__device__ void store_value_into_cache( + cache_t* __restrict__ cache, + const int head_idx, const int head_offset, + const int head_size, const int num_kv_heads, + const int64_t block_idx, const int block_size, + const int64_t block_offset, const int x, + vec_t& val, + const float kv_scale) { + const int x_idx = head_offset / x; + const int x_offset = head_offset % x; + + const int64_t tgt_idx = + block_idx * num_kv_heads * (head_size / x) * block_size * x + + head_idx * (head_size / x) * block_size * x + + x_idx * block_size * x + + block_offset * x + x_offset; + + if constexpr (kv_dt == vllm::Fp8KVCacheDataType::kAuto) { + *reinterpret_cast *>(cache + tgt_idx) = val; + } else { + *reinterpret_cast *>(cache + tgt_idx) = + vllm::fp8::scaled_convert(val, kv_scale); + } +} + +} // anonymous namespace + +namespace vllm { + +template +__global__ void __launch_bounds__ (512) fused_rotary_embedding_and_reshape_cache_kernel_vec( + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or + // [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int64_t query_stride, const int key_stride, const int value_stride, + const int num_heads, const int num_kv_heads, const int head_size, + const int rot_dim, const int block_size, const int x, + const float k_scale, const float v_scale) { + + const int width = 16 / sizeof(scalar_t); + using vec_t = vec_t; + + // Each thread block is responsible for "width" tokens. + const int token_idx = blockIdx.x * blockDim.y + threadIdx.y; + int64_t pos = positions[token_idx]; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + + const int embed_dim = rot_dim / 2; + const scalar_t* cos_ptr = cache_ptr; + const scalar_t* sin_ptr = cache_ptr + embed_dim; + + const int nq = num_heads * embed_dim / width; + for (int i = threadIdx.x; i < nq; i += blockDim.x) { + const int head_idx = (i * width) / embed_dim; + const int rot_offset = (i * width) % embed_dim; + const int64_t token_head = token_idx * query_stride + head_idx * head_size; + + vec_t& out_xvec = *reinterpret_cast(query + token_head + rot_offset); + vec_t& out_yvec = *reinterpret_cast(query + token_head + embed_dim + rot_offset); + + apply_rope(query + token_head, cos_ptr, sin_ptr, rot_offset, + embed_dim, IS_NEOX, out_xvec, out_yvec); + } + + const int64_t slot_idx = block_size == 0 ? 0: slot_mapping[token_idx]; + if (slot_idx < 0) { + // Padding token that should be ignored. + return; + } + + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + + const int nk = num_kv_heads * embed_dim / width; + for (int i = threadIdx.x; i < nk; i += blockDim.x) { + const int head_idx = (i * width) / embed_dim; + const int rot_offset = (i * width) % embed_dim; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; + + // FIXME check why do we need to modify key + vec_t& out_xvec = *reinterpret_cast(key + token_head + rot_offset); + vec_t& out_yvec = *reinterpret_cast(key + token_head + embed_dim + rot_offset); + + apply_rope(key + token_head, cos_ptr, sin_ptr, rot_offset, + embed_dim, IS_NEOX, out_xvec, out_yvec); + + if (block_size != 0) { + store_value_into_cache + (key_cache, head_idx, rot_offset, + head_size, num_kv_heads, + block_idx, block_size, block_offset, + x, out_xvec, k_scale); + store_value_into_cache + (key_cache, head_idx, embed_dim + rot_offset, + head_size, num_kv_heads, + block_idx, block_size, block_offset, + x, out_yvec, k_scale); + } + } + + const int nv = num_kv_heads * head_size; + for (int i = threadIdx.x; block_size && i < nv; i += blockDim.x) { + const int head_idx = i / head_size; + const int head_offset = i % head_size; + + const int64_t src_value_idx = token_idx * value_stride + i; + scalar_t tgt_value = value[src_value_idx]; + + const int64_t tgt_idx = + block_idx * num_kv_heads * head_size * block_size + + head_idx * head_size * block_size + + head_offset * block_size + + block_offset; + + if constexpr (kv_dt == vllm::Fp8KVCacheDataType::kAuto) { + value_cache[tgt_idx] = tgt_value; + } else { + value_cache[tgt_idx] = + vllm::fp8::scaled_convert(tgt_value, v_scale); + } + } +} + +} // namespace vllm + + + // SRC_DTYPE is the stored data type of qkv. + // CACHE_T is the data type of in KV cache. + // KV_DT is actual type of data in KV cache + // IS_NEOX flag to compute positional encodding. + #define CALL_ROTARY_EMBEDDING_RESHAPE_AND_CACHE(QKV_T, CACHE_T, KV_DT, IS_NEOX) \ + vllm::fused_rotary_embedding_and_reshape_cache_kernel_vec \ + <<>>( \ + reinterpret_cast(query.data_ptr()), \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + reinterpret_cast(cos_sin_cache.data_ptr()), \ + positions.data_ptr(), \ + slot_mapping.data_ptr(), \ + query_stride, key_stride, value_stride, \ + num_heads, num_kv_heads, head_size, \ + rot_dim, block_size, x, k_scale, v_scale); + + + // The following macro is used to dispatch the conversion function based on + // the data type of the key and value cache. The FN is a macro that calls a + // function with template. + #define DISPATCH_ROPE_BY_KV_CACHE_DTYPE(SRC_DTYPE, CACHE_T, IS_NEOX, FN) \ + if (CACHE_T == "auto") { \ + if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(c10::Half, c10::Half, vllm::Fp8KVCacheDataType::kAuto, IS_NEOX); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto, IS_NEOX); \ + } else { \ + TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } + + +void fused_rotary_embedding_and_reshape_cache( + torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] + torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or + // [num_tokens, num_heads * head_size] + torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] + torch::Tensor& value, // [batch_size, seq_len, num_heads * head_size] or + // [num_tokens, num_heads * head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + const std::string& kv_cache_dtype, + torch::Tensor& cos_sin_cache, // [max_position, rot_dim] + torch::Tensor& slot_mapping, // [num_tokens] + const int64_t head_size, + const double k_scale, + const double v_scale, + bool is_neox) { + + assert(query.scalar_type() == key.scalar_type()); + + int64_t num_tokens = query.numel() / query.size(-1); + int rot_dim = cos_sin_cache.size(1); + int num_heads = query.size(-1) / head_size; + int num_kv_heads = key.size(-1) / head_size; + int64_t query_stride = query.stride(-2); + int64_t key_stride = key.stride(-2); + int64_t value_stride = value.stride(-2); + + int block_size = key_cache.size(3); + int x = key_cache.size(4); + + auto width = 16 / key.element_size(); + dim3 grid(num_tokens / width); + dim3 block(std::min(num_heads * rot_dim / 2, 512) / width, width); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (is_neox) { + DISPATCH_ROPE_BY_KV_CACHE_DTYPE(key.scalar_type(), + kv_cache_dtype, true, CALL_ROTARY_EMBEDDING_RESHAPE_AND_CACHE) + } else { + DISPATCH_ROPE_BY_KV_CACHE_DTYPE(key.scalar_type(), + kv_cache_dtype, false, CALL_ROTARY_EMBEDDING_RESHAPE_AND_CACHE) + } +} \ No newline at end of file diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index d825686a6ced4..75bd9d59e338d 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -20,3 +20,12 @@ void paged_attention( const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, double k_scale, double v_scale, const c10::optional& fp8_out_scale, int64_t partition_size); + +void fused_rotary_embedding_and_reshape_cache( + torch::Tensor& positions, torch::Tensor& query, + torch::Tensor& key, torch::Tensor& value, + torch::Tensor& key_cache, torch::Tensor& value_cache, + const std::string& kv_cache_dtype, + torch::Tensor& cos_sin_cache,torch::Tensor& slot_mapping, + const int64_t head_size, const double k_scale, const double v_scale, + bool is_neox); diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 6402a3b2b2b60..8650238f87863 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -43,6 +43,23 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { "wvSpltK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in," " int CuCount) -> ()"); rocm_ops.impl("wvSpltK", torch::kCUDA, &wvSpltK); + rocm_ops.def( + "fused_rotary_embedding_and_reshape_cache(" + " Tensor positions," + " Tensor! query," + " Tensor! key," + " Tensor value," + " Tensor! key_cache," + " Tensor! value_cache," + " str kv_cache_dtype," + " Tensor cos_sin_cache," + " Tensor slot_mapping," + " int head_size," + " float k_scale," + " float v_scale," + " bool is_neox) -> ()"); + rocm_ops.impl("fused_rotary_embedding_and_reshape_cache", + torch::kCUDA, &fused_rotary_embedding_and_reshape_cache); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/tests/kernels/test_fused_rope_and_reshape_cache.py b/tests/kernels/test_fused_rope_and_reshape_cache.py new file mode 100644 index 0000000000000..203e48f829bc4 --- /dev/null +++ b/tests/kernels/test_fused_rope_and_reshape_cache.py @@ -0,0 +1,248 @@ +import random +from itertools import accumulate, product +from time import perf_counter +from typing import List, Optional + +import pytest +import torch + +from vllm import _custom_ops as ops +from vllm.config import CacheConfig +from vllm.model_executor.layers.rotary_embedding import _ROPE_DICT, get_rope + +from .allclose_default import get_default_atol, get_default_rtol + +NUM_TOKENS = [42] # Arbitrary values for testing + +# Arbitrary values for testing +# don't make it too large. e.g. [1024, 36000] will OOM +NUM_BLOCKS = [1024] +# We assume fp8 is always enabled for testing. +KV_CACHE_DTYPE = ["auto", "fp8"] + +BLOCK_SIZES = [16] + +DTYPES = [torch.bfloat16, torch.float16] # FIXME torch.half +HEAD_SIZES = [128] +NUM_HEADS = [32] # Arbitrary values for testing +NUM_KV_HEADS = [8] # Arbitrary values for testing +ROTARY_DIMS = [None] # None means rotary dim == head size FIXME: 32 +BATCH_SIZES = [8] # Arbitrary values for testing +SEQ_LENS = [1024] # Arbitrary values for testing +IS_NEOX_STYLE = [True, False] +SEEDS = [0] +CUDA_DEVICES = [0] # FIXME + + +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("seq_len", SEQ_LENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_fused_rotary_embedding_with_no_cache( + batch_size: int, + seq_len: int, + num_heads: int, + head_size: int, + num_kv_heads: int, + rotary_dim: Optional[int], + dtype: torch.dtype, + is_neox_style: bool, + seed: int, + device: str, + max_position: int = 8192, + base: int = 10000, +) -> None: + + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + + if rotary_dim is None: + rotary_dim = head_size + + _ROPE_DICT.clear() + cache_config = CacheConfig(16, 1.0, 1, "auto") + rope = get_rope(head_size, rotary_dim, max_position, base, + is_neox_style, rope_scaling={"type": "llama3", + "low_freq_factor": 1.0, + "high_freq_factor": 2.0, + "original_max_position_embeddings": 1024}, + fused_with_kv_cache_op=True, + cache_config = cache_config) + rope = rope.to(dtype=dtype) + + #------------------Simulate------------------------------ + + positions = torch.randint(0, max_position, (batch_size, seq_len)) + query = torch.randn(batch_size, + seq_len, + num_heads * head_size, + dtype=dtype) + key = torch.randn(batch_size, + seq_len, + num_kv_heads * head_size, + dtype=dtype) + value = torch.randn_like(key) + + ref_query, ref_key = rope.forward_native(positions, query, key) + + #----------------------Actual-Run------------------------ + + kv_scale = 1.0 + key_cache, value_cache = torch.empty(0, 0, 0, 0, 0), torch.empty(0, 0, 0, 0) + slot_mapping = torch.empty(0, dtype=torch.long) + + rope.forward( + positions, query, key, value, + key_cache, value_cache, + slot_mapping, kv_scale, kv_scale) + + #----------------------Assert---------------------------- + assert torch.allclose(query, ref_query, atol=0.001, rtol=0.1) + assert torch.allclose(key, ref_key, atol=0.001, rtol=0.1) + + + + +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("seq_len", SEQ_LENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_fused_rotary_embedding_with_reshape_cache( + kv_cache_factory, + block_size: int, + num_blocks: int, + kv_cache_dtype: str, + batch_size: int, + seq_len: int, + num_heads: int, + num_kv_heads: int, + head_size: int, + rotary_dim: Optional[int], + dtype: torch.dtype, + is_neox_style: bool, + seed: int, + device: str, + max_position: int = 8192, + base: int = 10000, +) -> None: + + torch.set_printoptions(precision=4) + torch.set_printoptions(threshold=5000) + + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + + if rotary_dim is None: + rotary_dim = head_size + + _ROPE_DICT.clear() + cache_config = CacheConfig(block_size, 1.0, 1, kv_cache_dtype) + rope = get_rope(head_size, rotary_dim, max_position, base, + is_neox_style, + rope_scaling={"rope_type": "llama3", + "factor": 1.0, + "low_freq_factor": 1.0, + "high_freq_factor": 2.0, + "original_max_position_embeddings": 1024}, + fused_with_kv_cache_op=True, + cache_config = cache_config) + rope = rope.to(dtype=dtype) + + # Create a random slot mapping. + num_slots = block_size * num_blocks + slot_mapping = random.sample(range(num_slots), batch_size * seq_len) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long) + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1, + num_kv_heads, head_size, + kv_cache_dtype, dtype, seed, + device) + key_cache, value_cache = key_caches[0], value_caches[0] + + # print(f"{key_cache.shape=}") + # print(f"{value_cache.shape=}") + + # Clone the KV caches. + cloned_key_cache = key_cache.clone() + cloned_value_cache = value_cache.clone() + + # Using default kv_scale + kv_scale = 1.0 + + positions = torch.randint(0, max_position, (batch_size, seq_len)) + # positions = positions.fill_(1) + query = torch.randn(batch_size, + seq_len, + num_heads * head_size, + dtype=dtype) + key = torch.randn(batch_size, + seq_len, + num_kv_heads * head_size, + dtype=dtype) + value = torch.randn_like(key) + + #------------------Simulate------------------------------ + time_start = perf_counter() + + ref_query, ref_key = rope.forward_cuda(positions, query.clone(), key.clone()) + ops.reshape_and_cache(ref_key.view(-1, num_kv_heads, head_size), + value.view(-1, num_kv_heads, head_size), + cloned_key_cache, + cloned_value_cache, slot_mapping, + kv_cache_dtype, kv_scale, kv_scale) + + time_end = perf_counter() + # report the duration + print(f'Non fused call {(time_end - time_start) * 1000} ms') + + #----------------------Actual-Run------------------------ + time_start = perf_counter() + + rope.forward( + positions, query, key, value, + key_cache, value_cache, + slot_mapping, kv_scale, kv_scale) + + time_end = perf_counter() + # report the duration + print(f'Fused run {(time_end - time_start) * 1000} ms') + + #----------------------Assert---------------------------- + atol, rtol = 1e-05, 1e-05 + if kv_cache_dtype == "fp8": + atol, rtol = 0.001, 0.001 + + + # print(f"{query[0, 0, 0:128].view(-1, 8)}") + # print(f"{query[0, 0, 128:256].view(-1, 8)}") + # print(f"{query[0, 0, 256:384].view(-1, 8)}") + # print(f"{query[0, 0, 384:512].view(-1, 8)}") + # print(f"{query[0, 0, 512:640].view(-1, 8)}") + # print(f"{ref_query[0, 0, 0:128].view(-1, 8)}") + + assert torch.allclose(query, ref_query, atol=atol, rtol=rtol) + assert torch.allclose(key, ref_key, atol=atol, rtol=rtol) + assert torch.allclose(key_cache, cloned_key_cache, atol=atol, rtol=rtol) + assert torch.allclose(value_cache, cloned_value_cache, atol=atol, rtol=rtol) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 7d2d87176800c..fe0721a35d6f3 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -589,7 +589,8 @@ def forward( key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_kv_heads, self.head_size) - if key is not None and value is not None: + if key is not None and value is not None and \ + not envs.VLLM_FUSED_ROPE_W_KV_CACHE: # Reshape the input keys and values and store them in the # cache. If kv_cache is not provided, the new key and value # tensors are not cached. This happens during the initial diff --git a/vllm/envs.py b/vllm/envs.py index 65614422f126a..8d1f964b9e00d 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -82,6 +82,7 @@ VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS: int = 1 VLLM_MOE_PADDING: bool = False VLLM_FP8_PADDING: bool = True + VLLM_FUSED_ROPE_W_KV_CACHE: bool = False def get_default_cache_root(): @@ -535,6 +536,10 @@ def get_default_config_root(): # Pad the weight for moe kernel or not "VLLM_FP8_PADDING": lambda: bool(int(os.getenv("VLLM_FP8_PADDING", "1"))), + + # version of ROPE that places output to KV cache + "VLLM_FUSED_ROPE_W_KV_CACHE": + lambda: bool(int(os.getenv("VLLM_FUSED_ROPE_W_KV_CACHE", "0"))), } # end-env-vars-definition diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 2158ad3339673..7fbcd12e82471 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -27,9 +27,12 @@ import torch import torch.nn as nn +from vllm import _custom_ops as ops +from vllm.config import CacheConfig from vllm.model_executor.custom_op import CustomOp + def _rotate_neox(x: torch.Tensor) -> torch.Tensor: x1 = x[..., :x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2:] @@ -713,6 +716,48 @@ def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: return new_freqs + +class FusedLlama3RotaryEmbedding(Llama3RotaryEmbedding): + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + scaling_factor: float, + low_freq_factor: float, + high_freq_factor: float, + orig_max_position: int, + cache_config: Optional[CacheConfig] = None, + ) -> None: + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype, scaling_factor, low_freq_factor, high_freq_factor, + orig_max_position) + self.kv_cache_dtype = "auto" if cache_config is None else cache_config.cache_dtype + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + key_scale: float, + value_scale: float, + ) -> None: + # print(f"{self.cos_sin_cache[1, 0:64].view(-1, 8)=}") + # print(f"{self.cos_sin_cache[1, 64:128].view(-1, 8)=}") + # print(f"{query[0, 0, :].view(-1, 128)[0, :].view(-1, 8)=}") + torch.ops._rocm_C.fused_rotary_embedding_and_reshape_cache( + positions, query, key, value, key_cache, value_cache, + self.kv_cache_dtype, self.cos_sin_cache, slot_mapping, + self.head_size, key_scale, value_scale, self.is_neox_style) + class MRotaryEmbedding(RotaryEmbedding): """Rotary Embedding with Multimodal Sections.""" @@ -898,6 +943,8 @@ def get_rope( rope_scaling: Optional[Dict[str, Any]] = None, dtype: Optional[torch.dtype] = None, partial_rotary_factor: float = 1.0, + fused_with_kv_cache_op: Optional[bool] = False, + cache_config: Optional[CacheConfig] = None, ) -> RotaryEmbedding: if dtype is None: dtype = torch.get_default_dtype() @@ -929,12 +976,17 @@ def get_rope( high_freq_factor = rope_scaling["high_freq_factor"] original_max_position = rope_scaling[ "original_max_position_embeddings"] - rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim, - max_position, base, - is_neox_style, dtype, - scaling_factor, low_freq_factor, - high_freq_factor, - original_max_position) + if fused_with_kv_cache_op: + rotary_emb = FusedLlama3RotaryEmbedding(head_size, rotary_dim, + max_position, base, is_neox_style, dtype, + scaling_factor, low_freq_factor, + high_freq_factor, original_max_position, + cache_config = cache_config) + else: + rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim, + max_position, base, is_neox_style, dtype, + scaling_factor, low_freq_factor, + high_freq_factor, original_max_position) elif scaling_type == "default": if "mrope_section" in rope_scaling: rotary_emb = MRotaryEmbedding( diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 5125b0b1530a1..63e854bd12c73 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -27,9 +27,9 @@ from torch import nn from transformers import LlamaConfig -import vllm.envs as envs -from vllm import _custom_ops as ops +from vllm import _custom_ops as ops, envs from vllm.attention import Attention, AttentionMetadata +from vllm.attention.ops.paged_attn import PagedAttention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, LoRAConfig, PoolerConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, @@ -184,6 +184,8 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, is_neox_style=is_neox_style, + fused_with_kv_cache_op=envs.VLLM_FUSED_ROPE_W_KV_CACHE, + cache_config = cache_config, ) self.attn = Attention( self.num_heads, @@ -208,7 +210,20 @@ def forward( ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q, k = self.rotary_emb(positions, q, k) + + if envs.VLLM_FUSED_ROPE_W_KV_CACHE: + key_cache, value_cache = torch.empty(0, 0, 0, 0, 0), torch.empty(0, 0, 0, 0) + if kv_cache is not None and kv_cache.numel() > 0: + key_cache, value_cache =PagedAttention.split_kv_cache( + kv_cache, self.num_kv_heads, self.head_dim) + + self.rotary_emb(positions, q, k, v, + key_cache, value_cache, + attn_metadata.slot_mapping, + self.attn._k_scale, self.attn._v_scale) + else: + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v,