-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature]: Tree attention about Speculative Decoding #3960
Comments
Thanks for your interest in contributing! FYI tree attention is a bit complicated to implement with non-contiguous KV cache, since intra-block attention masking has not been implemented anywhere AFAIK. We can get around this by limiting vLLM to block size of 1, but this makes it difficult to optimize latency of verification as we limit the allowed vLLM configuration space. The way I'd recommend going about this is to implement intra-block attention masking first, then integrate it with vLLM. This is the surefire way to obtain the best latency reduction possible in vLLM. The steps as follows:
After the remaining open sourcing work is complete, I'll add some documentation for this. More background information here: https://docs.google.com/document/d/1T-JaS2T1NRfdP51qzqpyakoCXxSXTtORppiwaj5asxA/edit#heading=h.kk7dq05lc6q8 |
Tree attention mechanisms can also be utilized to generate multiple outcomes from the same prompt by varying the seeds. This approach is an effective strategy to ensure the stability of results produced by Large Language Models (LLMs). For instance, when employing an LLM as a scoring tool to derive metrics, one could sample the LLM's outputs multiple times. By averaging these samples, a more reliable result can be obtained. This feature might become available following the implementation of tree attention mechanisms. |
@cadedaniel |
@yukavio you should talk with @LiuXiaoxuanPKU , who is adding MQA scoring to vLLM |
@cadedaniel what is the status of this? |
I guess it is blocked on this? #5691 |
Anw, for reference, this is what people refer to by tree attention, right? https://arxiv.org/pdf/2305.09781 sounds like a scheduling nightmare |
Yep, it is blocked on #5691. https://arxiv.org/pdf/2305.09781 was the first paper discussing this approach to scoring, although Medusa did a better job in popularizing it. It is not a scheduling nightmare; the scheduler today with lookahead scheduling already supports tree attention. The complexity is mapping the linear lookahead blocks to the tree structure such that 1) the attention operation has correct causal masks and 2) the linear lookahead blocks are defragmented after the accepted tokens are determined. See https://docs.google.com/document/d/1Z9TvqzzBPnh5WHcRwjvK2UEeFeq5zMZb5mFE8jR0HCs/edit#heading=h.1fjfb0donq5a for more information on lookahead scheduling. edit: you can implement tree attention in pure pytorch without #5691 but it will have additional overheads. |
So it is MQA with a per-sequence mask on the same block (or 2 blocks if crossing blocks) while speculating? So at the kernel level one should compact the non-accepted tokens away after verification, right? And that would be the thrust of #5691? |
Yep, although you can do compaction in PyTorch without custom kernel |
Nice idea. How is your progress? Are there any difficulties in the implementation process? |
This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you! |
This issue has been automatically closed due to inactivity. Please feel free to reopen if you feel it is still relevant. Thank you! |
🚀 The feature, motivation and pitch
I want to implement tree attention for vllm mentioned in RoadMap. But I don’t know whether I should implement it based on paged-attention kernel implemented in vllm or FlashInfer due to I found we plan to replace this kernel in this PR.
Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: