Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]: Tree attention about Speculative Decoding #3960

Closed
yukavio opened this issue Apr 10, 2024 · 13 comments
Closed

[Feature]: Tree attention about Speculative Decoding #3960

yukavio opened this issue Apr 10, 2024 · 13 comments

Comments

@yukavio
Copy link

yukavio commented Apr 10, 2024

🚀 The feature, motivation and pitch

I want to implement tree attention for vllm mentioned in RoadMap. But I don’t know whether I should implement it based on paged-attention kernel implemented in vllm or FlashInfer due to I found we plan to replace this kernel in this PR.

Alternatives

No response

Additional context

No response

@cadedaniel
Copy link
Collaborator

cadedaniel commented Apr 10, 2024

Thanks for your interest in contributing! FYI tree attention is a bit complicated to implement with non-contiguous KV cache, since intra-block attention masking has not been implemented anywhere AFAIK. We can get around this by limiting vLLM to block size of 1, but this makes it difficult to optimize latency of verification as we limit the allowed vLLM configuration space.

The way I'd recommend going about this is to implement intra-block attention masking first, then integrate it with vLLM. This is the surefire way to obtain the best latency reduction possible in vLLM. The steps as follows:

  • (Kernel) Support attention mask inside single block (M)
  • (Worker) Support attention mask in Worker API (S-M)
  • (Spec Decode Framework) Propose, score, and verify top-k candidates (M) (e.g. implement replacement for this)
  • (Spec Decode Framework) Defragment accepted KV blocks (S-M)

After the remaining open sourcing work is complete, I'll add some documentation for this.

More background information here: https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit#heading=h.kk7dq05lc6q8

@reyoung
Copy link

reyoung commented Apr 10, 2024

@cadedaniel @yukavio

Tree attention mechanisms can also be utilized to generate multiple outcomes from the same prompt by varying the seeds.

This approach is an effective strategy to ensure the stability of results produced by Large Language Models (LLMs). For instance, when employing an LLM as a scoring tool to derive metrics, one could sample the LLM's outputs multiple times. By averaging these samples, a more reliable result can be obtained.


This feature might become available following the implementation of tree attention mechanisms.

@yukavio
Copy link
Author

yukavio commented Apr 10, 2024

@cadedaniel
Thanks for your reply. I have read your document and it seems that the key to the problem is that each token in the score phase requires a loop and calculation of the entire kv-cache.
I think this problem can be solved by storing all pre-score tokens for a certain seq under the same adjacent address, instead of treating them as different seqs after expansion. In this way, we can perform calculations efficiently through tensor-core with a specific attention mask.
But in this way, we should organize the pre-score token in one sequence (left in img) instead of multiple sequences (right in img).
image
If you think this way of organizing pre-score tokens is appropriate, I can implement the tensor-core cuda kernel with tree attention mask.

@cadedaniel
Copy link
Collaborator

@yukavio you should talk with @LiuXiaoxuanPKU , who is adding MQA scoring to vLLM

@jon-chuang
Copy link
Contributor

@cadedaniel what is the status of this?

@jon-chuang
Copy link
Contributor

I guess it is blocked on this? #5691

@jon-chuang
Copy link
Contributor

jon-chuang commented Aug 6, 2024

Anw, for reference, this is what people refer to by tree attention, right? https://arxiv.org/pdf/2305.09781

sounds like a scheduling nightmare

@cadedaniel
Copy link
Collaborator

cadedaniel commented Aug 6, 2024

Yep, it is blocked on #5691. https://arxiv.org/pdf/2305.09781 was the first paper discussing this approach to scoring, although Medusa did a better job in popularizing it.

It is not a scheduling nightmare; the scheduler today with lookahead scheduling already supports tree attention. The complexity is mapping the linear lookahead blocks to the tree structure such that 1) the attention operation has correct causal masks and 2) the linear lookahead blocks are defragmented after the accepted tokens are determined.

See https://docs.google.com/document/d/1Z9TvqzzBPnh5WHcRwjvK2UEeFeq5zMZb5mFE8jR0HCs/edit#heading=h.1fjfb0donq5a for more information on lookahead scheduling.

edit: you can implement tree attention in pure pytorch without #5691 but it will have additional overheads.

@jon-chuang
Copy link
Contributor

jon-chuang commented Aug 6, 2024

So it is MQA with a per-sequence mask on the same block (or 2 blocks if crossing blocks) while speculating?

So at the kernel level one should compact the non-accepted tokens away after verification, right? And that would be the thrust of #5691?

@cadedaniel
Copy link
Collaborator

Yep, although you can do compaction in PyTorch without custom kernel

@Siegfried-qgf
Copy link

@cadedaniel Thanks for your reply. I have read your document and it seems that the key to the problem is that each token in the score phase requires a loop and calculation of the entire kv-cache. I think this problem can be solved by storing all pre-score tokens for a certain seq under the same adjacent address, instead of treating them as different seqs after expansion. In this way, we can perform calculations efficiently through tensor-core with a specific attention mask. But in this way, we should organize the pre-score token in one sequence (left in img) instead of multiple sequences (right in img). image If you think this way of organizing pre-score tokens is appropriate, I can implement the tensor-core cuda kernel with tree attention mask.

Nice idea. How is your progress? Are there any difficulties in the implementation process?

Copy link

github-actions bot commented Dec 6, 2024

This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!

@github-actions github-actions bot added the stale label Dec 6, 2024
Copy link

github-actions bot commented Jan 6, 2025

This issue has been automatically closed due to inactivity. Please feel free to reopen if you feel it is still relevant. Thank you!

@github-actions github-actions bot closed this as not planned Won't fix, can't repro, duplicate, stale Jan 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants