-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding support for encoder-decoder models, like T5 or BART #187
Comments
Yes, this is in our plan. Adding these models requires modifying vLLM's cache block manager to also manage the attention cache of the encoder, which is a notable modification. Feel free to talk to us if you are interested to contribute and accelerate this process. |
So... to contribute, we would need to re-implement the model in https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/<MODEL_NAME>.py, except with paged attention (i.e. replace It also seems like most linear projections are replaced with either ColumnParallelLinear or RowParallelLinear, right? So |
Are https://github.com/vllm-project/vllm/pull/60/files and https://github.com/vllm-project/vllm/pull/50/files good reference PRs for this? |
I see you've already answered this in the FAQ here: https://vllm.readthedocs.io/en/latest/models/adding_model.html |
@zhuohan123 Hi, I'm interested in implementing support for encode-decoder models. Does it require any changes other than what's listed in https://vllm.readthedocs.io/en/latest/models/adding_model.html? |
@WoosukKwon @zhuohan123 Hi, my team plans to work on T5 support. We would like to ask a few questions before we start.
Any help is appreciated. Thanks in advance! |
We are not actively working on this. Please go ahead!
Yeah I think the point is to maintain the cross-attention kv cache generated by the encoder. I believe this cache should also be included in our block managed and managed in a blocked fashion, because it's size depends on the input size, which can be highly variable.
Some points I can think of:
I believe there can be some other places in our code where we assume the model is decoder only.
Thanks for taking this and please let us know if there's any issue! We are also happy to chat online if you need more detailed suggestions. Feel free to shoot me an email at |
Also, I suppose the encoder cache eviction would be different. i.e. The encoder's cross-attention values would need to be kept as long as the decoding is active for a prompt, but can be evicted the moment the generation is completed. |
(Never mind, for the sake of simplicity, LRU should work just fine) |
cc @rib-2 |
Update: I'm very close to finishing this. I've run T5 with vllm successfully on my local machine. I think I will be able to submit a PR in the coming weeks. |
@js8544 Hello, is there any progress on this now? I would like to use it. Thank you |
would this include BART? |
Hello @js8544 thank you so much for this work. My team is very interested in encoder/decoder. I would like to offer to help with landing this PR. How can I assist? Once the encoder/decoder feature is landed, our team plans to integrate Whisper (audio speech recognition) support on top of it. This motivates the interest in supporting encoder/decoder work. @zhuohan123 FYI this relates to |
I just submitted a draft PR: #3117. There are still some problems to solve. I would really appreciate any comments or advice. |
I tried the pull request, T5 worked but BART did not. |
@Elsayed91 did you write your own BART implementation? What was the nature of the issue? Status update on encoder/decoder models & T5: It has become clear that the aforementioned work rightfully belongs in at least two medium-small sized PRs, rather than a single large PR: PR 1: vLLM infrastructure to support encoder/decoder, along with unit tests PR 2: Support for T5
My experience working on T5 integration suggests to me that T5's relative positional encoding relies on "custom attention bias" which is (1) not supported by vLLM flash_attn, (2) difficult to integrate efficiently into the existing vLLM workflow, and (3) really an entirely different task from encoder/decoder. Thus T5 support belongs in its own PR. More on the impact which custom bias has on the outcome of working with models like T5 can be found in the comments on this post https://twitter.com/birchlabs/status/1782791645961859142?s=46 Note that Whisper support (#180) takes a dependency on encoder/decoder as well, and will also be in a separate PR. |
I totally agree. The relative attention bias of T5 was very painful to implement, and is not necessary for other enc-dec models like whisper. I can add T5 support after your enc-dec infra pr is merged. |
btw Bart would be simpler than T5 because it uses the original Transformer structure. Maybe we can do Bart first. |
Update:
|
Is there any documentation for inference Bart type model? Thanks. |
Hello @anonymousz97 , this PR will include BART support & example code for invoking BART. This PR is WIP but should be ready for review soon. |
hi @afeldman-nm! is this pr going to support also https://huggingface.co/facebook/bart-large-mnli? thank you |
Thanks, i will try it @afeldman-nm af |
Update: #4888 is landed, enabling the xFormers backend to support encoder attention, decoder self-attention, and decoder cross-attention. #4837 and #4888 (both of which have been landed) were prerequisites for #4942 . #4942 |
Does it support MBartForConditionalGeneration model @afeldman-nm ? Thanks |
@afeldman-nm Will that PR include T5 support? |
FYI encoder/decoder support has landed ( #4942 ); there is an example in Currently vLLM encoder/decoder support is constrained in what features it is compatible with (i.e. not CUDAGraph, not pipeline parallelism, ...) So it is now a goal to make more features compatible with vLLM's encoder/decoder processing pipeline. To that end RFC #7366 overviews the vLLM features which are currently not compatible encoder/decoder, with an eye for bringing vLLM's encoder/decoder support to parity with vLLM's decoder-only support. Additionally #7366 proposes adding custom attention bias support as well as the Whisper and T5 models. The RFC feedback period is 1 week (until August 16th) |
Closing this issue in favor of #7366 |
@DarkLight1337 @afeldman-nm How much time it may take to have support for T5 based models in vllm? Your response would be appreciable. |
Upstream merge 24/09/16
Any updates on this ? |
Will there be added support for encoder-decoder models, like T5 or BART? All of the currently supported models are decoder-only.
The text was updated successfully, but these errors were encountered: