Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Attention] MLA with chunked prefill #12639

Open
wants to merge 36 commits into
base: main
Choose a base branch
from

Conversation

LucasWilkinson
Copy link
Collaborator

@LucasWilkinson LucasWilkinson commented Feb 1, 2025

Need to do more benchmarking to see if this makes sense to be on by default in V0, but lays the groundwork for a V1 implementation. (#13111 may help performance)

lm_eval --model vllm --model_args pretrained=deepseek-ai/DeepSeek-V2-Lite-Chat,tensor_parallel_size=2,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=16384,enable_chunked_prefill=False --task gsm8k --num_fewshot=5 --limit 100

vllm (pretrained=deepseek-ai/DeepSeek-V2-Lite-Chat,tensor_parallel_size=2,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=16384,enable_chunked_prefill=False), gen_kwargs: (None), limit: 100.0, num_fewshot: 5, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.66|±  |0.0476|
|     |       |strict-match    |     5|exact_match|↑  | 0.66|±  |0.0476|


lm_eval --model vllm --model_args pretrained=deepseek-ai/DeepSeek-V2-Lite-Chat,tensor_parallel_size=2,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=16384,enable_chunked_prefill=True --task gsm8k --num_fewshot=5 --limit 100


vllm (pretrained=deepseek-ai/DeepSeek-V2-Lite-Chat,tensor_parallel_size=2,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=16384,enable_chunked_prefill=True), gen_kwargs: (None), limit: 100.0, num_fewshot: 5, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.66|±  |0.0476|
|     |       |strict-match    |     5|exact_match|↑  | 0.66|±  |0.0476|

Shout-out to @pathorn for assisting with hardening this PR

Future work:

Copy link

github-actions bot commented Feb 1, 2025

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@LucasWilkinson LucasWilkinson changed the title [Attention] WIP MLA with chunked prefill [WIP][Attention] WIP MLA with chunked prefill Feb 1, 2025
Copy link

mergify bot commented Feb 6, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 6, 2025
@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/chunked-mla branch 2 times, most recently from 463e453 to c542cc4 Compare February 6, 2025 05:24
@mergify mergify bot added v1 and removed needs-rebase labels Feb 6, 2025
@LucasWilkinson LucasWilkinson changed the title [WIP][Attention] WIP MLA with chunked prefill [Attention] WIP MLA with chunked prefill Feb 6, 2025
@LucasWilkinson LucasWilkinson marked this pull request as ready for review February 6, 2025 05:49
Copy link

mergify bot commented Feb 7, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 7, 2025
@mergify mergify bot removed the needs-rebase label Feb 7, 2025
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
@tlrmchlsmth tlrmchlsmth removed the v1 label Feb 14, 2025
@mergify mergify bot added the v1 label Feb 14, 2025
@tlrmchlsmth
Copy link
Collaborator

Removed the V1 tag because although it does move some code out of the v1 flash attention backend, I didn't want anyone to get the impression that this PR adds support for MLA

Comment on lines +694 to +703
TORCH_CHECK(src_cache.device() == dst.device(),
"src_cache and dst must be on the same device");
TORCH_CHECK(src_cache.device() == block_table.device(),
"src_cache and block_table must be on the same device");
TORCH_CHECK(src_cache.device() == cu_seq_lens.device(),
"src_cache and cu_seq_lens must be on the same device");
if (seq_starts.has_value()) {
TORCH_CHECK(src_cache.device() == seq_starts.value().device(),
"src_cache and seq_starts must be on the same device");
}
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth Feb 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Future work: this is generally useful across all kernels and we should probably factor these checks out into a helper function. Something like this:

namespace detail {
// Get device from either Tensor or optional<Tensor>
inline std::optional<torch::Device> get_device(const torch::Tensor& tensor) {
    return tensor.device();
}

inline std::optional<torch::Device> get_device(const std::optional<torch::Tensor>& maybe_tensor) {
    return maybe_tensor.has_value() ? std::optional(maybe_tensor.value().device()) 
                                  : std::nullopt;
}
} // namespace detail

template <typename First, typename... Rest>
void check_same_device(const First& first, const Rest&... rest) {
    auto first_device = detail::get_device(first);
    if (!first_device.has_value()) return;
    
    ([&](const auto& tensor) {
        auto device = detail::get_device(tensor);
        if (device.has_value()) {
            TORCH_CHECK(*device == *first_device, "All tensors must be on the same device");
        }
    }(rest), ...);
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, that would be nice, ill work on a separated PR

Comment on lines 416 to 429
if self.chunked_prefill_enabled:
if not hasattr(self, "chunked_prefill_workspace"):
# not self.runner.device does not return the correct device
# for this process, (init_device sets the correct device but
# only on the Worker). The only way Ive figured out to get the
# correct device is to allocate the workspace on the first call
# to begin_forward and use the device of the input tokens
assert model_input.input_tokens is not None
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
self.model_config.get_head_size()),
dtype=self.model_config.dtype,
device=model_input.input_tokens.device,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we are allocating workspace used for the decompressed KV cache. This is going to happen during the profile run -- @WoosukKwon do you see any issues with doing that here?

To me it seems fine for lack of an official way to allocate workspace like this. If we leave it here, I would tighten up the comment a bit (remove the 1st person)

Signed-off-by: Lucas Wilkinson <[email protected]>
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks really good, great work. Concerns are mainly on the memory/workspace/profile run issues which we all know are really hard to get right:

  • I still have a question about where the best spot to allocate the workspace is, but I think what's happening here is OK. Unless people spot issues with this that I don't see, I don't think this should be a blocker.
  • I wonder if it would be better to detect if we are in the profile run and allocate temporary tensors of size equal to the upper limit on the workspace required, instead of what we are doing now. It sounds like there might be an edge case where we run out of memory, and if so we should address before landing

Signed-off-by: Lucas Wilkinson <[email protected]>
@LucasWilkinson
Copy link
Collaborator Author

@tlrmchlsmth

I wonder if it would be better to detect if we are in the profile run and allocate temporary tensors of size equal to the upper limit on the workspace required, instead of what we are doing now. It sounds like there might be an edge case where we run out of memory, and if so we should address before landing

this should be addressed by: 1c59597

without this commit I get:

model weights take 84.11GiB; non_torch_memory takes 5.13GiB; PyTorch activation peak memory takes 0.19GiB; the rest of the memory reserved for KV Cache is 36.41GiB.

with it I get:

model weights take 84.11GiB; non_torch_memory takes 5.13GiB; PyTorch activation peak memory takes 1.17GiB; the rest of the memory reserved for KV Cache is 35.42GiB.

Signed-off-by: Lucas Wilkinson <[email protected]>
Copy link

mergify bot commented Feb 15, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Feb 15, 2025
@LucasWilkinson
Copy link
Collaborator Author

LucasWilkinson commented Feb 15, 2025

NOTE: @pathorn found a bug when stress testing R1, will notify here when resolved

https://vllm-dev.slack.com/archives/C08AD2B5HH8/p1739521144253459?thread_ts=1739486497.566799&cid=C08AD2B5HH8

Edit: should be resolved by 920ecc6#diff-00753a3c1f378f8b8c60e9eb10b94c3cbbfcea74fca6e66712e5d4ae360f6741

Comment on lines 1433 to 1443
if attn_metadata.is_profile_run and \
attn_metadata.chunked_prefill_workspace is not None:
# During the profile run try to simulate to worse case output size
# for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
# since this can be large
_ = torch.empty(
(attn_metadata.chunked_prefill_workspace.shape[0],
self.num_heads, self.qk_nope_head_dim + self.v_head_dim),
device=k_c_normed.device,
dtype=k_c_normed.dtype,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great - definitely feel good about the profile_run now

Signed-off-by: Lucas Wilkinson <[email protected]>
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉

@mergify mergify bot removed the needs-rebase label Feb 18, 2025
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 18, 2025
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) February 18, 2025 13:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants