Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jzhang38 committed Feb 3, 2024
0 parents commit 7621b1a
Show file tree
Hide file tree
Showing 64 changed files with 8,352 additions and 0 deletions.
148 changes: 148 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# LongMamba
This repo contains my exploration on Mamba's context scaling. It includes code to: 1. train Mamba on longer context 2. evaluate Mamba's PPL on the proof pile test set. 3. perform needle in a haystack test (pass-key retrieval).
## Install
<details>
<summary>Code</summary>
```bash
conda create -n longmamba python=3.10 -y
conda activate longmamba
pip3 install torch --index-url https://download.pytorch.org/whl/cu118
pip install causal-conv1d>=1.1.0
pip install mamba-ssm
pip install -r requirements.txt
```
</details>

## Mamba Cannot Directly Handle Longer Context
We first run Mamba on the proof pile test set and note down the average PPL. It is observed that the PPL explodes when the context length increases.
<details>
<summary>Code</summary>

```bash
python eval.py \
--tokenized PY007/tokenized_proof_pile_test_neox \
--dataset-min-tokens 32768 \
--samples 20 \
--output-file data/original_mamba.csv \
--min-tokens 2048 \
--max-tokens 12288 \
--tokens-step 2048 \
--truncate \
-m state-spaces/mamba-2.8b-slimpj \
-m state-spaces/mamba-2.8b \
-m state-spaces/mamba-1.4b \
-m state-spaces/mamba-790m \
-m state-spaces/mamba-370m \
-m state-spaces/mamba-130m
python plot.py --xmax 12288 --ymax 20 data/original_mamba.csv
```
</details>
<img src="data/original_mamba.csv.png" alt="PPL explode when increasing the context length" width="500"/>

To further validates that phenomenon, let's look at below plot from [Mamba's ICLR rebuttal](https://openreview.net/forum?id=AL1fq05o7H) (unfortunately the paper not accepted). It was generated by taking the validation set of the Pile dataset, feeding in each example with no padding or concatenation, and measuring the loss per token.
<img src="data/mamba-length-extrapolation.png" alt="Mamba ICLR rebuttal" width="500"/>


## Preliminary Studies
Mamba is only trained on sequence length up to 2048. It is possible that sequence longer than that is OOD for it. But what if we just tweak the positional embeddings to make it think it's still at position 2048, just like what the positional interpolation is doing to Transformer (https://arxiv.org/abs/2306.15595 and https://kaiokendev.github.io/til)?
The thing is Mamba does not have positional embeddings. It is position-aware simply through its causal RNN-like architecture. But the underlying state space model does have a term to control the discretization of the context, and I find it quite similar to the positional embeddings in Transformer. The Figure 2 from the MambaByte paper gives a very good illustration.

<img src="data/mamba_byte.png" width="500"/>

Let's say we want Mamba to operate in a 4096 context. To make it think it's still operating at 2048, we can simply decrease the delta to one half of the original value.
<details>
<summary>Code</summary>
```bash
python eval.py \
--tokenized PY007/tokenized_proof_pile_test_neox \
--dataset-min-tokens 32768 \
--samples 20 \
--output-file data/original_mamba_delta_ratio_0.5.csv \
--min-tokens 2048 \
--max-tokens 12288 \
--tokens-step 2048 \
--truncate \
-m state-spaces/mamba-2.8b-slimpj \
--delta_ratio 0.5
python plot.py --xmax 12288 --ymax 20 data/original_mamba_delta_ratio_0.5.csv
</details>

And it does seem to work, except that the PPL on short context is now worse.

<img src="data/original_mamba_delta_ratio_0.5.csv.png" alt="PPL explode when increasing the context length" width="500"/>


The very obvious next thing to do is to train Mamba on longer context with the delta value halfed, and we can use mamba directly trained on longer context as a baseline.
To avoid uncesary counfounders, I choose state-spaces/mamba-2.8b-slimpj and train on a [subsample of slimpajama](DKYoon/SlimPajama-6B), the same dataset that Mamba is pretrained on.
<details>
<summary>Code</summary>
```bash
accelerate launch --num_processes 8 train.py --batch-size 1 --gradient-accumulate-every 8 --output-dir ./output/slim_delta_1.0_legnth_4096_step_100_lr_2e-5 \
--wandb longmamba --model state-spaces/mamba-2.8b-slimpj --dataset PY007/tokenized_slim6B_train_neox_4096 --max-train-steps 100 --learning-rate 2e-5
accelerate launch --num_processes 8 train.py --batch-size 1 --gradient-accumulate-every 8 --output-dir ./output/slim_delta_0.5_legnth_4096_step_100_lr_2e-5 \
--wandb longmamba --model state-spaces/mamba-2.8b-slimpj --dataset PY007/tokenized_slim6B_train_neox_4096 --max-train-steps 100 --learning-rate 2e-5 --delta_ratio 0.5
```
</details>

<img src="data/mamba_half_delta_training.csv.png" width="500">

Turns out halfing the delta value performs worse than the baseline. What suprises me is how good the baseline is doing: it is only trained on 2048 --> 4096 context, but it generalizes to sequence length up to 12288. This is a very good sign that Mamba is capable of handling longer context without bells and whistles!

## Start Baking
I then train mamba-2.8b-slimpj on 16384 context length, the longest that I can fit with 8 A100 80GB and FSDP Fully Shard enabled. The nice thing is it only taks 9 hours.
<details>
<summary>Code</summary>
```bash
srun accelerate launch --num_processes 8 finetune.py --batch-size 1 --gradient-accumulate-every 16 --output-dir ./output/2.8B_slim_legnth_16384_step_400_lr_3e-5 \
--wandb longmamba --model state-spaces/mamba-2.8b-slimpj --dataset PY007/tokenized_slim6B_train_neox_16384 --delta_ratio 1.0 --max-train-steps 400 --learning-rate 3e-5
# Model is uploaded to https://huggingface.co/PY007/LongMamba_16384_bs128_step400
python eval.py \
--tokenized PY007/tokenized_proof_pile_test_neox \
--dataset-min-tokens 65536 \
--samples 20 \
--output-file data/LongMamba_16384_bs128_step400.csv \
--min-tokens 2048 \
--max-tokens 65536 \
--tokens-step 2048 \
--truncate \
-m PY007/LongMamba_16384_bs128_step400 \
-m state-spaces/mamba-2.8b-slimpj
python plot.py --xmax 65536 --ymin 2 --ymax 10 data/LongMamba_16384_bs128_step400.csv
python plot.py --xmax 65536 --ymin 2 --ymax 4 data/LongMamba_16384_bs128_step400.csv
```
</details>

<img src="data/LongMamba_16384_bs128_step400_large.csv.png" width="500">

A closeer look:

<img src="data/LongMamba_16384_bs128_step400.csv.png" width="500">


This time it is just doing so good. The PPL keeps decreasing till 40K. Even after 40K, it just increase a little bit rather than directly explode.

Only the PPL test is not enought though. We need to see whether can it really memorize things. To do this, I follow [LongLora](https://arxiv.org/abs/2309.12307) and test my model with pass-key retrieval. Here is what the task looks like:

<img src="data/longlora_passkey.png" width="500">

Note that this test is slightly different from https://github.com/gkamradt/LLMTest_NeedleInAHaystack because in this test the haystack is just one sentence repeated N and M times, while in https://github.com/gkamradt/LLMTest_NeedleInAHaystack, the haystack is a an actualy docment.

<details>
<summary>Code</summary>
```bash
python pass_key.py --max_tokens 16384 --num_tests 5
python pass_key.py --max_tokens 32768 --num_tests 5
```
</details>

<img src="data/heatmap_16384.png" width="500">

It can be observed that the model retrieves nearly perfectly on 16384. We can further test on 32768 tokens, and see if it still works well.

<img src="data/heatmap_32768.png" width="500">

## Next Step


## References
This repository borrows code from the [yarn repo](https://github.com/jquesnelle/yarn).
3 changes: 3 additions & 0 deletions data/LongMamba_16384_bs128_step400.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
,2048,4096,6144,8192,10240,12288,14336,16384,18432,20480,22528,24576,26624,28672,30720,32768,34816,36864,38912,40960,43008,45056,47104,49152,51200,53248,55296,57344,59392,61440,63488,65536
PY007/LongMamba_16384_bs128_step400,4.5625,3.953125,3.625,3.484375,3.359375,3.234375,3.15625,3.0625,3.03125,3.015625,2.984375,2.984375,2.984375,2.96875,2.9375,2.921875,2.921875,2.921875,2.921875,2.921875,2.921875,2.9375,2.9375,2.9375,2.9375,2.9375,2.9375,2.9375,2.96875,2.96875,2.984375,3.015625
state-spaces/mamba-2.8b-slimpj,4.53125,4.125,5.40625,9.5,24.625,70.0
Binary file added data/LongMamba_16384_bs128_step400.csv.pdf
Binary file not shown.
Binary file added data/LongMamba_16384_bs128_step400.csv.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/LongMamba_16384_bs128_step400_large.csv.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/heatmap_16384.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/heatmap_32768.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/longlora_passkey.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/mamba-length-extrapolation.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/mamba_byte.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 4 additions & 0 deletions data/mamba_half_delta_training.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
,2048,4096,6144,8192,10240,12288,14336,16384
state-spaces/mamba-2.8b-slimpj,4.3396406173706055,3.880648136138916,4.968588352203369,8.698759078979492,22.647510528564453,64.0375747680664,174.87843322753906,399.41754150390625
output/slim_delta_1.0_slim_trained_legnth_4096,4.3607659339904785,3.7298173904418945,3.4340758323669434,3.260786771774292,3.198638916015625,3.2345333099365234,3.5753703117370605,4.608829021453857
output/slim_delta_0.5_slim_trained_legnth_4096,4.530888080596924,3.860032558441162,3.6022095680236816,3.811201572418213,4.903066635131836,7.090259552001953,10.373894691467285,17.401325225830078
Binary file added data/mamba_half_delta_training.csv.pdf
Binary file not shown.
Binary file added data/mamba_half_delta_training.csv.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
7 changes: 7 additions & 0 deletions data/original_mamba.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
,2048,4096,6144,8192,10240,12288
state-spaces/mamba-2.8b-slimpj,4.53125,4.125,5.40625,9.5,24.625,70.0
state-spaces/mamba-2.8b,4.40625,4.84375,25.0,119.5,356.0,776.0
state-spaces/mamba-1.4b,4.84375,4.21875,4.125,4.46875,5.375,7.09375
state-spaces/mamba-790m,5.1875,4.46875,4.65625,5.9375,9.0625,14.4375
state-spaces/mamba-370m,5.9375,5.03125,4.875,5.28125,6.0625,7.5
state-spaces/mamba-130m,7.75,6.46875,6.03125,5.90625,5.96875,6.21875
Binary file added data/original_mamba.csv.pdf
Binary file not shown.
Binary file added data/original_mamba.csv.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions data/original_mamba_delta_ratio_0.5.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
,2048,4096,6144,8192,10240,12288
state-spaces/mamba-2.8b-slimpj_delta_0.5,5.4375,4.78125,4.46875,5.0,6.46875,9.8125
state-spaces/mamba-2.8b-slimpj,4.53125,4.125,5.40625,9.5,24.625,70.0
Binary file added data/original_mamba_delta_ratio_0.5.csv.pdf
Binary file not shown.
Binary file added data/original_mamba_delta_ratio_0.5.csv.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
193 changes: 193 additions & 0 deletions eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import argparse
import datasets
import gc
import sys
import torch
import warnings
from transformers import AutoTokenizer
from tqdm import tqdm
from modeling.mamba_lm import MambaLMHeadModel

def compute_perplexity(
encodings, model, tokenizer, add_start_token: bool = True, device=None, max_length=None, sliding_window=256, truncate=False, aggressive_memory=False, hide_progress=False, delta_ratio=None
):
r"""Compute "sliding window" perplexity on a dataset. Validated against the calculations reported in arXiv 2306.15595"""
if device is not None:
assert device in ["gpu", "cpu",
"cuda"], "device should be either gpu or cpu."
if device == "gpu":
device = "cuda"
else:
device = "cuda" if torch.cuda.is_available() else "cpu"

if add_start_token:
# leave room for <BOS> token to be added:
assert (
tokenizer.bos_token is not None
), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False"
max_tokenized_len = max_length - 1
else:
max_tokenized_len = max_length

encoded_texts = encodings["input_ids"]
attn_masks = encodings["attention_mask"]


if max_length and truncate:
encoded_texts = [x[0:max_tokenized_len] for x in encoded_texts]
attn_masks = [x[0:max_tokenized_len] for x in attn_masks]
sliding_window = max_tokenized_len

pbar = tqdm(total=len(encoded_texts), disable=hide_progress)
nlls = []
for encoding_index in range(0, len(encoded_texts)):

labels = torch.tensor(encoded_texts[encoding_index:encoding_index+1])
seq_len = labels.size(1)

prev_end_loc = 0
for begin_loc in range(0, seq_len, sliding_window):

end_loc = min(begin_loc + max_tokenized_len, seq_len)
trg_len = end_loc - prev_end_loc
input_ids = labels[:, begin_loc:end_loc].to(device)
if add_start_token:
bos_tokens_tensor = torch.tensor(
[[tokenizer.bos_token_id]] * input_ids.size(dim=0)).to(device)
input_ids = torch.cat(
[bos_tokens_tensor, input_ids], dim=1)

target_ids = input_ids.clone()
target_ids[:, :-trg_len] = -100
with torch.no_grad():
# only get the logits for the last 1024 tokens:
logits = model(input_ids, delta_ratio=delta_ratio).logits[..., :-1, :].contiguous()
target_ids = target_ids[..., 1:].contiguous()
neg_log_likelihood = torch.nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)), target_ids.view(-1), reduction='mean')

if aggressive_memory:
outputs = None
input_ids = None
target_ids = None
gc.collect()
torch.cuda.empty_cache()

nlls.append(neg_log_likelihood)

ppl = float(torch.exp(torch.stack(nlls).mean()).float().cpu())
pbar.set_postfix(ppl=ppl)

prev_end_loc = end_loc
if end_loc == seq_len:
break

pbar.update(1)

ppl = float(torch.exp(torch.stack(nlls).mean()).float().cpu())
return {"mean_perplexity": ppl}


def main(args):
models = [x[0] for x in args.model]
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
tokenizer.pad_token = tokenizer.eos_token

if args.tokenized:
try:
input_texts = datasets.load_from_disk(args.tokenized)
except:
input_texts = datasets.load_dataset(
args.tokenized, name=args.subset, split=args.split)
else:
input_texts = datasets.load_dataset(
args.dataset, name=args.subset, split=args.split)

def tokenize(example):
tokenized = tokenizer(
example[args.feature],
add_special_tokens=False,
padding=True,
truncation=False,
max_length=sys.maxsize,
return_attention_mask=True,
)
example["input_ids"] = tokenized["input_ids"]
example["attention_mask"] = tokenized["attention_mask"]
example["tokenized_len"] = len(tokenized["input_ids"])
return example

input_texts = input_texts.map(tokenize, num_proc=64)
if args.save_tokenized:
from datasets import DatasetDict
dataset = DatasetDict({"test": input_texts})
dataset.push_to_hub(args.save_tokenized)
print(f"Saved tokenized dataset to {args.save_tokenized}")
return

if args.dataset_min_tokens:
input_texts = input_texts.filter(
lambda x: x["tokenized_len"] >= args.dataset_min_tokens, num_proc=64)
if args.samples:
input_texts = input_texts[:args.samples]

if args.tokens_step:
tokens = [x for x in range(
args.min_tokens, args.max_tokens + 1, args.tokens_step)]
else:
tokens = [args.min_tokens]
while args.min_tokens < args.max_tokens:
point = tokens[-1] * 2
if point <= args.max_tokens:
tokens.append(point)
else:
break

results = []
for model in tqdm(models, desc="Model", leave=False, disable=args.hide_progress):
torch.cuda.empty_cache()

loaded = MambaLMHeadModel.from_pretrained(model, dtype=torch.bfloat16).to("cuda")
# loaded = torch.compile(loaded)
loaded.eval()
result = []
for max_length in tokens:
ppl = compute_perplexity(model=loaded, tokenizer=tokenizer, encodings=input_texts,
add_start_token=tokenizer.bos_token is not None, max_length=max_length,
sliding_window=args.sliding_window, truncate=args.truncate,
aggressive_memory=args.aggressive_memory, hide_progress=args.hide_progress, delta_ratio=args.delta_ratio)['mean_perplexity']
print(f"{model}: {max_length}={ppl}")
result.append(ppl)

result.insert(0, model)
results.append(result)

if args.output_file:
with open(args.output_file, "w", encoding="utf-8") as f:
f.write(f",{','.join([str(x) for x in tokens])}\n")
for result in results:
f.write(f"{','.join([str(x) for x in result])}\n")


if __name__ == "__main__":
warnings.simplefilter("ignore")
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", action="append", nargs="+")
parser.add_argument("-d", "--dataset", type=str)
parser.add_argument("-s", "--subset", type=str)
parser.add_argument("-f", "--feature", type=str)
parser.add_argument("--max-tokens", type=int, default=8192)
parser.add_argument("--min-tokens", type=int, default=256)
parser.add_argument("--dataset-min-tokens", type=int)
parser.add_argument("--tokens-step", type=int)
parser.add_argument("--sliding-window", type=int, default=256)
parser.add_argument("--truncate", action="store_true")
parser.add_argument("--split", type=str, default="test")
parser.add_argument("--samples", type=int)
parser.add_argument("--save-tokenized", type=str)
parser.add_argument("--tokenized", type=str)
parser.add_argument("--output-file", type=str)
parser.add_argument("--aggressive-memory", action="store_true")
parser.add_argument("--hide-progress", action="store_true")
parser.add_argument("--delta_ratio", type=float)
main(parser.parse_args())
Binary file added modeling/__pycache__/mamba_lm.cpython-310.pyc
Binary file not shown.
Binary file added modeling/__pycache__/mamba_module.cpython-310.pyc
Binary file not shown.
Loading

0 comments on commit 7621b1a

Please sign in to comment.