diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 14eef00b855ac..53d1f5803cf20 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -114,6 +114,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: block_size, max_seq_len, alibi_slopes, + None, # TODO add custom bias kv_cache_dtype, k_scale, v_scale, @@ -134,6 +135,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: block_size, max_seq_len, alibi_slopes, + None, kv_cache_dtype, k_scale, v_scale, diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh index 563e1438f0b01..08f9882f65f09 100644 --- a/csrc/attention/attention_kernels.cuh +++ b/csrc/attention/attention_kernels.cuh @@ -104,6 +104,8 @@ __device__ void paged_attention_kernel( const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] + const float* __restrict__ attn_bias, // [num_seqs, num_heads, max_seq_len] + const int padded_max_seq_len, // Avoid recomputing from seq_lens. const int q_stride, const int kv_block_stride, const int kv_head_stride, const float k_scale, const float v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, @@ -154,6 +156,14 @@ __device__ void paged_attention_kernel( const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + // NOTE (NickLucche) `max_seq_len` (padded) bias values for current sequence + // and current head. + const float* attn_bias_vec = + attn_bias == nullptr + ? nullptr + : attn_bias + seq_idx * num_heads * padded_max_seq_len + + head_idx * padded_max_seq_len; + // A vector type to store a part of a key or a query. // The vector size is configured in such a way that the threads in a thread // group fetch or compute 16 bytes at a time. For example, if the size of a @@ -293,8 +303,10 @@ __device__ void paged_attention_kernel( // This includes a reduction across the threads in the same thread group. float qk = scale * Qk_dot::dot( q_vecs[thread_group_offset], k_vecs); - // Add the ALiBi bias if slopes are given. + // Add the ALiBi bias if slopes are given, then add custom bias if given. + // TODO mutually exclusive? qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; + qk += (attn_bias_vec != nullptr) ? attn_bias_vec[token_idx] : 0; if (thread_group_offset == 0) { // Store the partial reductions to shared memory. @@ -512,6 +524,8 @@ __global__ void paged_attention_v1_kernel( const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] + const float* __restrict__ attn_bias, + const int padded_max_seq_len, // Avoid recomputing from seq_lens. const int q_stride, const int kv_block_stride, const int kv_head_stride, const float k_scale, const float v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, @@ -520,9 +534,9 @@ __global__ void paged_attention_v1_kernel( KV_DTYPE, IS_BLOCK_SPARSE>( /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, seq_lens, - max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, - kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, - blocksparse_vert_stride, blocksparse_block_size, + max_num_blocks_per_seq, alibi_slopes, attn_bias, padded_max_seq_len, + q_stride, kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, + blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step); } @@ -548,6 +562,8 @@ __global__ void paged_attention_v2_kernel( const int* __restrict__ seq_lens, // [num_seqs] const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] + const float* __restrict__ attn_bias, + const int padded_max_seq_len, // Avoid recomputing from seq_lens. const int q_stride, const int kv_block_stride, const int kv_head_stride, const float k_scale, const float v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, @@ -555,10 +571,10 @@ __global__ void paged_attention_v2_kernel( paged_attention_kernel( exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, - block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, - kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, - blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, - blocksparse_head_sliding_step); + block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, attn_bias, + padded_max_seq_len, q_stride, kv_block_stride, kv_head_stride, k_scale, + v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, + blocksparse_block_size, blocksparse_head_sliding_step); } // Grid: (num_heads, num_seqs). diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu index 27321148f6dda..0b04b55a4e13a 100644 --- a/csrc/attention/paged_attention_v1.cu +++ b/csrc/attention/paged_attention_v1.cu @@ -40,10 +40,10 @@ <<>>( \ out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ - alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ - k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ - blocksparse_vert_stride, blocksparse_block_size, \ - blocksparse_head_sliding_step); + alibi_slopes_ptr, attn_bias_ptr, padded_max_seq_len, q_stride, \ + kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); // TODO(woosuk): Tune NUM_THREADS. template & alibi_slopes, float k_scale, - float v_scale, const int tp_rank, const int blocksparse_local_blocks, + const c10::optional& alibi_slopes, + const c10::optional& attn_bias, float k_scale, float v_scale, + const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); @@ -73,7 +74,21 @@ void paged_attention_v1_launcher( alibi_slopes ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; - + const float* attn_bias_ptr = + attn_bias ? reinterpret_cast(attn_bias.value().data_ptr()) + : nullptr; + const int padded_max_seq_len = + DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; + if (attn_bias_ptr) { + const torch::Tensor& abias = attn_bias.value(); + TORCH_CHECK(abias.dtype() == torch::kFloat32, + "Unsupported bias dtype: ", abias.dtype()); + TORCH_CHECK(abias.size(abias.dim() - 1) == padded_max_seq_len, + "The last dimension of the attention bias must " + "match the block-aligned maximum sequence length (", + padded_max_seq_len, + "). However, the given dimensions are: ", abias.sizes()); + } T* out_ptr = reinterpret_cast(out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); @@ -82,13 +97,11 @@ void paged_attention_v1_launcher( int* seq_lens_ptr = seq_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_seq_len = - DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; - int logits_size = padded_max_seq_len * sizeof(float); - int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + const int logits_size = padded_max_seq_len * sizeof(float); + const int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len // Keep that in sync with the logic here! - int shared_mem_size = std::max(logits_size, outputs_size); + const int shared_mem_size = std::max(logits_size, outputs_size); dim3 grid(num_heads, num_seqs, 1); dim3 block(NUM_THREADS); @@ -135,8 +148,8 @@ void paged_attention_v1_launcher( paged_attention_v1_launcher( \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ - seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \ - blocksparse_local_blocks, blocksparse_vert_stride, \ + seq_lens, max_seq_len, alibi_slopes, attn_bias, k_scale, v_scale, \ + tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_block_size, blocksparse_head_sliding_step); #define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ @@ -176,7 +189,8 @@ void paged_attention_v1( torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, - const std::optional& alibi_slopes, + const c10::optional& alibi_slopes, + const c10::optional& attn_bias, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index a453b2243e48c..5eeba75d5cf1c 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -36,10 +36,11 @@ <<>>( \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ - seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ - kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ - blocksparse_local_blocks, blocksparse_vert_stride, \ - blocksparse_block_size, blocksparse_head_sliding_step); \ + seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, \ + attn_bias_ptr, padded_max_seq_len, q_stride, kv_block_stride, \ + kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ + blocksparse_head_sliding_step); \ vllm::paged_attention_v2_reduce_kernel \ <<>>( \ @@ -54,8 +55,9 @@ void paged_attention_v2_launcher( torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, - const std::optional& alibi_slopes, float k_scale, - float v_scale, const int tp_rank, const int blocksparse_local_blocks, + const c10::optional& alibi_slopes, + const c10::optional& attn_bias, float k_scale, float v_scale, + const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); @@ -74,7 +76,21 @@ void paged_attention_v2_launcher( alibi_slopes ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; - + const float* attn_bias_ptr = + attn_bias ? reinterpret_cast(attn_bias.value().data_ptr()) + : nullptr; + const int padded_max_seq_len = + DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; + if (attn_bias_ptr) { + const torch::Tensor& abias = attn_bias.value(); + TORCH_CHECK(abias.dtype() == torch::kFloat32, + "Unsupported bias dtype: ", abias.dtype()); + TORCH_CHECK(abias.size(abias.dim() - 1) == padded_max_seq_len, + "The last dimension of the attention bias must " + "match the block-aligned maximum sequence length (", + padded_max_seq_len, + "). However, the given dimensions are: ", abias.sizes()); + } T* out_ptr = reinterpret_cast(out.data_ptr()); float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); @@ -86,16 +102,16 @@ void paged_attention_v2_launcher( int* seq_lens_ptr = seq_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); - int logits_size = PARTITION_SIZE * sizeof(float); - int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + const int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); + const int logits_size = PARTITION_SIZE * sizeof(float); + const int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // For paged attention v2 kernel. dim3 grid(num_heads, num_seqs, max_num_partitions); - int shared_mem_size = std::max(logits_size, outputs_size); + const int shared_mem_size = std::max(logits_size, outputs_size); // For paged attention v2 reduce kernel. dim3 reduce_grid(num_heads, num_seqs); - int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); + const int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); dim3 block(NUM_THREADS); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); @@ -142,7 +158,7 @@ void paged_attention_v2_launcher( IS_BLOCK_SPARSE>( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ - k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ + attn_bias, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_head_sliding_step); @@ -187,7 +203,8 @@ void paged_attention_v2( torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, - const std::optional& alibi_slopes, + const c10::optional& alibi_slopes, + const c10::optional& attn_bias, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index ef5b14088c63b..eb33c66953a6e 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -459,7 +459,8 @@ void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const std::optional& alibi_slopes, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const c10::optional& attn_bias, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, @@ -467,6 +468,8 @@ void paged_attention_v1( TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); + TORCH_CHECK(!attn_bias.has_value(), + "CPU backend does not support custom attention bias."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", [&] { CPU_KERNEL_GUARD_IN(paged_attention_v1_impl) @@ -781,7 +784,8 @@ void paged_attention_v2( torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const std::optional& alibi_slopes, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const c10::optional& attn_bias, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, @@ -789,6 +793,8 @@ void paged_attention_v2( TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); + TORCH_CHECK(!attn_bias.has_value(), + "CPU backend does not support custom attention bias."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", [&] { CPU_KERNEL_GUARD_IN(paged_attention_v2_impl) diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 74e4d8189d403..3cfa289848e21 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -24,12 +24,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Attention ops // Compute the attention between an input query and the cached keys/values // using PagedAttention. + // TODO attn_bias on cpu ops.def( "paged_attention_v1(" " Tensor! out, Tensor query, Tensor key_cache," " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," - " int max_seq_len, Tensor? alibi_slopes," + " int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias," " str kv_cache_dtype, float k_scale, float v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," @@ -43,7 +44,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! tmp_out, Tensor query, Tensor key_cache," " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," - " int max_seq_len, Tensor? alibi_slopes," + " int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias," " str kv_cache_dtype, float k_scale, float v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," diff --git a/csrc/ops.h b/csrc/ops.h index 5a194a0dd3654..f2ad92074f446 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -33,7 +33,8 @@ void paged_attention_v1( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const std::optional& alibi_slopes, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const c10::optional& attn_bias, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, @@ -44,7 +45,8 @@ void paged_attention_v2( torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, - int64_t max_seq_len, const std::optional& alibi_slopes, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const c10::optional& attn_bias, const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index fb53d122487d3..5ab404523494f 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -29,7 +29,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! out, Tensor query, Tensor key_cache," " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," - " int max_seq_len, Tensor? alibi_slopes," + " int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias," " str kv_cache_dtype, float k_scale, float v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," @@ -43,7 +43,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! tmp_out, Tensor query, Tensor key_cache," " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," - " int max_seq_len, Tensor? alibi_slopes," + " int max_seq_len, Tensor? alibi_slopes, Tensor? attn_bias," " str kv_cache_dtype, float k_scale, float v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 124d5d297a574..22932cb7b9c94 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -37,6 +37,7 @@ BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] +USE_CUSTOM_ATTN_BIAS = [False, True] KV_CACHE_DTYPE = ["auto", "fp8"] SEEDS = [0] CUDA_DEVICES = [ @@ -60,16 +61,11 @@ def ref_masked_attention( def ref_single_query_cached_kv_attention( - output: torch.Tensor, - query: torch.Tensor, - num_queries_per_kv: int, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - seq_lens: torch.Tensor, - scale: float, - alibi_slopes: Optional[torch.Tensor], -) -> None: + output: torch.Tensor, query: torch.Tensor, num_queries_per_kv: int, + key_cache: torch.Tensor, value_cache: torch.Tensor, + block_tables: torch.Tensor, seq_lens: torch.Tensor, scale: float, + alibi_slopes: Optional[torch.Tensor], + attn_bias: Optional[List[torch.Tensor]]) -> None: num_query_heads = query.shape[1] num_kv_heads = value_cache.shape[1] head_size = value_cache.shape[2] @@ -102,15 +98,17 @@ def ref_single_query_cached_kv_attention( keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) - alibi_bias = None + bias = None if alibi_slopes is not None: # Create the ALiBi bias used in the paged attention kernel. position_ids = torch.arange(seq_len).int() alibi_bias = (position_ids - seq_len + 1).float() alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( 1, 1, -1) - - out = ref_masked_attention(q, keys, values, scale, alibi_bias) + bias = alibi_bias + if attn_bias is not None: + bias = attn_bias[i] if bias is None else bias + attn_bias[i] + out = ref_masked_attention(q, keys, values, scale, bias) out = out.view(num_query_heads, head_size) output[i].copy_(out, non_blocking=True) @@ -122,6 +120,7 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("use_alibi", USE_ALIBI) +@pytest.mark.parametrize("use_custom_attn_bias", USE_CUSTOM_ATTN_BIAS) @pytest.mark.parametrize("block_size", BLOCK_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @@ -134,6 +133,7 @@ def test_paged_attention( num_heads: Tuple[int, int], head_size: int, use_alibi: bool, + use_custom_attn_bias: bool, block_size: int, dtype: torch.dtype, kv_cache_dtype: str, @@ -153,7 +153,7 @@ def test_paged_attention( assert num_query_heads % num_kv_heads == 0 num_queries_per_kv = num_query_heads // num_kv_heads - alibi_slopes = None + alibi_slopes, attn_bias = None, None if use_alibi: alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) @@ -161,9 +161,30 @@ def test_paged_attention( seq_lens[-1] = MAX_SEQ_LEN max_seq_len = max(seq_lens) seq_lens = torch.tensor(seq_lens, dtype=torch.int) + attn_bias_list = None + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + if use_custom_attn_bias: + # NOTE (NickLucche) each sequence can have a different bias, + # depending on its len, but it *must* be padded to the block + # aligned max_seq_len and of type float32! + attn_bias_list = [ + torch.randn(num_query_heads, 1, seq_len, dtype=torch.float) + for seq_len in seq_lens + ] + block_aligned_max_seq_len = max_num_blocks_per_seq * block_size + attn_bias = torch.empty( + num_seqs, + num_query_heads, + 1, + block_aligned_max_seq_len, # padded dim + device=device, + dtype=torch.float) + + for i, (seq_len, bias) in enumerate(zip(seq_lens, attn_bias_list)): + # first `seq_len` entries of the bias matrix for each head/seq + attn_bias[i, :, :, :seq_len] = bias # Create the block tables. - max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables_lst: List[List[int]] = [] for _ in range(num_seqs): block_table = [ @@ -199,6 +220,7 @@ def test_paged_attention( block_size, max_seq_len, alibi_slopes, + attn_bias, kv_cache_dtype, k_scale, v_scale, @@ -207,7 +229,7 @@ def test_paged_attention( opcheck(torch.ops._C.paged_attention_v1, (output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), + attn_bias, kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0])) @@ -240,18 +262,20 @@ def test_paged_attention( block_size, max_seq_len, alibi_slopes, + attn_bias, kv_cache_dtype, k_scale, v_scale, ) - opcheck(torch.ops._C.paged_attention_v2, - (output, exp_sums, max_logits, tmp_output, query, - key_cache, value_cache, num_kv_heads, scale, block_tables, - seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), - cond=(head_size == HEAD_SIZES[0] - and block_size == BLOCK_SIZES[0])) + opcheck( + torch.ops._C.paged_attention_v2, + (output, exp_sums, max_logits, tmp_output, query, key_cache, + value_cache, num_kv_heads, scale, block_tables, seq_lens, + block_size, max_seq_len, alibi_slopes, attn_bias, + kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), + cond=(head_size == HEAD_SIZES[0] + and block_size == BLOCK_SIZES[0])) else: ops.paged_attention_rocm( @@ -305,17 +329,10 @@ def test_paged_attention( value_cache = dequantized_value_cache ref_output = torch.empty_like(query) - ref_single_query_cached_kv_attention( - ref_output, - query, - num_queries_per_kv, - key_cache, - value_cache, - block_tables, - seq_lens, - scale, - alibi_slopes, - ) + ref_single_query_cached_kv_attention(ref_output, query, num_queries_per_kv, + key_cache, value_cache, block_tables, + seq_lens, scale, alibi_slopes, + attn_bias_list) # NOTE(woosuk): Due to the kernel-level differences in the two # implementations, there is a small numerical difference in the two diff --git a/tests/kernels/test_blocksparse_attention.py b/tests/kernels/test_blocksparse_attention.py index fad342d1b5923..1cfba68483338 100644 --- a/tests/kernels/test_blocksparse_attention.py +++ b/tests/kernels/test_blocksparse_attention.py @@ -228,6 +228,7 @@ def test_paged_attention( block_size, max_seq_len, alibi_slopes, + None, # TODO add custom bias kv_cache_dtype, k_scale, v_scale, @@ -265,6 +266,7 @@ def test_paged_attention( block_size, max_seq_len, alibi_slopes, + None, kv_cache_dtype, k_scale, v_scale, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index d04cbbc0a9eed..cae9c5e31fa68 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -47,6 +47,7 @@ def paged_attention_v1( block_size: int, max_seq_len: int, alibi_slopes: Optional[torch.Tensor], + attn_bias: Optional[torch.Tensor], kv_cache_dtype: str, k_scale: float, v_scale: float, @@ -58,8 +59,8 @@ def paged_attention_v1( ) -> None: torch.ops._C.paged_attention_v1( out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, - seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, - k_scale, v_scale, tp_rank, blocksparse_local_blocks, + seq_lens, block_size, max_seq_len, alibi_slopes, attn_bias, + kv_cache_dtype, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step) @@ -79,6 +80,7 @@ def paged_attention_v2( block_size: int, max_seq_len: int, alibi_slopes: Optional[torch.Tensor], + attn_bias: Optional[torch.Tensor], kv_cache_dtype: str, k_scale: float, v_scale: float, @@ -91,7 +93,7 @@ def paged_attention_v2( torch.ops._C.paged_attention_v2( out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, - alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, + alibi_slopes, attn_bias, kv_cache_dtype, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, blocksparse_head_sliding_step) diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 9089db1126c94..0700f4195a634 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -439,6 +439,7 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, + None, # TODO support attn_bias layer._k_scale, layer._v_scale, tp_rank=self.tp_rank, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index e9f2808ff1674..477c52141020b 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -628,6 +628,7 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, + None, # TODO support attn_bias layer._k_scale, layer._v_scale, ) diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 38e27434dab2c..15732af7acad8 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -577,6 +577,7 @@ def forward( prefill_meta.context_lens_tensor, prefill_meta.max_query_len, self.alibi_slopes, + prefill_meta.attn_bias, self.sliding_window, layer._k_scale, layer._v_scale, @@ -605,6 +606,9 @@ def forward( self.num_kv_heads, self.scale, self.alibi_slopes, + # TODO (NickLucche) cross_attn_bias not needed for T5-like + # models, abstract bias selection if needed. + decode_meta.attn_bias, layer._k_scale, layer._v_scale, ) diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py index cbc6c74acf09a..5e4b1c8bc29e2 100644 --- a/vllm/attention/ops/ipex_attn.py +++ b/vllm/attention/ops/ipex_attn.py @@ -103,6 +103,7 @@ def forward_decode( block_size, max_context_len, alibi_slopes, + None, # TODO add custom bias kv_cache_dtype, k_scale, v_scale, diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 076f151ffcb61..bc651fa0dc326 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -95,6 +95,7 @@ def forward_decode( num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], + attn_bias: Optional[torch.Tensor], k_scale: float, v_scale: float, tp_rank: int = 0, @@ -140,6 +141,7 @@ def forward_decode( block_size, max_seq_len, alibi_slopes, + attn_bias, kv_cache_dtype, k_scale, v_scale, @@ -178,6 +180,7 @@ def forward_decode( block_size, max_seq_len, alibi_slopes, + attn_bias, kv_cache_dtype, k_scale, v_scale, @@ -203,11 +206,13 @@ def forward_prefix( context_lens: torch.Tensor, max_query_len: int, alibi_slopes: Optional[torch.Tensor], + attn_bias: Optional[torch.Tensor], sliding_window: Optional[int], k_scale: float, v_scale: float, ) -> torch.Tensor: output = torch.empty_like(query) + assert attn_bias is None, "Bias for prefix not yet enabled" context_attention_fwd( query, key,