-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transformer - High VRAM, context length #39
Comments
@MarcusLoppe yes indeed, you are correct on both accounts. local attention is tricky to handle with kv cache grouped query attention is also already available in x-transformers, which this lib is using ok, no more AI stuff until after the new years 😆 |
Have you given any thought on implementing Flash Attention 2? 😄 Seem like a great benefit to speed up the transformers training times & inference. |
@MarcusLoppe flash attention 2 will make it into the next release of pytorch, so no need! |
will it need code change? |
Hello again, this issue is for next year 😃
When training the transformer, I used the follow config:
This resulted in a transformer that was 22M parameters.
I then tried try to train it on a 6206 faces mesh which is 37236 tokens (6206 * 6).
When I feed it the faces codes (1,6206,128) it used about 11GB VRAM and at the end of the forward it used about 20 GB.
If I used a transformer that as 188M (256dim) it used 50GB of VRAM.
My suggestion to implement Sliding-Window Attention / Local attention since most long context LLM uses it and it seems to be working.
Or creating a embedding of the tokens and concating it together with the text conditioner embedding so the cross attention can beware of previous tokens as well.
Also take a look if Grouped-Query Attention is beneficial :)
The text was updated successfully, but these errors were encountered: