Skip to content

Commit

Permalink
update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
jzhang38 committed Feb 4, 2024
1 parent d333c96 commit 2f5346a
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 1 deletion.
23 changes: 22 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,29 @@ It can be observed that the model retrieves nearly perfectly on 16384. We can fu

<img src="data/heatmap_32768.png" width="800">

It is interesting to see how Mamba starts to forget the beginning of the context when we increase the context length, which is very different from the Transformer that [lost in the middle](https://arxiv.org/abs/2307.03172).

## Scaling the context length to infinity with Transformer-XL Style Training

The maximum training length we can attain is still bounded by the limited GPU memory (in my case 16384 on 8 A100 80G). To overcome this, we can train Mamba in a Transformer-XL style. That is, for each batch, instead of initializing the SSM hidden states with zeros, we initialize them with the hidden states from the previous batch.

<img src="data/Transformer-XL.png" width="800">

What is more, the RNN-like architecture of Mamba gives use another unique advantange: for Transformer-XL, its extended context length is still bounded by the maximum KV cache it can store, while for Mamba, its context cache (the hidden states) does not grow with the context length! This means theorectically we can scale the context length to infinity!

Suppose we set the sequence length within a single batch as 2048. One caveat is that we should no longer randomly shuffle the 2048-chunk in our data loading script. We should load the data in a sequential manner such that for documents longer than 2048, the first 2048 tokens are in the first batch, the next 2048 tokens are in the second batch, and so on. We can still shuffle the documents, but we should not shuffle the tokens within a document.

This idea really excites me. Unfortunately, the current Mamba implementation does not support this.

<img src="data/Mamba_issue.png" width="800">

Instead of twicking the CUDA code myself. I decide to wait for CUDA Master Tri Dao to implement this feature, because I am not confident if I can do it correctly.

What I can do, however, is to modify the torch code. Code written in torch is obviously not as efficient as the CUDA code, and may use significantly more memory. But at least I can try out the smallest Mamba model and use it as a proof of concept.


## Next Step


## References
This repository borrows code from the [yarn repo](https://github.com/jquesnelle/yarn).
This repository borrows code from the [yarn repo](https://github.com/jquesnelle/yarn) and the [Mamba repo](https://github.com/state-spaces/mamba).
Binary file added data/Mamba_issue.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/Transformer-XL.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 2f5346a

Please sign in to comment.