-
-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Attention] MLA with chunked prefill #12639
base: main
Are you sure you want to change the base?
[Attention] MLA with chunked prefill #12639
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
f939824
to
77be9af
Compare
77be9af
to
bf6a400
Compare
This pull request has merge conflicts that must be resolved before it can be |
463e453
to
c542cc4
Compare
This pull request has merge conflicts that must be resolved before it can be |
727b265
to
c2d5468
Compare
7bffc5c
to
de3474d
Compare
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Removed the V1 tag because although it does move some code out of the v1 flash attention backend, I didn't want anyone to get the impression that this PR adds support for MLA |
TORCH_CHECK(src_cache.device() == dst.device(), | ||
"src_cache and dst must be on the same device"); | ||
TORCH_CHECK(src_cache.device() == block_table.device(), | ||
"src_cache and block_table must be on the same device"); | ||
TORCH_CHECK(src_cache.device() == cu_seq_lens.device(), | ||
"src_cache and cu_seq_lens must be on the same device"); | ||
if (seq_starts.has_value()) { | ||
TORCH_CHECK(src_cache.device() == seq_starts.value().device(), | ||
"src_cache and seq_starts must be on the same device"); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Future work: this is generally useful across all kernels and we should probably factor these checks out into a helper function. Something like this:
namespace detail {
// Get device from either Tensor or optional<Tensor>
inline std::optional<torch::Device> get_device(const torch::Tensor& tensor) {
return tensor.device();
}
inline std::optional<torch::Device> get_device(const std::optional<torch::Tensor>& maybe_tensor) {
return maybe_tensor.has_value() ? std::optional(maybe_tensor.value().device())
: std::nullopt;
}
} // namespace detail
template <typename First, typename... Rest>
void check_same_device(const First& first, const Rest&... rest) {
auto first_device = detail::get_device(first);
if (!first_device.has_value()) return;
([&](const auto& tensor) {
auto device = detail::get_device(tensor);
if (device.has_value()) {
TORCH_CHECK(*device == *first_device, "All tensors must be on the same device");
}
}(rest), ...);
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed, that would be nice, ill work on a separated PR
vllm/attention/backends/mla/utils.py
Outdated
if self.chunked_prefill_enabled: | ||
if not hasattr(self, "chunked_prefill_workspace"): | ||
# not self.runner.device does not return the correct device | ||
# for this process, (init_device sets the correct device but | ||
# only on the Worker). The only way Ive figured out to get the | ||
# correct device is to allocate the workspace on the first call | ||
# to begin_forward and use the device of the input tokens | ||
assert model_input.input_tokens is not None | ||
self.chunked_prefill_workspace = torch.empty( | ||
(self.chunked_prefill_workspace_size, | ||
self.model_config.get_head_size()), | ||
dtype=self.model_config.dtype, | ||
device=model_input.input_tokens.device, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we are allocating workspace used for the decompressed KV cache. This is going to happen during the profile run -- @WoosukKwon do you see any issues with doing that here?
To me it seems fine for lack of an official way to allocate workspace like this. If we leave it here, I would tighten up the comment a bit (remove the 1st person)
Signed-off-by: Lucas Wilkinson <[email protected]>
50bc858
to
3c800bb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks really good, great work. Concerns are mainly on the memory/workspace/profile run issues which we all know are really hard to get right:
- I still have a question about where the best spot to allocate the workspace is, but I think what's happening here is OK. Unless people spot issues with this that I don't see, I don't think this should be a blocker.
- I wonder if it would be better to detect if we are in the profile run and allocate temporary tensors of size equal to the upper limit on the workspace required, instead of what we are doing now. It sounds like there might be an edge case where we run out of memory, and if so we should address before landing
fa5d62f
to
9c42fb0
Compare
Signed-off-by: Lucas Wilkinson <[email protected]>
9c42fb0
to
a79ee4c
Compare
Signed-off-by: Lucas Wilkinson <[email protected]>
8e7bcae
to
1c59597
Compare
this should be addressed by: 1c59597 without this commit I get:
with it I get:
|
Signed-off-by: Lucas Wilkinson <[email protected]>
This pull request has merge conflicts that must be resolved before it can be |
NOTE: @pathorn found a bug when stress testing R1, will notify here when resolved Edit: should be resolved by 920ecc6#diff-00753a3c1f378f8b8c60e9eb10b94c3cbbfcea74fca6e66712e5d4ae360f6741 |
if attn_metadata.is_profile_run and \ | ||
attn_metadata.chunked_prefill_workspace is not None: | ||
# During the profile run try to simulate to worse case output size | ||
# for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context` | ||
# since this can be large | ||
_ = torch.empty( | ||
(attn_metadata.chunked_prefill_workspace.shape[0], | ||
self.num_heads, self.qk_nope_head_dim + self.v_head_dim), | ||
device=k_c_normed.device, | ||
dtype=k_c_normed.dtype, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great - definitely feel good about the profile_run now
Signed-off-by: Lucas Wilkinson <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🎉
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Need to do more benchmarking to see if this makes sense to be on by default in V0, but lays the groundwork for a V1 implementation. (#13111 may help performance)
Shout-out to @pathorn for assisting with hardening this PR
Future work:
self.kv_b_proj(kv_c_normed)
in the profile run