Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for encoder-decoder models, like T5 or BART #187

Closed
shermansiu opened this issue Jun 21, 2023 · 33 comments
Closed

Adding support for encoder-decoder models, like T5 or BART #187

shermansiu opened this issue Jun 21, 2023 · 33 comments
Labels
new model Requests to new models

Comments

@shermansiu
Copy link

Will there be added support for encoder-decoder models, like T5 or BART? All of the currently supported models are decoder-only.

@zhuohan123
Copy link
Member

Yes, this is in our plan. Adding these models requires modifying vLLM's cache block manager to also manage the attention cache of the encoder, which is a notable modification. Feel free to talk to us if you are interested to contribute and accelerate this process.

@zhuohan123 zhuohan123 added the new model Requests to new models label Jun 21, 2023
@shermansiu
Copy link
Author

So... to contribute, we would need to re-implement the model in https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/<MODEL_NAME>.py, except with paged attention (i.e. replace self.attn with a paged version and use a KVCache during computation)?

It also seems like most linear projections are replaced with either ColumnParallelLinear or RowParallelLinear, right? So nn.Linear(small, big) is replaced with ColumnParallelLinear(small, big) (thus parallelizing the large number of columns) and nn.Linear(big, small) is replaced by RowParallelLinear(big, small)?

@shermansiu
Copy link
Author

@shermansiu
Copy link
Author

I see you've already answered this in the FAQ here: https://vllm.readthedocs.io/en/latest/models/adding_model.html

@js8544
Copy link
Contributor

js8544 commented Oct 17, 2023

@zhuohan123 Hi, I'm interested in implementing support for encode-decoder models. Does it require any changes other than what's listed in https://vllm.readthedocs.io/en/latest/models/adding_model.html?

@js8544
Copy link
Contributor

js8544 commented Nov 24, 2023

@WoosukKwon @zhuohan123 Hi, my team plans to work on T5 support. We would like to ask a few questions before we start.

  1. Is the vLLM team currently working or planning to work on this? If so then there's no point for us to do it.
  2. @zhuohan123 said above that it requires cache block manager to also manage the attention cache of the encoder. However, AFAIU encoder doesn't need kv caches. Instead it should manage decoder's cross attention kv cache, right?
  3. Apart from managing cross attention kv cache in block_manager.py and implementing the model in t5.py. Are there any other components that need to change. Could you briefly describle how to implement this with minimal change?

Any help is appreciated. Thanks in advance!

@zhuohan123
Copy link
Member

@WoosukKwon @zhuohan123 Hi, my team plans to work on T5 support. We would like to ask a few questions before we start.

  1. Is the vLLM team currently working or planning to work on this? If so then there's no point for us to do it.

We are not actively working on this. Please go ahead!

  1. @zhuohan123 said above that it requires cache block manager to also manage the attention cache of the encoder. However, AFAIU encoder doesn't need kv caches. Instead it should manage decoder's cross attention kv cache, right?

Yeah I think the point is to maintain the cross-attention kv cache generated by the encoder. I believe this cache should also be included in our block managed and managed in a blocked fashion, because it's size depends on the input size, which can be highly variable.

  1. Apart from managing cross attention kv cache in block_manager.py and implementing the model in t5.py. Are there any other components that need to change. Could you briefly describle how to implement this with minimal change?

Some points I can think of:

  • You might need to change the memory profiling logic in profile_num_available_blocks(). This function profiles the maximum memory usage of the model, which may need to be changed because of the encoder-decoder structure.
  • The blocks for encoders and the blocks for decoders may need to be stored separately, since the encoder's cross-attention cache is shared, and the decoder's cache is per layer.
  • For the model in t5.py, you might need to look at the input and check whether the input is a prompt run or a generation run. If it's prompt run, you call the encoder and feed <sos> to the decoder and run the first decoder run. If it's a generation run, you only call the decoder.

I believe there can be some other places in our code where we assume the model is decoder only.

Any help is appreciated. Thanks in advance!

Thanks for taking this and please let us know if there's any issue! We are also happy to chat online if you need more detailed suggestions. Feel free to shoot me an email at zhuohan[at]berkeley.edu.

@shermansiu
Copy link
Author

Also, I suppose the encoder cache eviction would be different.

i.e. The encoder's cross-attention values would need to be kept as long as the decoding is active for a prompt, but can be evicted the moment the generation is completed.

@shermansiu
Copy link
Author

(Never mind, for the sake of simplicity, LRU should work just fine)

@simon-mo
Copy link
Collaborator

cc @rib-2

@js8544
Copy link
Contributor

js8544 commented Jan 31, 2024

Update: I'm very close to finishing this. I've run T5 with vllm successfully on my local machine. I think I will be able to submit a PR in the coming weeks.

@junior-zsy
Copy link

@js8544 Hello, is there any progress on this now? I would like to use it. Thank you

@Elsayed91
Copy link

would this include BART?

@afeldman-nm
Copy link
Contributor

Hello @js8544 thank you so much for this work. My team is very interested in encoder/decoder.

I would like to offer to help with landing this PR. How can I assist?

Once the encoder/decoder feature is landed, our team plans to integrate Whisper (audio speech recognition) support on top of it. This motivates the interest in supporting encoder/decoder work. @zhuohan123 FYI this relates to

#180

@js8544
Copy link
Contributor

js8544 commented Feb 29, 2024

I just submitted a draft PR: #3117. There are still some problems to solve. I would really appreciate any comments or advice.

@Elsayed91
Copy link

I tried the pull request, T5 worked but BART did not.

@afeldman-nm
Copy link
Contributor

afeldman-nm commented Apr 30, 2024

@Elsayed91 did you write your own BART implementation? What was the nature of the issue?

Status update on encoder/decoder models & T5:

It has become clear that the aforementioned work rightfully belongs in at least two medium-small sized PRs, rather than a single large PR:

PR 1: vLLM infrastructure to support encoder/decoder, along with unit tests

PR 2: Support for T5

  • Draft PR: TBD

My experience working on T5 integration suggests to me that T5's relative positional encoding relies on "custom attention bias" which is (1) not supported by vLLM flash_attn, (2) difficult to integrate efficiently into the existing vLLM workflow, and (3) really an entirely different task from encoder/decoder. Thus T5 support belongs in its own PR.

More on the impact which custom bias has on the outcome of working with models like T5 can be found in the comments on this post https://twitter.com/birchlabs/status/1782791645961859142?s=46

Note that Whisper support (#180) takes a dependency on encoder/decoder as well, and will also be in a separate PR.

@js8544
Copy link
Contributor

js8544 commented Apr 30, 2024

@Elsayed91 did you write your own BART implementation? What was the nature of the issue?

Status update on encoder/decoder models & T5:

It has become clear that the aforementioned work rightfully belongs in at least two medium-small sized PRs, rather than a single large PR:

PR 1: vLLM infrastructure to support encoder/decoder, along with unit tests

PR 2: Support for T5

  • Draft PR: TBD

My experience working on T5 integration suggests to me that T5's relative positional encoding relies on "custom attention bias" which is (1) not supported by vLLM flash_attn, (2) difficult to integrate efficiently into the existing vLLM workflow, and (3) really an entirely different task from encoder/decoder. Thus T5 support belongs in its own PR.

More on the impact which custom bias has on the outcome of working with models like T5 can be found in the comments on this post https://twitter.com/birchlabs/status/1782791645961859142?s=46

Note that Whisper support (#180) takes a dependency on encoder/decoder as well, and will also be in a separate PR.

I totally agree. The relative attention bias of T5 was very painful to implement, and is not necessary for other enc-dec models like whisper. I can add T5 support after your enc-dec infra pr is merged.

@js8544
Copy link
Contributor

js8544 commented Apr 30, 2024

btw Bart would be simpler than T5 because it uses the original Transformer structure. Maybe we can do Bart first.

@afeldman-nm
Copy link
Contributor

Quick update: the PR to support cross-attention caching (#4837) has been landed. Now I am working on landing the PR to correctly invoke the attention kernel for cross-attention (#4888).

@afeldman-nm
Copy link
Contributor

afeldman-nm commented Jun 21, 2024

Update:

@anonymousz97
Copy link

Is there any documentation for inference Bart type model? Thanks.

@afeldman-nm
Copy link
Contributor

afeldman-nm commented Jul 3, 2024

Hello @anonymousz97 , this PR

#4942

will include BART support & example code for invoking BART. This PR is WIP but should be ready for review soon.

@Sapessii
Copy link

Sapessii commented Jul 5, 2024

hi @afeldman-nm! is this pr going to support also https://huggingface.co/facebook/bart-large-mnli?

thank you

@anonymousz97
Copy link

Thanks, i will try it @afeldman-nm af

@afeldman-nm
Copy link
Contributor

afeldman-nm commented Jul 8, 2024

Update: #4888 is landed, enabling the xFormers backend to support encoder attention, decoder self-attention, and decoder cross-attention. #4837 and #4888 (both of which have been landed) were prerequisites for #4942 . #4942
completes end-to-end support for encoder/decoder models & also introduces the BART model into vLLM. #4942 is still WIP.

@anonymousz97
Copy link

Does it support MBartForConditionalGeneration model @afeldman-nm ? Thanks

@thanhlt998
Copy link

Hello @anonymousz97 , this PR

#4942

will include BART support & example code for invoking BART. This PR is WIP but should be ready for review soon.

@afeldman-nm Will that PR include T5 support?

@afeldman-nm
Copy link
Contributor

FYI encoder/decoder support has landed ( #4942 ); there is an example in examples/offline_inference_encoder_decoder.py. BART has been integrated in to vLLM (T5 and Whisper have not, to answer a previous question.)

Currently vLLM encoder/decoder support is constrained in what features it is compatible with (i.e. not CUDAGraph, not pipeline parallelism, ...) So it is now a goal to make more features compatible with vLLM's encoder/decoder processing pipeline.

To that end RFC #7366 overviews the vLLM features which are currently not compatible encoder/decoder, with an eye for bringing vLLM's encoder/decoder support to parity with vLLM's decoder-only support.

Additionally #7366 proposes adding custom attention bias support as well as the Whisper and T5 models.

The RFC feedback period is 1 week (until August 16th)

@DarkLight1337
Copy link
Member

Closing this issue in favor of #7366

@yugaljain1999
Copy link

yugaljain1999 commented Sep 12, 2024

@DarkLight1337 @afeldman-nm How much time it may take to have support for T5 based models in vllm?

Your response would be appreciable.
Thanks

mht-sharma pushed a commit to mht-sharma/vllm that referenced this issue Oct 30, 2024
@saisurbehera
Copy link

Any updates on this ?

@DarkLight1337
Copy link
Member

DarkLight1337 commented Jan 14, 2025

For T5 model, see #11334 + #11901, or #11470

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
new model Requests to new models
Projects
None yet
Development

No branches or pull requests