Skip to content

Commit

Permalink
enforce last dim of attn bias to be block aligned
Browse files Browse the repository at this point in the history
Signed-off-by: NickLucche <[email protected]>
  • Loading branch information
NickLucche committed Jan 9, 2025
1 parent ac6bf63 commit 5c47f43
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 24 deletions.
17 changes: 10 additions & 7 deletions csrc/attention/paged_attention_v1.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,17 @@ void paged_attention_v1_launcher(
const float* attn_bias_ptr =
attn_bias ? reinterpret_cast<const float*>(attn_bias.value().data_ptr())
: nullptr;
const int padded_max_seq_len =
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
if (attn_bias_ptr) {
const torch::Tensor& abias = attn_bias.value();
TORCH_CHECK(abias.dtype() == torch::kFloat32,
"Unsupported bias dtype: ", abias.dtype());
TORCH_CHECK(abias.size(abias.dim() - 1) == max_seq_len,
"Unexpected attn_bias shape: ", abias.sizes());
TORCH_CHECK(abias.size(abias.dim() - 1) == padded_max_seq_len,
"The last dimension of the attention bias must "
"match the block-aligned maximum sequence length (",
padded_max_seq_len,
"). However, the given dimensions are: ", abias.sizes());
}
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
Expand All @@ -92,13 +97,11 @@ void paged_attention_v1_launcher(
int* seq_lens_ptr = seq_lens.data_ptr<int>();

constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_seq_len =
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
int logits_size = padded_max_seq_len * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
const int logits_size = padded_max_seq_len * sizeof(float);
const int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
// Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len
// Keep that in sync with the logic here!
int shared_mem_size = std::max(logits_size, outputs_size);
const int shared_mem_size = std::max(logits_size, outputs_size);

dim3 grid(num_heads, num_seqs, 1);
dim3 block(NUM_THREADS);
Expand Down
21 changes: 12 additions & 9 deletions csrc/attention/paged_attention_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,17 @@ void paged_attention_v2_launcher(
const float* attn_bias_ptr =
attn_bias ? reinterpret_cast<const float*>(attn_bias.value().data_ptr())
: nullptr;
const int padded_max_seq_len =
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
if (attn_bias_ptr) {
const torch::Tensor& abias = attn_bias.value();
TORCH_CHECK(abias.dtype() == torch::kFloat32,
"Unsupported bias dtype: ", abias.dtype());
TORCH_CHECK(abias.size(abias.dim() - 1) == max_seq_len,
"Unexpected attn_bias shape: ", abias.sizes());
TORCH_CHECK(abias.size(abias.dim() - 1) == padded_max_seq_len,
"The last dimension of the attention bias must "
"match the block-aligned maximum sequence length (",
padded_max_seq_len,
"). However, the given dimensions are: ", abias.sizes());
}
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
Expand All @@ -97,18 +102,16 @@ void paged_attention_v2_launcher(
int* seq_lens_ptr = seq_lens.data_ptr<int>();

constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
int padded_max_seq_len =
DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE;
int logits_size = PARTITION_SIZE * sizeof(float);
int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);
const int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
const int logits_size = PARTITION_SIZE * sizeof(float);
const int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float);

// For paged attention v2 kernel.
dim3 grid(num_heads, num_seqs, max_num_partitions);
int shared_mem_size = std::max(logits_size, outputs_size);
const int shared_mem_size = std::max(logits_size, outputs_size);
// For paged attention v2 reduce kernel.
dim3 reduce_grid(num_heads, num_seqs);
int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
const int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);

dim3 block(NUM_THREADS);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
Expand Down
19 changes: 11 additions & 8 deletions tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,26 +162,29 @@ def test_paged_attention(
max_seq_len = max(seq_lens)
seq_lens = torch.tensor(seq_lens, dtype=torch.int)
attn_bias_list = None
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
if use_custom_attn_bias:
# NOTE (NickLucche) each sequence can have a different bias,
# depending on its len, but it *must* be float (f32)!
# depending on its len, but it *must* be padded to the block
# aligned max_seq_len and of type float32!
attn_bias_list = [
torch.randn(num_query_heads, 1, seq_len, dtype=torch.float)
for seq_len in seq_lens
]
attn_bias = torch.empty(num_seqs,
num_query_heads,
1,
max_seq_len,
device=device,
dtype=torch.float)
block_aligned_max_seq_len = max_num_blocks_per_seq * block_size
attn_bias = torch.empty(
num_seqs,
num_query_heads,
1,
block_aligned_max_seq_len, # padded dim
device=device,
dtype=torch.float)

for i, (seq_len, bias) in enumerate(zip(seq_lens, attn_bias_list)):
# first `seq_len` entries of the bias matrix for each head/seq
attn_bias[i, :, :, :seq_len] = bias

# Create the block tables.
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size
block_tables_lst: List[List[int]] = []
for _ in range(num_seqs):
block_table = [
Expand Down

0 comments on commit 5c47f43

Please sign in to comment.