Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Do not merge] vllm layout varlen #106

Draft
wants to merge 17 commits into
base: ck_improve_v0.1.3
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/composable_kernel
Submodule composable_kernel updated 191 files
51 changes: 44 additions & 7 deletions csrc/flash_attn_ck/flash_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,65 @@
#include "flash_common.hpp"

namespace flash {
int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
int override_num_splits_if_necessary(int batch,
int nhead,
int max_seqlen_q,
int hdim_q,
int hdim_v,
float p_drop,
bool is_prefill,
int num_splits)
{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
{
return num_splits;
}

hipDeviceProp_t props{};
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess)
{
return num_splits;
}

// TODO - tile size should match the TileFmhaShape, hardcode for now
const int kM0 = 128;
const int kN1 = hdim_v;
const int kM0 = [&] {
// get kM0 for prefill phase
if(is_prefill)
{
return 128;
}

// get kM0 for decode phase
/// TODO: take dtype=fp8/bf8 into consideration
const std::map<int, int> hdim_to_m0 = {
{32, 32},
{64, 64},
// {96, 64},
{128, 64},
{256, 64},
};

for(auto [hdim, m0] : hdim_to_m0)
{
if(hdim_q <= hdim && hdim_v <= hdim)
{
return m0;
}
}

return 64; // meet unsupported hdim_q/hdim_v
}();
// const int kN1 = hdim_v;

const int num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
const int num_n_blocks = (hdim_v + kN1 - 1) / kN1;
// const int num_n_blocks = (hdim_v + kN1 - 1) / kN1; // always 1

if(num_splits < 1 && p_drop == 0.0f)
return num_splits_heuristic_ck(
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128);
{
return num_splits_heuristic_ck(batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 8);
}

return num_splits;
}
Expand Down
63 changes: 35 additions & 28 deletions csrc/flash_attn_ck/flash_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,42 +35,49 @@ inline __global__ void ParsePhiloxCudaState(at::PhiloxCudaState arg, uint64_t* r
}
}

inline int num_splits_heuristic_ck(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
inline int num_splits_heuristic_ck(int batch_nhead_mblocks, int num_SMs, int max_splits)
{
// If we have enough to almost fill the SMs, then just use 1 split
if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }
max_splits = std::min({max_splits, num_SMs, num_n_blocks});
if(batch_nhead_mblocks >= 0.8f * num_SMs)
{
return 1;
}

max_splits = std::min({max_splits, num_SMs});

constexpr std::array<int, 5> num_splits_array = {1, 2, 4, 8, 16};

float max_efficiency = 0.f;
std::vector<float> efficiency;
efficiency.reserve(max_splits);
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
// (i.e. it's 11 splits anyway).
// So we check if the number of blocks per split is the same as the previous num_splits.
auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
};
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
if (!is_split_eligible(num_splits)) {
efficiency.push_back(0.f);
} else {
float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
float eff = n_waves / ceil(n_waves);
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
if (eff > max_efficiency) { max_efficiency = eff; }
efficiency.push_back(eff);
std::array<float, num_splits_array.size()> efficiency;

for(size_t idx = 0; idx < num_splits_array.size() && num_splits_array[idx] <= max_splits; ++idx)
{
float n_blocks = float(batch_nhead_mblocks * num_splits_array[idx]) / num_SMs;
float eff = n_blocks / std::ceil(n_blocks);

if(eff > max_efficiency)
{
max_efficiency = eff;
}
efficiency[idx] = eff;
}
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
if (!is_split_eligible(num_splits)) { continue; }
if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
// printf("num_splits chosen = %d\n", num_splits);
return num_splits;
for(size_t idx = 0; idx < num_splits_array.size() && num_splits_array[idx] <= max_splits; ++idx)
{
if(efficiency[idx] >= 0.85 * max_efficiency)
{
return num_splits_array[idx];
}
}
return 1;
}

int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits);
int override_num_splits_if_necessary(int batch,
int nhead,
int max_seqlen_q,
int hdim_q,
int hdim_v,
float p_drop,
bool is_prefill,
int num_splits);

} // namespace flash
5 changes: 4 additions & 1 deletion csrc/flash_attn_ck/mha_fwd_kvcache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_siz
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int num_splits)
{
TORCH_CHECK(false, "vllm layout does not support mha_fwd_kvcache for now");

auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
"FlashAttention only support fp16 and bf16 data type");
Expand Down Expand Up @@ -471,7 +473,8 @@ mha_fwd_kvcache(at::Tensor &q, // batch_siz
TORCH_CHECK(cache_batch_idx.scalar_type() == torch::kInt32, "cache_batch_idx must have dtype int32");
}

num_splits = flash::override_num_splits_if_necessary(batch_size, num_heads, seqlen_q, head_size_8x, 0, num_splits);
num_splits = flash::override_num_splits_if_necessary(batch_size, num_heads, seqlen_q, head_size_8x, head_size_8x,
/*p_drop=*/0, /*is_prefill=*/false, num_splits);
TORCH_CHECK(num_splits > 0, "num_splits should greater than 0");
TORCH_CHECK(num_splits <= 128, "num_splits greater than 128 is not supported");

Expand Down
31 changes: 17 additions & 14 deletions csrc/flash_attn_ck/mha_varlen_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ fmha_fwd_splitkv_traits get_ck_fmha_varlen_fwd_splitkv_traits(const mask_info &m
head_size,
dtype,
true, // is_group_mode
true, // is_v_rowmajor
false, // is_v_rowmajor
mask.type,
enable_alibi ? bias_enum::alibi : bias_enum::no_bias,
has_lse,
Expand Down Expand Up @@ -183,8 +183,8 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse,
at::Tensor out_acc)
{
// q: (total_q, nheads, d)
// k: (num_blocks, page_block_size, num_heads_k, d)
// v: (num_blocks, page_block_size, num_heads_k, d)
// k: (num_blocks, num_heads_k, d / 8, page_block_size, 8)
// v: (num_blocks, num_heads_k, d, page_block_size)
// o: (total_q, nheads, d)

// alibi_slopes:(batch_size, nheads) or (nhead)
Expand Down Expand Up @@ -241,12 +241,12 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse,
args.nhead_stride_q = q.stride(1);

args.batch_stride_k = k.stride(0);
args.stride_k = k.stride(1);
args.nhead_stride_k = k.stride(2);
args.nhead_stride_k = k.stride(1);
args.stride_k = k.stride(2);

args.batch_stride_v = v.stride(0);
args.stride_v = v.stride(1);
args.nhead_stride_v = v.stride(2);
args.nhead_stride_v = v.stride(1);
args.stride_v = v.stride(2);

args.batch_stride_o = 0;
args.stride_o = out.stride(0);
Expand Down Expand Up @@ -292,8 +292,8 @@ fmha_fwd_splitkv_args get_ck_fmha_varlen_fwd_splitkv_args(bool has_lse,

std::vector<at::Tensor>
mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x num_heads_k x head_size / 8 x page_block_size x 8 if there's a block_table.
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i or num_blocks x num_heads_k x page_block_size x head_size if there's a block_table.
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
Expand Down Expand Up @@ -335,6 +335,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si
CHECK_DEVICE(block_table);
TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32");
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
CHECK_CONTIGUOUS(k);
CHECK_CONTIGUOUS(v);
}

TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
Expand All @@ -348,11 +350,11 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si
const int batch_size = cu_seqlens_q.numel() - 1;
int num_heads = sizes[1];
const int head_size = sizes[2];
const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
const int num_heads_k = k.size(1);

const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
const int num_blocks = !paged_KV ? 0 : k.size(0);
const int page_block_size = !paged_KV ? 1 : k.size(1);
const int page_block_size = !paged_KV ? 1 : k.size(3);
TORCH_CHECK(!paged_KV || page_block_size % 128 == 0, "Paged KV cache block size must be divisible by 128");

if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
Expand Down Expand Up @@ -394,8 +396,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si
CHECK_SHAPE(k, total_k, num_heads_k, head_size);
CHECK_SHAPE(v, total_k, num_heads_k, head_size);
} else {
CHECK_SHAPE(k, num_blocks, page_block_size, num_heads_k, head_size);
CHECK_SHAPE(v, num_blocks, page_block_size, num_heads_k, head_size);
CHECK_SHAPE(k, num_blocks, num_heads_k, head_size / 8, page_block_size, 8);
CHECK_SHAPE(v, num_blocks, num_heads_k, head_size, page_block_size);
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
}

Expand Down Expand Up @@ -444,7 +446,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si
}

int num_splits = 0;
num_splits = flash::override_num_splits_if_necessary(batch_size, num_heads, max_seqlen_q, head_size, 0, num_splits);
num_splits = flash::override_num_splits_if_necessary(batch_size, num_heads, max_seqlen_q, head_size, head_size,
/*p_drop=*/0, /*is_prefill=*/true, num_splits);
TORCH_CHECK(num_splits > 0, "num_splits should greater than 0");
TORCH_CHECK(num_splits <= 128, "num_splits greater than 128 is not supported");

Expand Down
Loading