-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 7621b1a
Showing
64 changed files
with
8,352 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
# LongMamba | ||
This repo contains my exploration on Mamba's context scaling. It includes code to: 1. train Mamba on longer context 2. evaluate Mamba's PPL on the proof pile test set. 3. perform needle in a haystack test (pass-key retrieval). | ||
## Install | ||
<details> | ||
<summary>Code</summary> | ||
```bash | ||
conda create -n longmamba python=3.10 -y | ||
conda activate longmamba | ||
pip3 install torch --index-url https://download.pytorch.org/whl/cu118 | ||
pip install causal-conv1d>=1.1.0 | ||
pip install mamba-ssm | ||
pip install -r requirements.txt | ||
``` | ||
</details> | ||
|
||
## Mamba Cannot Directly Handle Longer Context | ||
We first run Mamba on the proof pile test set and note down the average PPL. It is observed that the PPL explodes when the context length increases. | ||
<details> | ||
<summary>Code</summary> | ||
|
||
```bash | ||
python eval.py \ | ||
--tokenized PY007/tokenized_proof_pile_test_neox \ | ||
--dataset-min-tokens 32768 \ | ||
--samples 20 \ | ||
--output-file data/original_mamba.csv \ | ||
--min-tokens 2048 \ | ||
--max-tokens 12288 \ | ||
--tokens-step 2048 \ | ||
--truncate \ | ||
-m state-spaces/mamba-2.8b-slimpj \ | ||
-m state-spaces/mamba-2.8b \ | ||
-m state-spaces/mamba-1.4b \ | ||
-m state-spaces/mamba-790m \ | ||
-m state-spaces/mamba-370m \ | ||
-m state-spaces/mamba-130m | ||
python plot.py --xmax 12288 --ymax 20 data/original_mamba.csv | ||
``` | ||
</details> | ||
<img src="data/original_mamba.csv.png" alt="PPL explode when increasing the context length" width="500"/> | ||
|
||
To further validates that phenomenon, let's look at below plot from [Mamba's ICLR rebuttal](https://openreview.net/forum?id=AL1fq05o7H) (unfortunately the paper not accepted). It was generated by taking the validation set of the Pile dataset, feeding in each example with no padding or concatenation, and measuring the loss per token. | ||
<img src="data/mamba-length-extrapolation.png" alt="Mamba ICLR rebuttal" width="500"/> | ||
|
||
|
||
## Preliminary Studies | ||
Mamba is only trained on sequence length up to 2048. It is possible that sequence longer than that is OOD for it. But what if we just tweak the positional embeddings to make it think it's still at position 2048, just like what the positional interpolation is doing to Transformer (https://arxiv.org/abs/2306.15595 and https://kaiokendev.github.io/til)? | ||
The thing is Mamba does not have positional embeddings. It is position-aware simply through its causal RNN-like architecture. But the underlying state space model does have a term to control the discretization of the context, and I find it quite similar to the positional embeddings in Transformer. The Figure 2 from the MambaByte paper gives a very good illustration. | ||
|
||
<img src="data/mamba_byte.png" width="500"/> | ||
|
||
Let's say we want Mamba to operate in a 4096 context. To make it think it's still operating at 2048, we can simply decrease the delta to one half of the original value. | ||
<details> | ||
<summary>Code</summary> | ||
```bash | ||
python eval.py \ | ||
--tokenized PY007/tokenized_proof_pile_test_neox \ | ||
--dataset-min-tokens 32768 \ | ||
--samples 20 \ | ||
--output-file data/original_mamba_delta_ratio_0.5.csv \ | ||
--min-tokens 2048 \ | ||
--max-tokens 12288 \ | ||
--tokens-step 2048 \ | ||
--truncate \ | ||
-m state-spaces/mamba-2.8b-slimpj \ | ||
--delta_ratio 0.5 | ||
python plot.py --xmax 12288 --ymax 20 data/original_mamba_delta_ratio_0.5.csv | ||
</details> | ||
|
||
And it does seem to work, except that the PPL on short context is now worse. | ||
|
||
<img src="data/original_mamba_delta_ratio_0.5.csv.png" alt="PPL explode when increasing the context length" width="500"/> | ||
|
||
|
||
The very obvious next thing to do is to train Mamba on longer context with the delta value halfed, and we can use mamba directly trained on longer context as a baseline. | ||
To avoid uncesary counfounders, I choose state-spaces/mamba-2.8b-slimpj and train on a [subsample of slimpajama](DKYoon/SlimPajama-6B), the same dataset that Mamba is pretrained on. | ||
<details> | ||
<summary>Code</summary> | ||
```bash | ||
accelerate launch --num_processes 8 train.py --batch-size 1 --gradient-accumulate-every 8 --output-dir ./output/slim_delta_1.0_legnth_4096_step_100_lr_2e-5 \ | ||
--wandb longmamba --model state-spaces/mamba-2.8b-slimpj --dataset PY007/tokenized_slim6B_train_neox_4096 --max-train-steps 100 --learning-rate 2e-5 | ||
accelerate launch --num_processes 8 train.py --batch-size 1 --gradient-accumulate-every 8 --output-dir ./output/slim_delta_0.5_legnth_4096_step_100_lr_2e-5 \ | ||
--wandb longmamba --model state-spaces/mamba-2.8b-slimpj --dataset PY007/tokenized_slim6B_train_neox_4096 --max-train-steps 100 --learning-rate 2e-5 --delta_ratio 0.5 | ||
``` | ||
</details> | ||
|
||
<img src="data/mamba_half_delta_training.csv.png" width="500"> | ||
|
||
Turns out halfing the delta value performs worse than the baseline. What suprises me is how good the baseline is doing: it is only trained on 2048 --> 4096 context, but it generalizes to sequence length up to 12288. This is a very good sign that Mamba is capable of handling longer context without bells and whistles! | ||
|
||
## Start Baking | ||
I then train mamba-2.8b-slimpj on 16384 context length, the longest that I can fit with 8 A100 80GB and FSDP Fully Shard enabled. The nice thing is it only taks 9 hours. | ||
<details> | ||
<summary>Code</summary> | ||
```bash | ||
srun accelerate launch --num_processes 8 finetune.py --batch-size 1 --gradient-accumulate-every 16 --output-dir ./output/2.8B_slim_legnth_16384_step_400_lr_3e-5 \ | ||
--wandb longmamba --model state-spaces/mamba-2.8b-slimpj --dataset PY007/tokenized_slim6B_train_neox_16384 --delta_ratio 1.0 --max-train-steps 400 --learning-rate 3e-5 | ||
# Model is uploaded to https://huggingface.co/PY007/LongMamba_16384_bs128_step400 | ||
python eval.py \ | ||
--tokenized PY007/tokenized_proof_pile_test_neox \ | ||
--dataset-min-tokens 65536 \ | ||
--samples 20 \ | ||
--output-file data/LongMamba_16384_bs128_step400.csv \ | ||
--min-tokens 2048 \ | ||
--max-tokens 65536 \ | ||
--tokens-step 2048 \ | ||
--truncate \ | ||
-m PY007/LongMamba_16384_bs128_step400 \ | ||
-m state-spaces/mamba-2.8b-slimpj | ||
python plot.py --xmax 65536 --ymin 2 --ymax 10 data/LongMamba_16384_bs128_step400.csv | ||
python plot.py --xmax 65536 --ymin 2 --ymax 4 data/LongMamba_16384_bs128_step400.csv | ||
``` | ||
</details> | ||
|
||
<img src="data/LongMamba_16384_bs128_step400_large.csv.png" width="500"> | ||
|
||
A closeer look: | ||
|
||
<img src="data/LongMamba_16384_bs128_step400.csv.png" width="500"> | ||
|
||
|
||
This time it is just doing so good. The PPL keeps decreasing till 40K. Even after 40K, it just increase a little bit rather than directly explode. | ||
|
||
Only the PPL test is not enought though. We need to see whether can it really memorize things. To do this, I follow [LongLora](https://arxiv.org/abs/2309.12307) and test my model with pass-key retrieval. Here is what the task looks like: | ||
|
||
<img src="data/longlora_passkey.png" width="500"> | ||
|
||
Note that this test is slightly different from https://github.com/gkamradt/LLMTest_NeedleInAHaystack because in this test the haystack is just one sentence repeated N and M times, while in https://github.com/gkamradt/LLMTest_NeedleInAHaystack, the haystack is a an actualy docment. | ||
|
||
<details> | ||
<summary>Code</summary> | ||
```bash | ||
python pass_key.py --max_tokens 16384 --num_tests 5 | ||
python pass_key.py --max_tokens 32768 --num_tests 5 | ||
``` | ||
</details> | ||
|
||
<img src="data/heatmap_16384.png" width="500"> | ||
|
||
It can be observed that the model retrieves nearly perfectly on 16384. We can further test on 32768 tokens, and see if it still works well. | ||
|
||
<img src="data/heatmap_32768.png" width="500"> | ||
|
||
## Next Step | ||
|
||
|
||
## References | ||
This repository borrows code from the [yarn repo](https://github.com/jquesnelle/yarn). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
,2048,4096,6144,8192,10240,12288,14336,16384,18432,20480,22528,24576,26624,28672,30720,32768,34816,36864,38912,40960,43008,45056,47104,49152,51200,53248,55296,57344,59392,61440,63488,65536 | ||
PY007/LongMamba_16384_bs128_step400,4.5625,3.953125,3.625,3.484375,3.359375,3.234375,3.15625,3.0625,3.03125,3.015625,2.984375,2.984375,2.984375,2.96875,2.9375,2.921875,2.921875,2.921875,2.921875,2.921875,2.921875,2.9375,2.9375,2.9375,2.9375,2.9375,2.9375,2.9375,2.96875,2.96875,2.984375,3.015625 | ||
state-spaces/mamba-2.8b-slimpj,4.53125,4.125,5.40625,9.5,24.625,70.0 |
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
,2048,4096,6144,8192,10240,12288,14336,16384 | ||
state-spaces/mamba-2.8b-slimpj,4.3396406173706055,3.880648136138916,4.968588352203369,8.698759078979492,22.647510528564453,64.0375747680664,174.87843322753906,399.41754150390625 | ||
output/slim_delta_1.0_slim_trained_legnth_4096,4.3607659339904785,3.7298173904418945,3.4340758323669434,3.260786771774292,3.198638916015625,3.2345333099365234,3.5753703117370605,4.608829021453857 | ||
output/slim_delta_0.5_slim_trained_legnth_4096,4.530888080596924,3.860032558441162,3.6022095680236816,3.811201572418213,4.903066635131836,7.090259552001953,10.373894691467285,17.401325225830078 |
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
,2048,4096,6144,8192,10240,12288 | ||
state-spaces/mamba-2.8b-slimpj,4.53125,4.125,5.40625,9.5,24.625,70.0 | ||
state-spaces/mamba-2.8b,4.40625,4.84375,25.0,119.5,356.0,776.0 | ||
state-spaces/mamba-1.4b,4.84375,4.21875,4.125,4.46875,5.375,7.09375 | ||
state-spaces/mamba-790m,5.1875,4.46875,4.65625,5.9375,9.0625,14.4375 | ||
state-spaces/mamba-370m,5.9375,5.03125,4.875,5.28125,6.0625,7.5 | ||
state-spaces/mamba-130m,7.75,6.46875,6.03125,5.90625,5.96875,6.21875 |
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
,2048,4096,6144,8192,10240,12288 | ||
state-spaces/mamba-2.8b-slimpj_delta_0.5,5.4375,4.78125,4.46875,5.0,6.46875,9.8125 | ||
state-spaces/mamba-2.8b-slimpj,4.53125,4.125,5.40625,9.5,24.625,70.0 |
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
import argparse | ||
import datasets | ||
import gc | ||
import sys | ||
import torch | ||
import warnings | ||
from transformers import AutoTokenizer | ||
from tqdm import tqdm | ||
from modeling.mamba_lm import MambaLMHeadModel | ||
|
||
def compute_perplexity( | ||
encodings, model, tokenizer, add_start_token: bool = True, device=None, max_length=None, sliding_window=256, truncate=False, aggressive_memory=False, hide_progress=False, delta_ratio=None | ||
): | ||
r"""Compute "sliding window" perplexity on a dataset. Validated against the calculations reported in arXiv 2306.15595""" | ||
if device is not None: | ||
assert device in ["gpu", "cpu", | ||
"cuda"], "device should be either gpu or cpu." | ||
if device == "gpu": | ||
device = "cuda" | ||
else: | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
if add_start_token: | ||
# leave room for <BOS> token to be added: | ||
assert ( | ||
tokenizer.bos_token is not None | ||
), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False" | ||
max_tokenized_len = max_length - 1 | ||
else: | ||
max_tokenized_len = max_length | ||
|
||
encoded_texts = encodings["input_ids"] | ||
attn_masks = encodings["attention_mask"] | ||
|
||
|
||
if max_length and truncate: | ||
encoded_texts = [x[0:max_tokenized_len] for x in encoded_texts] | ||
attn_masks = [x[0:max_tokenized_len] for x in attn_masks] | ||
sliding_window = max_tokenized_len | ||
|
||
pbar = tqdm(total=len(encoded_texts), disable=hide_progress) | ||
nlls = [] | ||
for encoding_index in range(0, len(encoded_texts)): | ||
|
||
labels = torch.tensor(encoded_texts[encoding_index:encoding_index+1]) | ||
seq_len = labels.size(1) | ||
|
||
prev_end_loc = 0 | ||
for begin_loc in range(0, seq_len, sliding_window): | ||
|
||
end_loc = min(begin_loc + max_tokenized_len, seq_len) | ||
trg_len = end_loc - prev_end_loc | ||
input_ids = labels[:, begin_loc:end_loc].to(device) | ||
if add_start_token: | ||
bos_tokens_tensor = torch.tensor( | ||
[[tokenizer.bos_token_id]] * input_ids.size(dim=0)).to(device) | ||
input_ids = torch.cat( | ||
[bos_tokens_tensor, input_ids], dim=1) | ||
|
||
target_ids = input_ids.clone() | ||
target_ids[:, :-trg_len] = -100 | ||
with torch.no_grad(): | ||
# only get the logits for the last 1024 tokens: | ||
logits = model(input_ids, delta_ratio=delta_ratio).logits[..., :-1, :].contiguous() | ||
target_ids = target_ids[..., 1:].contiguous() | ||
neg_log_likelihood = torch.nn.functional.cross_entropy( | ||
logits.view(-1, logits.size(-1)), target_ids.view(-1), reduction='mean') | ||
|
||
if aggressive_memory: | ||
outputs = None | ||
input_ids = None | ||
target_ids = None | ||
gc.collect() | ||
torch.cuda.empty_cache() | ||
|
||
nlls.append(neg_log_likelihood) | ||
|
||
ppl = float(torch.exp(torch.stack(nlls).mean()).float().cpu()) | ||
pbar.set_postfix(ppl=ppl) | ||
|
||
prev_end_loc = end_loc | ||
if end_loc == seq_len: | ||
break | ||
|
||
pbar.update(1) | ||
|
||
ppl = float(torch.exp(torch.stack(nlls).mean()).float().cpu()) | ||
return {"mean_perplexity": ppl} | ||
|
||
|
||
def main(args): | ||
models = [x[0] for x in args.model] | ||
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b') | ||
tokenizer.pad_token = tokenizer.eos_token | ||
|
||
if args.tokenized: | ||
try: | ||
input_texts = datasets.load_from_disk(args.tokenized) | ||
except: | ||
input_texts = datasets.load_dataset( | ||
args.tokenized, name=args.subset, split=args.split) | ||
else: | ||
input_texts = datasets.load_dataset( | ||
args.dataset, name=args.subset, split=args.split) | ||
|
||
def tokenize(example): | ||
tokenized = tokenizer( | ||
example[args.feature], | ||
add_special_tokens=False, | ||
padding=True, | ||
truncation=False, | ||
max_length=sys.maxsize, | ||
return_attention_mask=True, | ||
) | ||
example["input_ids"] = tokenized["input_ids"] | ||
example["attention_mask"] = tokenized["attention_mask"] | ||
example["tokenized_len"] = len(tokenized["input_ids"]) | ||
return example | ||
|
||
input_texts = input_texts.map(tokenize, num_proc=64) | ||
if args.save_tokenized: | ||
from datasets import DatasetDict | ||
dataset = DatasetDict({"test": input_texts}) | ||
dataset.push_to_hub(args.save_tokenized) | ||
print(f"Saved tokenized dataset to {args.save_tokenized}") | ||
return | ||
|
||
if args.dataset_min_tokens: | ||
input_texts = input_texts.filter( | ||
lambda x: x["tokenized_len"] >= args.dataset_min_tokens, num_proc=64) | ||
if args.samples: | ||
input_texts = input_texts[:args.samples] | ||
|
||
if args.tokens_step: | ||
tokens = [x for x in range( | ||
args.min_tokens, args.max_tokens + 1, args.tokens_step)] | ||
else: | ||
tokens = [args.min_tokens] | ||
while args.min_tokens < args.max_tokens: | ||
point = tokens[-1] * 2 | ||
if point <= args.max_tokens: | ||
tokens.append(point) | ||
else: | ||
break | ||
|
||
results = [] | ||
for model in tqdm(models, desc="Model", leave=False, disable=args.hide_progress): | ||
torch.cuda.empty_cache() | ||
|
||
loaded = MambaLMHeadModel.from_pretrained(model, dtype=torch.bfloat16).to("cuda") | ||
# loaded = torch.compile(loaded) | ||
loaded.eval() | ||
result = [] | ||
for max_length in tokens: | ||
ppl = compute_perplexity(model=loaded, tokenizer=tokenizer, encodings=input_texts, | ||
add_start_token=tokenizer.bos_token is not None, max_length=max_length, | ||
sliding_window=args.sliding_window, truncate=args.truncate, | ||
aggressive_memory=args.aggressive_memory, hide_progress=args.hide_progress, delta_ratio=args.delta_ratio)['mean_perplexity'] | ||
print(f"{model}: {max_length}={ppl}") | ||
result.append(ppl) | ||
|
||
result.insert(0, model) | ||
results.append(result) | ||
|
||
if args.output_file: | ||
with open(args.output_file, "w", encoding="utf-8") as f: | ||
f.write(f",{','.join([str(x) for x in tokens])}\n") | ||
for result in results: | ||
f.write(f"{','.join([str(x) for x in result])}\n") | ||
|
||
|
||
if __name__ == "__main__": | ||
warnings.simplefilter("ignore") | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("-m", "--model", action="append", nargs="+") | ||
parser.add_argument("-d", "--dataset", type=str) | ||
parser.add_argument("-s", "--subset", type=str) | ||
parser.add_argument("-f", "--feature", type=str) | ||
parser.add_argument("--max-tokens", type=int, default=8192) | ||
parser.add_argument("--min-tokens", type=int, default=256) | ||
parser.add_argument("--dataset-min-tokens", type=int) | ||
parser.add_argument("--tokens-step", type=int) | ||
parser.add_argument("--sliding-window", type=int, default=256) | ||
parser.add_argument("--truncate", action="store_true") | ||
parser.add_argument("--split", type=str, default="test") | ||
parser.add_argument("--samples", type=int) | ||
parser.add_argument("--save-tokenized", type=str) | ||
parser.add_argument("--tokenized", type=str) | ||
parser.add_argument("--output-file", type=str) | ||
parser.add_argument("--aggressive-memory", action="store_true") | ||
parser.add_argument("--hide-progress", action="store_true") | ||
parser.add_argument("--delta_ratio", type=float) | ||
main(parser.parse_args()) |
Binary file not shown.
Binary file not shown.
Oops, something went wrong.