From 5c47f43fc80e78cdd6e91a8ed3449ec7eac240e5 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Fri, 27 Dec 2024 17:21:43 +0000 Subject: [PATCH] enforce last dim of attn bias to be block aligned Signed-off-by: NickLucche --- csrc/attention/paged_attention_v1.cu | 17 ++++++++++------- csrc/attention/paged_attention_v2.cu | 21 ++++++++++++--------- tests/kernels/test_attention.py | 19 +++++++++++-------- 3 files changed, 33 insertions(+), 24 deletions(-) diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu index cbe5d1dd6f3f6..0b04b55a4e13a 100644 --- a/csrc/attention/paged_attention_v1.cu +++ b/csrc/attention/paged_attention_v1.cu @@ -77,12 +77,17 @@ void paged_attention_v1_launcher( const float* attn_bias_ptr = attn_bias ? reinterpret_cast(attn_bias.value().data_ptr()) : nullptr; + const int padded_max_seq_len = + DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; if (attn_bias_ptr) { const torch::Tensor& abias = attn_bias.value(); TORCH_CHECK(abias.dtype() == torch::kFloat32, "Unsupported bias dtype: ", abias.dtype()); - TORCH_CHECK(abias.size(abias.dim() - 1) == max_seq_len, - "Unexpected attn_bias shape: ", abias.sizes()); + TORCH_CHECK(abias.size(abias.dim() - 1) == padded_max_seq_len, + "The last dimension of the attention bias must " + "match the block-aligned maximum sequence length (", + padded_max_seq_len, + "). However, the given dimensions are: ", abias.sizes()); } T* out_ptr = reinterpret_cast(out.data_ptr()); T* query_ptr = reinterpret_cast(query.data_ptr()); @@ -92,13 +97,11 @@ void paged_attention_v1_launcher( int* seq_lens_ptr = seq_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_seq_len = - DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; - int logits_size = padded_max_seq_len * sizeof(float); - int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + const int logits_size = padded_max_seq_len * sizeof(float); + const int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len // Keep that in sync with the logic here! - int shared_mem_size = std::max(logits_size, outputs_size); + const int shared_mem_size = std::max(logits_size, outputs_size); dim3 grid(num_heads, num_seqs, 1); dim3 block(NUM_THREADS); diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index 2b25a6afe765f..5eeba75d5cf1c 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -79,12 +79,17 @@ void paged_attention_v2_launcher( const float* attn_bias_ptr = attn_bias ? reinterpret_cast(attn_bias.value().data_ptr()) : nullptr; + const int padded_max_seq_len = + DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; if (attn_bias_ptr) { const torch::Tensor& abias = attn_bias.value(); TORCH_CHECK(abias.dtype() == torch::kFloat32, "Unsupported bias dtype: ", abias.dtype()); - TORCH_CHECK(abias.size(abias.dim() - 1) == max_seq_len, - "Unexpected attn_bias shape: ", abias.sizes()); + TORCH_CHECK(abias.size(abias.dim() - 1) == padded_max_seq_len, + "The last dimension of the attention bias must " + "match the block-aligned maximum sequence length (", + padded_max_seq_len, + "). However, the given dimensions are: ", abias.sizes()); } T* out_ptr = reinterpret_cast(out.data_ptr()); float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); @@ -97,18 +102,16 @@ void paged_attention_v2_launcher( int* seq_lens_ptr = seq_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); - int padded_max_seq_len = - DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; - int logits_size = PARTITION_SIZE * sizeof(float); - int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + const int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); + const int logits_size = PARTITION_SIZE * sizeof(float); + const int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // For paged attention v2 kernel. dim3 grid(num_heads, num_seqs, max_num_partitions); - int shared_mem_size = std::max(logits_size, outputs_size); + const int shared_mem_size = std::max(logits_size, outputs_size); // For paged attention v2 reduce kernel. dim3 reduce_grid(num_heads, num_seqs); - int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); + const int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); dim3 block(NUM_THREADS); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 0fb374266b628..b9cfe6437183f 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -162,26 +162,29 @@ def test_paged_attention( max_seq_len = max(seq_lens) seq_lens = torch.tensor(seq_lens, dtype=torch.int) attn_bias_list = None + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size if use_custom_attn_bias: # NOTE (NickLucche) each sequence can have a different bias, - # depending on its len, but it *must* be float (f32)! + # depending on its len, but it *must* be padded to the block + # aligned max_seq_len and of type float32! attn_bias_list = [ torch.randn(num_query_heads, 1, seq_len, dtype=torch.float) for seq_len in seq_lens ] - attn_bias = torch.empty(num_seqs, - num_query_heads, - 1, - max_seq_len, - device=device, - dtype=torch.float) + block_aligned_max_seq_len = max_num_blocks_per_seq * block_size + attn_bias = torch.empty( + num_seqs, + num_query_heads, + 1, + block_aligned_max_seq_len, # padded dim + device=device, + dtype=torch.float) for i, (seq_len, bias) in enumerate(zip(seq_lens, attn_bias_list)): # first `seq_len` entries of the bias matrix for each head/seq attn_bias[i, :, :, :seq_len] = bias # Create the block tables. - max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size block_tables_lst: List[List[int]] = [] for _ in range(num_seqs): block_table = [