Skip to content

Commit

Permalink
Fix examples
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin committed May 23, 2024
1 parent 4121b74 commit d77e518
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 20 deletions.
4 changes: 1 addition & 3 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@
examples = ["auto_fp8 is an easy-to-use model quantization library"]
examples = tokenizer(examples, return_tensors="pt").to("cuda")

ignore_patterns = ["re:.*gate"]

quantize_config = BaseQuantizeConfig(
quant_method="fp8",
activation_scheme="dynamic", # or "static"
ignore_patterns=ignore_patterns,
ignore_patterns=["re:.*lm_head"],
)

model = AutoFP8ForCausalLM.from_pretrained(
Expand Down
21 changes: 4 additions & 17 deletions example_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,11 @@
quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8"

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
ds = load_dataset("mgoin/ultrachat_2k", split="train_sft").select(512)
examples = [tokenizer.apply_chat_template(batch["messages"], tokenize=False) for batch in ds]
examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to("cuda")

DATASET_ID = "mgoin/ultrachat_2k"
DATASET_SPLIT = "train_sft"
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.map(
lambda batch: {
"text": tokenizer.apply_chat_template(batch["messages"], tokenize=False)
}
)
examples = [sample["text"] for sample in ds]
tokenizer.pad_token = tokenizer.eos_token
examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to(
"cuda"
)

quantize_config = BaseQuantizeConfig(
quant_method="fp8", activation_scheme="static"
) # or "static"
quantize_config = BaseQuantizeConfig(quant_method="fp8", activation_scheme="static")

model = AutoFP8ForCausalLM.from_pretrained(
pretrained_model_dir, quantize_config=quantize_config
Expand Down
24 changes: 24 additions & 0 deletions examples/example_mixtral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from datasets import load_dataset
from transformers import AutoTokenizer

from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig

pretrained_model_dir = "mistralai/Mixtral-8x7B-Instruct-v0.1"
quantized_model_dir = "Mixtral-8x7B-Instruct-v0.1-FP8"

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)
ds = load_dataset("mgoin/ultrachat_2k", split="train_sft").select(10)
examples = [tokenizer.apply_chat_template(batch["messages"], tokenize=False) for batch in ds]
examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to("cuda")

quantize_config = BaseQuantizeConfig(
quant_method="fp8",
activation_scheme="static",
ignore_patterns=["re:.*lm_head", "re:.*gate"],
)

model = AutoFP8ForCausalLM.from_pretrained(
pretrained_model_dir, quantize_config=quantize_config
)
model.quantize(examples)
model.save_quantized(quantized_model_dir)

0 comments on commit d77e518

Please sign in to comment.