From 9cccf85d6bff218deede7d1f7fb6724c69d86725 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 2 Dec 2024 18:02:18 +0000 Subject: [PATCH 1/7] wip Signed-off-by: NickLucche --- .../kernels/benchmark_paged_attention.py | 2 + csrc/attention/attention_kernels.cuh | 50 +++++++++---- csrc/attention/paged_attention_v1.cu | 47 +++++++----- csrc/attention/paged_attention_v2.cu | 19 +++-- csrc/cpu/attention.cpp | 10 ++- csrc/cpu/torch_bindings.cpp | 5 +- csrc/ops.h | 6 +- csrc/torch_bindings.cpp | 4 +- tests/kernels/test_attention.py | 75 ++++++++++++------- vllm/_custom_ops.py | 8 +- vllm/attention/backends/blocksparse_attn.py | 1 + vllm/attention/backends/rocm_flash_attn.py | 1 + vllm/attention/backends/xformers.py | 2 + vllm/attention/ops/paged_attn.py | 5 ++ 14 files changed, 157 insertions(+), 78 deletions(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 14eef00b855ac..120b8ffe9c657 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -114,6 +114,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: block_size, max_seq_len, alibi_slopes, + None, # TODO add custom bias kv_cache_dtype, k_scale, v_scale, @@ -134,6 +135,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: block_size, max_seq_len, alibi_slopes, + None, kv_cache_dtype, k_scale, v_scale, diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh index 563e1438f0b01..25de77c324c62 100644 --- a/csrc/attention/attention_kernels.cuh +++ b/csrc/attention/attention_kernels.cuh @@ -104,6 +104,7 @@ __device__ void paged_attention_kernel( const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] + const float* __restrict__ attn_bias, // [num_seqs, num_heads, max_seq_len] const int q_stride, const int kv_block_stride, const int kv_head_stride, const float k_scale, const float v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, @@ -153,6 +154,21 @@ __device__ void paged_attention_kernel( const int kv_head_idx = head_idx / num_queries_per_kv; const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + // TODO check if indexing still makes sense + // seq_len indexes on 'max_seq_lens' dim, + // it's like renaming dim you get attn_bias: seq_len x num_kv_heads x seq_len + // TODO each seq can have different len (seq_lens) but only one bias!! + // NOTE (NickLucche) `max_seq_len` bias values for current sequence and current head + const float* attn_bias_vec = + attn_bias == nullptr + ? nullptr + : attn_bias + seq_idx * num_heads * num_seq_blocks * BLOCK_SIZE + + head_idx * num_seq_blocks * BLOCK_SIZE; + // : attn_bias + seq_idx * num_kv_heads * num_seq_blocks * BLOCK_SIZE + + // const float* attn_bias_vec = attn_bias == nullptr + // ? nullptr + // : attn_bias + seq_idx * num_kv_heads * seq_len + + // kv_head_idx * seq_len; // A vector type to store a part of a key or a query. // The vector size is configured in such a way that the threads in a thread @@ -293,8 +309,12 @@ __device__ void paged_attention_kernel( // This includes a reduction across the threads in the same thread group. float qk = scale * Qk_dot::dot( q_vecs[thread_group_offset], k_vecs); - // Add the ALiBi bias if slopes are given. + // NOTE here each thread adds its own alibi (one per head..) like I am + // sure not the whole group needs to do so Add the ALiBi bias if slopes + // are given. + // TODO mutually exclusive? qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; + qk += (attn_bias_vec != nullptr) ? attn_bias_vec[token_idx] : 0; if (thread_group_offset == 0) { // Store the partial reductions to shared memory. @@ -512,17 +532,18 @@ __global__ void paged_attention_v1_kernel( const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float k_scale, const float v_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + const float* __restrict__ attn_bias, const int q_stride, + const int kv_block_stride, const int kv_head_stride, const float k_scale, + const float v_scale, const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, - max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, - kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, - blocksparse_vert_stride, blocksparse_block_size, + max_num_blocks_per_seq, alibi_slopes, attn_bias, q_stride, + kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, + blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step); } @@ -548,15 +569,16 @@ __global__ void paged_attention_v2_kernel( const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float k_scale, const float v_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + const float* __restrict__ attn_bias, const int q_stride, + const int kv_block_stride, const int kv_head_stride, const float k_scale, + const float v_scale, const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { paged_attention_kernel( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, - block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, - kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, + block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, attn_bias, + q_stride, kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step); } diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu index 27321148f6dda..13a10221db425 100644 --- a/csrc/attention/paged_attention_v1.cu +++ b/csrc/attention/paged_attention_v1.cu @@ -29,20 +29,20 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) -#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ - VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ - ((void*)vllm::paged_attention_v1_kernel), \ - shared_mem_size); \ - vllm::paged_attention_v1_kernel \ - <<>>( \ - out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ - scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ - alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ - k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ - blocksparse_vert_stride, blocksparse_block_size, \ +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ + VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ + ((void*)vllm::paged_attention_v1_kernel), \ + shared_mem_size); \ + vllm::paged_attention_v1_kernel \ + <<>>( \ + out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ + scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, attn_bias_ptr, q_stride, kv_block_stride, \ + kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_head_sliding_step); // TODO(woosuk): Tune NUM_THREADS. @@ -53,8 +53,9 @@ void paged_attention_v1_launcher( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const std::optional& alibi_slopes, float k_scale, - float v_scale, const int tp_rank, const int blocksparse_local_blocks, + const c10::optional& alibi_slopes, + const c10::optional& attn_bias, float k_scale, float v_scale, + const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); @@ -73,7 +74,12 @@ void paged_attention_v1_launcher( alibi_slopes ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; - + const float* attn_bias_ptr = + attn_bias ? reinterpret_cast(attn_bias.value().data_ptr()) + : nullptr; + if (attn_bias_ptr){ + TORCH_CHECK(attn_bias.value().dtype() == torch::kFloat32, "Unsupported bias dtype: ", attn_bias.value().dtype()); + } T* out_ptr = reinterpret_cast(out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); @@ -135,8 +141,8 @@ void paged_attention_v1_launcher( paged_attention_v1_launcher( \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ - seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \ - blocksparse_local_blocks, blocksparse_vert_stride, \ + seq_lens, max_seq_len, alibi_slopes, attn_bias, k_scale, v_scale, \ + tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_block_size, blocksparse_head_sliding_step); #define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ @@ -176,7 +182,8 @@ void paged_attention_v1( torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, - const std::optional& alibi_slopes, + const c10::optional& alibi_slopes, + const c10::optional& attn_bias, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index a453b2243e48c..80e1d7cb962df 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -36,9 +36,9 @@ <<>>( \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ - seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ - kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ - blocksparse_local_blocks, blocksparse_vert_stride, \ + seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, \ + attn_bias_ptr, q_stride, kv_block_stride, kv_head_stride, k_scale, \ + v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_block_size, blocksparse_head_sliding_step); \ vllm::paged_attention_v2_reduce_kernel \ @@ -54,8 +54,9 @@ void paged_attention_v2_launcher( torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const std::optional& alibi_slopes, float k_scale, - float v_scale, const int tp_rank, const int blocksparse_local_blocks, + const c10::optional& alibi_slopes, + const c10::optional& attn_bias, float k_scale, float v_scale, + const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); @@ -74,6 +75,9 @@ void paged_attention_v2_launcher( alibi_slopes ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; + const float* attn_bias_ptr = + attn_bias ? reinterpret_cast(attn_bias.value().data_ptr()) + : nullptr; T* out_ptr = reinterpret_cast(out.data_ptr()); float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); @@ -142,7 +146,7 @@ void paged_attention_v2_launcher( IS_BLOCK_SPARSE>( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ - k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ + attn_bias, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_head_sliding_step); @@ -187,7 +191,8 @@ void paged_attention_v2( torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, - const std::optional& alibi_slopes, + const c10::optional& alibi_slopes, + const c10::optional& attn_bias, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index ef5b14088c63b..eb33c66953a6e 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -459,7 +459,8 @@ void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const std::optional& alibi_slopes, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const c10::optional& attn_bias, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, @@ -467,6 +468,8 @@ void paged_attention_v1( TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); + TORCH_CHECK(!attn_bias.has_value(), + "CPU backend does not support custom attention bias."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", [&] { CPU_KERNEL_GUARD_IN(paged_attention_v1_impl) @@ -781,7 +784,8 @@ void paged_attention_v2( torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const std::optional& alibi_slopes, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const c10::optional& attn_bias, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, @@ -789,6 +793,8 @@ void paged_attention_v2( TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); + TORCH_CHECK(!attn_bias.has_value(), + "CPU backend does not support custom attention bias."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", [&] { CPU_KERNEL_GUARD_IN(paged_attention_v2_impl) diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 74e4d8189d403..3cfa289848e21 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -24,12 +24,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Attention ops // Compute the attention between an input query and the cached keys/values // using PagedAttention. + // TODO attn_bias on cpu ops.def( "paged_attention_v1(" " Tensor! out, Tensor query, Tensor key_cache," " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," - " int max_seq_len, Tensor? alibi_slopes," + " int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias," " str kv_cache_dtype, float k_scale, float v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," @@ -43,7 +44,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! tmp_out, Tensor query, Tensor key_cache," " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," - " int max_seq_len, Tensor? alibi_slopes," + " int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias," " str kv_cache_dtype, float k_scale, float v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," diff --git a/csrc/ops.h b/csrc/ops.h index 5a194a0dd3654..f2ad92074f446 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -33,7 +33,8 @@ void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const std::optional& alibi_slopes, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const c10::optional& attn_bias, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, @@ -44,7 +45,8 @@ void paged_attention_v2( torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const std::optional& alibi_slopes, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const c10::optional& attn_bias, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index fb53d122487d3..5ab404523494f 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -29,7 +29,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! out, Tensor query, Tensor key_cache," " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," - " int max_seq_len, Tensor? alibi_slopes," + " int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias," " str kv_cache_dtype, float k_scale, float v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," @@ -43,7 +43,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! tmp_out, Tensor query, Tensor key_cache," " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," - " int max_seq_len, Tensor? alibi_slopes," + " int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias," " str kv_cache_dtype, float k_scale, float v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 124d5d297a574..8a3b21e197612 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -18,7 +18,8 @@ FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. # - 512 as a buffer -MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 +# MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 +MAX_SEQ_LEN = 16 # There may not be enough gpu memory due to large NUM_BLOCKS. # Reduce NUM_BLOCKS when it happens. NUM_BLOCKS = 4321 # Arbitrary values for testing @@ -29,6 +30,7 @@ ] if not current_platform.is_rocm() else [torch.half, torch.bfloat16] NUM_GEN_SEQS = [7] # Arbitrary values for testing NUM_PREFILL_SEQS = [3] # Arbitrary values for testing +# TODO fix different num of heads NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing # This should be sync with get_supported_head_sizes() in @@ -37,6 +39,7 @@ BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] +USE_CUSTOM_ATTN_BIAS = [False, True] KV_CACHE_DTYPE = ["auto", "fp8"] SEEDS = [0] CUDA_DEVICES = [ @@ -60,16 +63,11 @@ def ref_masked_attention( def ref_single_query_cached_kv_attention( - output: torch.Tensor, - query: torch.Tensor, - num_queries_per_kv: int, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, - scale: float, - alibi_slopes: Optional[torch.Tensor], -) -> None: + output: torch.Tensor, query: torch.Tensor, num_queries_per_kv: int, + key_cache: torch.Tensor, value_cache: torch.Tensor, + block_tables: torch.Tensor, seq_lens: torch.Tensor, scale: float, + alibi_slopes: Optional[torch.Tensor], + attn_bias: Optional[List[torch.Tensor]]) -> None: num_query_heads = query.shape[1] num_kv_heads = value_cache.shape[1] head_size = value_cache.shape[2] @@ -102,15 +100,19 @@ def ref_single_query_cached_kv_attention( keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) - alibi_bias = None + bias = None if alibi_slopes is not None: # Create the ALiBi bias used in the paged attention kernel. position_ids = torch.arange(seq_len).int() alibi_bias = (position_ids - seq_len + 1).float() alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( 1, 1, -1) - - out = ref_masked_attention(q, keys, values, scale, alibi_bias) + bias = alibi_bias + if attn_bias is not None: + # TODO test alibi + bias + bias = attn_bias[i] if bias is None else bias + attn_bias[i] + # print(f"ATTN BIAS {i}: {attn_bias[i]}") + out = ref_masked_attention(q, keys, values, scale, bias) out = out.view(num_query_heads, head_size) output[i].copy_(out, non_blocking=True) @@ -122,6 +124,7 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("use_alibi", USE_ALIBI) +@pytest.mark.parametrize("use_custom_attn_bias", USE_CUSTOM_ATTN_BIAS) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @@ -134,12 +137,17 @@ def test_paged_attention( num_heads: Tuple[int, int], head_size: int, use_alibi: bool, + use_custom_attn_bias: bool, block_size: int, dtype: torch.dtype, kv_cache_dtype: str, seed: int, device: str, ) -> None: + # num_heads = (2, 2) + # num_seqs = 2 + # head_size = 32 + if ((kv_cache_dtype == "fp8" and head_size % 16) or (version == "rocm" and head_size not in (64, 128))): pytest.skip() @@ -153,7 +161,7 @@ def test_paged_attention( assert num_query_heads % num_kv_heads == 0 num_queries_per_kv = num_query_heads // num_kv_heads - alibi_slopes = None + alibi_slopes, attn_bias = None, None if use_alibi: alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) @@ -161,6 +169,20 @@ def test_paged_attention( seq_lens[-1] = MAX_SEQ_LEN max_seq_len = max(seq_lens) seq_lens = torch.tensor(seq_lens, dtype=torch.int) + attn_bias_list = None + if use_custom_attn_bias: + # NOTE (NickLucche) each sequence can have a different bias, + # depending on its len, but it *must* be float (f32)! + attn_bias_list = [torch.randn(num_query_heads, + 1, + seq_len, + dtype=torch.float) for seq_len in seq_lens] + attn_bias = torch.empty(num_seqs, num_query_heads, 1, max_seq_len, device=device, dtype=torch.float) + + for i, (seq_len, bias) in enumerate(zip(seq_lens, attn_bias_list)): + # first seq_len entries of the bias for each head/seq + attn_bias[i, :, :, :seq_len] = bias + # print("bias shape", attn_bias.shape) # Create the block tables. max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size @@ -186,6 +208,7 @@ def test_paged_attention( # Call the paged attention kernel. output = torch.empty_like(query) + # print("BIAS", attn_bias) if version == "v1": ops.paged_attention_v1( output, @@ -199,19 +222,23 @@ def test_paged_attention( block_size, max_seq_len, alibi_slopes, + attn_bias, kv_cache_dtype, k_scale, v_scale, ) + # print("\nOUT", output) opcheck(torch.ops._C.paged_attention_v1, (output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, + attn_bias, kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0])) elif version in ("v2", "rocm"): + assert False num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape @@ -240,6 +267,7 @@ def test_paged_attention( block_size, max_seq_len, alibi_slopes, + attn_bias, kv_cache_dtype, k_scale, v_scale, @@ -249,6 +277,7 @@ def test_paged_attention( (output, exp_sums, max_logits, tmp_output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, + attn_bias, kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0])) @@ -305,17 +334,11 @@ def test_paged_attention( value_cache = dequantized_value_cache ref_output = torch.empty_like(query) - ref_single_query_cached_kv_attention( - ref_output, - query, - num_queries_per_kv, - key_cache, - value_cache, - block_tables, - seq_lens, - scale, - alibi_slopes, - ) + ref_single_query_cached_kv_attention(ref_output, query, num_queries_per_kv, + key_cache, value_cache, block_tables, + seq_lens, scale, alibi_slopes, + attn_bias_list) + # print("\nREF OUT", ref_output) # NOTE(woosuk): Due to the kernel-level differences in the two # implementations, there is a small numerical difference in the two diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index d04cbbc0a9eed..cae9c5e31fa68 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -47,6 +47,7 @@ def paged_attention_v1( block_size: int, max_seq_len: int, alibi_slopes: Optional[torch.Tensor], + attn_bias: Optional[torch.Tensor], kv_cache_dtype: str, k_scale: float, v_scale: float, @@ -58,8 +59,8 @@ def paged_attention_v1( ) -> None: torch.ops._C.paged_attention_v1( out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, - seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, - k_scale, v_scale, tp_rank, blocksparse_local_blocks, + seq_lens, block_size, max_seq_len, alibi_slopes, attn_bias, + kv_cache_dtype, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step) @@ -79,6 +80,7 @@ def paged_attention_v2( block_size: int, max_seq_len: int, alibi_slopes: Optional[torch.Tensor], + attn_bias: Optional[torch.Tensor], kv_cache_dtype: str, k_scale: float, v_scale: float, @@ -91,7 +93,7 @@ def paged_attention_v2( torch.ops._C.paged_attention_v2( out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, - alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, + alibi_slopes, attn_bias, kv_cache_dtype, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step) diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 9089db1126c94..0700f4195a634 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -439,6 +439,7 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, + None, # TODO support attn_bias layer._k_scale, layer._v_scale, tp_rank=self.tp_rank, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index e9f2808ff1674..477c52141020b 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -628,6 +628,7 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, + None, # TODO support attn_bias layer._k_scale, layer._v_scale, ) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 38e27434dab2c..733287ab51e2b 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -577,6 +577,7 @@ def forward( prefill_meta.context_lens_tensor, prefill_meta.max_query_len, self.alibi_slopes, + prefill_meta.attn_bias, self.sliding_window, layer._k_scale, layer._v_scale, @@ -605,6 +606,7 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, + decode_meta.attn_bias, layer._k_scale, layer._v_scale, ) diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 076f151ffcb61..bc651fa0dc326 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -95,6 +95,7 @@ def forward_decode( num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], + attn_bias: Optional[torch.Tensor], k_scale: float, v_scale: float, tp_rank: int = 0, @@ -140,6 +141,7 @@ def forward_decode( block_size, max_seq_len, alibi_slopes, + attn_bias, kv_cache_dtype, k_scale, v_scale, @@ -178,6 +180,7 @@ def forward_decode( block_size, max_seq_len, alibi_slopes, + attn_bias, kv_cache_dtype, k_scale, v_scale, @@ -203,11 +206,13 @@ def forward_prefix( context_lens: torch.Tensor, max_query_len: int, alibi_slopes: Optional[torch.Tensor], + attn_bias: Optional[torch.Tensor], sliding_window: Optional[int], k_scale: float, v_scale: float, ) -> torch.Tensor: output = torch.empty_like(query) + assert attn_bias is None, "Bias for prefix not yet enabled" context_attention_fwd( query, key, From 85636a12b7b372167295d3e016d48929154baa11 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 19 Dec 2024 10:16:11 +0000 Subject: [PATCH 2/7] add working kernel with padded_max_seq_len as arg Signed-off-by: NickLucche --- csrc/attention/attention_kernels.cuh | 56 +++++++++++++--------------- csrc/attention/paged_attention_v1.cu | 38 ++++++++++--------- csrc/attention/paged_attention_v2.cu | 17 +++++++-- 3 files changed, 59 insertions(+), 52 deletions(-) diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh index 25de77c324c62..08f9882f65f09 100644 --- a/csrc/attention/attention_kernels.cuh +++ b/csrc/attention/attention_kernels.cuh @@ -104,7 +104,8 @@ __device__ void paged_attention_kernel( const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] - const float* __restrict__ attn_bias, // [num_seqs, num_heads, max_seq_len] + const float* __restrict__ attn_bias, // [num_seqs, num_heads, max_seq_len] + const int padded_max_seq_len, // Avoid recomputing from seq_lens. const int q_stride, const int kv_block_stride, const int kv_head_stride, const float k_scale, const float v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, @@ -154,21 +155,14 @@ __device__ void paged_attention_kernel( const int kv_head_idx = head_idx / num_queries_per_kv; const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; - // TODO check if indexing still makes sense - // seq_len indexes on 'max_seq_lens' dim, - // it's like renaming dim you get attn_bias: seq_len x num_kv_heads x seq_len - // TODO each seq can have different len (seq_lens) but only one bias!! - // NOTE (NickLucche) `max_seq_len` bias values for current sequence and current head + + // NOTE (NickLucche) `max_seq_len` (padded) bias values for current sequence + // and current head. const float* attn_bias_vec = attn_bias == nullptr ? nullptr - : attn_bias + seq_idx * num_heads * num_seq_blocks * BLOCK_SIZE + - head_idx * num_seq_blocks * BLOCK_SIZE; - // : attn_bias + seq_idx * num_kv_heads * num_seq_blocks * BLOCK_SIZE + - // const float* attn_bias_vec = attn_bias == nullptr - // ? nullptr - // : attn_bias + seq_idx * num_kv_heads * seq_len + - // kv_head_idx * seq_len; + : attn_bias + seq_idx * num_heads * padded_max_seq_len + + head_idx * padded_max_seq_len; // A vector type to store a part of a key or a query. // The vector size is configured in such a way that the threads in a thread @@ -309,9 +303,7 @@ __device__ void paged_attention_kernel( // This includes a reduction across the threads in the same thread group. float qk = scale * Qk_dot::dot( q_vecs[thread_group_offset], k_vecs); - // NOTE here each thread adds its own alibi (one per head..) like I am - // sure not the whole group needs to do so Add the ALiBi bias if slopes - // are given. + // Add the ALiBi bias if slopes are given, then add custom bias if given. // TODO mutually exclusive? qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; qk += (attn_bias_vec != nullptr) ? attn_bias_vec[token_idx] : 0; @@ -532,17 +524,18 @@ __global__ void paged_attention_v1_kernel( const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] - const float* __restrict__ attn_bias, const int q_stride, - const int kv_block_stride, const int kv_head_stride, const float k_scale, - const float v_scale, const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { + const float* __restrict__ attn_bias, + const int padded_max_seq_len, // Avoid recomputing from seq_lens. + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float k_scale, const float v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, - max_num_blocks_per_seq, alibi_slopes, attn_bias, q_stride, - kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, + max_num_blocks_per_seq, alibi_slopes, attn_bias, padded_max_seq_len, + q_stride, kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step); } @@ -569,18 +562,19 @@ __global__ void paged_attention_v2_kernel( const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] - const float* __restrict__ attn_bias, const int q_stride, - const int kv_block_stride, const int kv_head_stride, const float k_scale, - const float v_scale, const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { + const float* __restrict__ attn_bias, + const int padded_max_seq_len, // Avoid recomputing from seq_lens. + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float k_scale, const float v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { paged_attention_kernel( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, attn_bias, - q_stride, kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, - blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, - blocksparse_head_sliding_step); + padded_max_seq_len, q_stride, kv_block_stride, kv_head_stride, k_scale, + v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, + blocksparse_block_size, blocksparse_head_sliding_step); } // Grid: (num_heads, num_seqs). diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu index 13a10221db425..cbe5d1dd6f3f6 100644 --- a/csrc/attention/paged_attention_v1.cu +++ b/csrc/attention/paged_attention_v1.cu @@ -29,21 +29,21 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) -#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ - VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ - ((void*)vllm::paged_attention_v1_kernel), \ - shared_mem_size); \ - vllm::paged_attention_v1_kernel \ - <<>>( \ - out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ - scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ - alibi_slopes_ptr, attn_bias_ptr, q_stride, kv_block_stride, \ - kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ - blocksparse_vert_stride, blocksparse_block_size, \ - blocksparse_head_sliding_step); +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ + VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ + ((void*)vllm::paged_attention_v1_kernel), \ + shared_mem_size); \ + vllm::paged_attention_v1_kernel \ + <<>>( \ + out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ + scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, attn_bias_ptr, padded_max_seq_len, q_stride, \ + kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); // TODO(woosuk): Tune NUM_THREADS. template (attn_bias.value().data_ptr()) : nullptr; - if (attn_bias_ptr){ - TORCH_CHECK(attn_bias.value().dtype() == torch::kFloat32, "Unsupported bias dtype: ", attn_bias.value().dtype()); + if (attn_bias_ptr) { + const torch::Tensor& abias = attn_bias.value(); + TORCH_CHECK(abias.dtype() == torch::kFloat32, + "Unsupported bias dtype: ", abias.dtype()); + TORCH_CHECK(abias.size(abias.dim() - 1) == max_seq_len, + "Unexpected attn_bias shape: ", abias.sizes()); } T* out_ptr = reinterpret_cast(out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index 80e1d7cb962df..2b25a6afe765f 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -37,9 +37,10 @@ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, \ - attn_bias_ptr, q_stride, kv_block_stride, kv_head_stride, k_scale, \ - v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \ - blocksparse_block_size, blocksparse_head_sliding_step); \ + attn_bias_ptr, padded_max_seq_len, q_stride, kv_block_stride, \ + kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ + blocksparse_head_sliding_step); \ vllm::paged_attention_v2_reduce_kernel \ <<>>( \ @@ -78,7 +79,13 @@ void paged_attention_v2_launcher( const float* attn_bias_ptr = attn_bias ? reinterpret_cast(attn_bias.value().data_ptr()) : nullptr; - + if (attn_bias_ptr) { + const torch::Tensor& abias = attn_bias.value(); + TORCH_CHECK(abias.dtype() == torch::kFloat32, + "Unsupported bias dtype: ", abias.dtype()); + TORCH_CHECK(abias.size(abias.dim() - 1) == max_seq_len, + "Unexpected attn_bias shape: ", abias.sizes()); + } T* out_ptr = reinterpret_cast(out.data_ptr()); float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); @@ -91,6 +98,8 @@ void paged_attention_v2_launcher( constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); + int padded_max_seq_len = + DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; int logits_size = PARTITION_SIZE * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); From 0f39fdc4953b1d312bce994c170b48b6277234d7 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 19 Dec 2024 10:22:50 +0000 Subject: [PATCH 3/7] add attn_bias case to pagedattn tests Signed-off-by: NickLucche --- tests/kernels/test_attention.py | 51 ++++++++++++++------------------- 1 file changed, 21 insertions(+), 30 deletions(-) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 8a3b21e197612..efc7c040c1212 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -18,8 +18,7 @@ FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 # This will change depending on the compute capability. # - 512 as a buffer -# MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 -MAX_SEQ_LEN = 16 +MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 # There may not be enough gpu memory due to large NUM_BLOCKS. # Reduce NUM_BLOCKS when it happens. NUM_BLOCKS = 4321 # Arbitrary values for testing @@ -30,7 +29,6 @@ ] if not current_platform.is_rocm() else [torch.half, torch.bfloat16] NUM_GEN_SEQS = [7] # Arbitrary values for testing NUM_PREFILL_SEQS = [3] # Arbitrary values for testing -# TODO fix different num of heads NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing # This should be sync with get_supported_head_sizes() in @@ -109,9 +107,7 @@ def ref_single_query_cached_kv_attention( 1, 1, -1) bias = alibi_bias if attn_bias is not None: - # TODO test alibi + bias bias = attn_bias[i] if bias is None else bias + attn_bias[i] - # print(f"ATTN BIAS {i}: {attn_bias[i]}") out = ref_masked_attention(q, keys, values, scale, bias) out = out.view(num_query_heads, head_size) output[i].copy_(out, non_blocking=True) @@ -144,10 +140,6 @@ def test_paged_attention( seed: int, device: str, ) -> None: - # num_heads = (2, 2) - # num_seqs = 2 - # head_size = 32 - if ((kv_cache_dtype == "fp8" and head_size % 16) or (version == "rocm" and head_size not in (64, 128))): pytest.skip() @@ -173,16 +165,20 @@ def test_paged_attention( if use_custom_attn_bias: # NOTE (NickLucche) each sequence can have a different bias, # depending on its len, but it *must* be float (f32)! - attn_bias_list = [torch.randn(num_query_heads, - 1, - seq_len, - dtype=torch.float) for seq_len in seq_lens] - attn_bias = torch.empty(num_seqs, num_query_heads, 1, max_seq_len, device=device, dtype=torch.float) + attn_bias_list = [ + torch.randn(num_query_heads, 1, seq_len, dtype=torch.float) + for seq_len in seq_lens + ] + attn_bias = torch.empty(num_seqs, + num_query_heads, + 1, + max_seq_len, + device=device, + dtype=torch.float) for i, (seq_len, bias) in enumerate(zip(seq_lens, attn_bias_list)): - # first seq_len entries of the bias for each head/seq + # first `seq_len` entries of the bias matrix for each head/seq attn_bias[i, :, :, :seq_len] = bias - # print("bias shape", attn_bias.shape) # Create the block tables. max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size @@ -208,7 +204,6 @@ def test_paged_attention( # Call the paged attention kernel. output = torch.empty_like(query) - # print("BIAS", attn_bias) if version == "v1": ops.paged_attention_v1( output, @@ -227,18 +222,15 @@ def test_paged_attention( k_scale, v_scale, ) - # print("\nOUT", output) opcheck(torch.ops._C.paged_attention_v1, (output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, - attn_bias, - kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), + attn_bias, kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0])) elif version in ("v2", "rocm"): - assert False num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape @@ -273,14 +265,14 @@ def test_paged_attention( v_scale, ) - opcheck(torch.ops._C.paged_attention_v2, - (output, exp_sums, max_logits, tmp_output, query, - key_cache, value_cache, num_kv_heads, scale, block_tables, - seq_lens, block_size, max_seq_len, alibi_slopes, - attn_bias, - kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), - cond=(head_size == HEAD_SIZES[0] - and block_size == BLOCK_SIZES[0])) + opcheck( + torch.ops._C.paged_attention_v2, + (output, exp_sums, max_logits, tmp_output, query, key_cache, + value_cache, num_kv_heads, scale, block_tables, seq_lens, + block_size, max_seq_len, alibi_slopes, attn_bias, + kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), + cond=(head_size == HEAD_SIZES[0] + and block_size == BLOCK_SIZES[0])) else: ops.paged_attention_rocm( @@ -338,7 +330,6 @@ def test_paged_attention( key_cache, value_cache, block_tables, seq_lens, scale, alibi_slopes, attn_bias_list) - # print("\nREF OUT", ref_output) # NOTE(woosuk): Due to the kernel-level differences in the two # implementations, there is a small numerical difference in the two From a1de4bcd59cf173a85366410d79506c87729ea77 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 19 Dec 2024 10:32:29 +0000 Subject: [PATCH 4/7] format Signed-off-by: NickLucche --- benchmarks/kernels/benchmark_paged_attention.py | 2 +- vllm/attention/backends/xformers.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 120b8ffe9c657..53d1f5803cf20 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -114,7 +114,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: block_size, max_seq_len, alibi_slopes, - None, # TODO add custom bias + None, # TODO add custom bias kv_cache_dtype, k_scale, v_scale, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 733287ab51e2b..058960456f831 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -606,6 +606,8 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, + # TODO (NickLucche) cross_attn_bias not needed for T5-like + # models, abstract bias selection if needed. decode_meta.attn_bias, layer._k_scale, layer._v_scale, From 7e776b6afaad361b4eef49552b2c1cb46e35f091 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Thu, 19 Dec 2024 10:57:23 +0000 Subject: [PATCH 5/7] format Signed-off-by: NickLucche --- vllm/attention/backends/xformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 058960456f831..15732af7acad8 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -606,7 +606,7 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, - # TODO (NickLucche) cross_attn_bias not needed for T5-like + # TODO (NickLucche) cross_attn_bias not needed for T5-like # models, abstract bias selection if needed. decode_meta.attn_bias, layer._k_scale, From 17affcffe88dea53aea764155f897a876099bbce Mon Sep 17 00:00:00 2001 From: NickLucche Date: Fri, 27 Dec 2024 17:21:43 +0000 Subject: [PATCH 6/7] enforce last dim of attn bias to be block aligned Signed-off-by: NickLucche --- csrc/attention/paged_attention_v1.cu | 17 ++++++++++------- csrc/attention/paged_attention_v2.cu | 21 ++++++++++++--------- tests/kernels/test_attention.py | 19 +++++++++++-------- 3 files changed, 33 insertions(+), 24 deletions(-) diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu index cbe5d1dd6f3f6..0b04b55a4e13a 100644 --- a/csrc/attention/paged_attention_v1.cu +++ b/csrc/attention/paged_attention_v1.cu @@ -77,12 +77,17 @@ void paged_attention_v1_launcher( const float* attn_bias_ptr = attn_bias ? reinterpret_cast(attn_bias.value().data_ptr()) : nullptr; + const int padded_max_seq_len = + DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; if (attn_bias_ptr) { const torch::Tensor& abias = attn_bias.value(); TORCH_CHECK(abias.dtype() == torch::kFloat32, "Unsupported bias dtype: ", abias.dtype()); - TORCH_CHECK(abias.size(abias.dim() - 1) == max_seq_len, - "Unexpected attn_bias shape: ", abias.sizes()); + TORCH_CHECK(abias.size(abias.dim() - 1) == padded_max_seq_len, + "The last dimension of the attention bias must " + "match the block-aligned maximum sequence length (", + padded_max_seq_len, + "). However, the given dimensions are: ", abias.sizes()); } T* out_ptr = reinterpret_cast(out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); @@ -92,13 +97,11 @@ void paged_attention_v1_launcher( int* seq_lens_ptr = seq_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_seq_len = - DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; - int logits_size = padded_max_seq_len * sizeof(float); - int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + const int logits_size = padded_max_seq_len * sizeof(float); + const int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len // Keep that in sync with the logic here! - int shared_mem_size = std::max(logits_size, outputs_size); + const int shared_mem_size = std::max(logits_size, outputs_size); dim3 grid(num_heads, num_seqs, 1); dim3 block(NUM_THREADS); diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index 2b25a6afe765f..5eeba75d5cf1c 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -79,12 +79,17 @@ void paged_attention_v2_launcher( const float* attn_bias_ptr = attn_bias ? reinterpret_cast(attn_bias.value().data_ptr()) : nullptr; + const int padded_max_seq_len = + DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; if (attn_bias_ptr) { const torch::Tensor& abias = attn_bias.value(); TORCH_CHECK(abias.dtype() == torch::kFloat32, "Unsupported bias dtype: ", abias.dtype()); - TORCH_CHECK(abias.size(abias.dim() - 1) == max_seq_len, - "Unexpected attn_bias shape: ", abias.sizes()); + TORCH_CHECK(abias.size(abias.dim() - 1) == padded_max_seq_len, + "The last dimension of the attention bias must " + "match the block-aligned maximum sequence length (", + padded_max_seq_len, + "). However, the given dimensions are: ", abias.sizes()); } T* out_ptr = reinterpret_cast(out.data_ptr()); float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); @@ -97,18 +102,16 @@ void paged_attention_v2_launcher( int* seq_lens_ptr = seq_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); - int padded_max_seq_len = - DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; - int logits_size = PARTITION_SIZE * sizeof(float); - int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + const int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); + const int logits_size = PARTITION_SIZE * sizeof(float); + const int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // For paged attention v2 kernel. dim3 grid(num_heads, num_seqs, max_num_partitions); - int shared_mem_size = std::max(logits_size, outputs_size); + const int shared_mem_size = std::max(logits_size, outputs_size); // For paged attention v2 reduce kernel. dim3 reduce_grid(num_heads, num_seqs); - int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); + const int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); dim3 block(NUM_THREADS); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index efc7c040c1212..22932cb7b9c94 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -162,26 +162,29 @@ def test_paged_attention( max_seq_len = max(seq_lens) seq_lens = torch.tensor(seq_lens, dtype=torch.int) attn_bias_list = None + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size if use_custom_attn_bias: # NOTE (NickLucche) each sequence can have a different bias, - # depending on its len, but it *must* be float (f32)! + # depending on its len, but it *must* be padded to the block + # aligned max_seq_len and of type float32! attn_bias_list = [ torch.randn(num_query_heads, 1, seq_len, dtype=torch.float) for seq_len in seq_lens ] - attn_bias = torch.empty(num_seqs, - num_query_heads, - 1, - max_seq_len, - device=device, - dtype=torch.float) + block_aligned_max_seq_len = max_num_blocks_per_seq * block_size + attn_bias = torch.empty( + num_seqs, + num_query_heads, + 1, + block_aligned_max_seq_len, # padded dim + device=device, + dtype=torch.float) for i, (seq_len, bias) in enumerate(zip(seq_lens, attn_bias_list)): # first `seq_len` entries of the bias matrix for each head/seq attn_bias[i, :, :, :seq_len] = bias # Create the block tables. - max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables_lst: List[List[int]] = [] for _ in range(num_seqs): block_table = [ From ca9562bf65d34c7b4e59027a57e50c65ac9ad3b0 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Wed, 15 Jan 2025 14:28:52 +0000 Subject: [PATCH 7/7] fix blocksparse tests Signed-off-by: NickLucche --- tests/kernels/test_blocksparse_attention.py | 2 ++ vllm/attention/ops/ipex_attn.py | 1 + 2 files changed, 3 insertions(+) diff --git a/tests/kernels/test_blocksparse_attention.py b/tests/kernels/test_blocksparse_attention.py index fad342d1b5923..1cfba68483338 100644 --- a/tests/kernels/test_blocksparse_attention.py +++ b/tests/kernels/test_blocksparse_attention.py @@ -228,6 +228,7 @@ def test_paged_attention( block_size, max_seq_len, alibi_slopes, + None, # TODO add custom bias kv_cache_dtype, k_scale, v_scale, @@ -265,6 +266,7 @@ def test_paged_attention( block_size, max_seq_len, alibi_slopes, + None, kv_cache_dtype, k_scale, v_scale, diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py index cbc6c74acf09a..5e4b1c8bc29e2 100644 --- a/vllm/attention/ops/ipex_attn.py +++ b/vllm/attention/ops/ipex_attn.py @@ -103,6 +103,7 @@ def forward_decode( block_size, max_context_len, alibi_slopes, + None, # TODO add custom bias kv_cache_dtype, k_scale, v_scale,