Skip to content

Commit

Permalink
Add ignore_patterns config
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin committed May 23, 2024
1 parent a111911 commit 4121b74
Show file tree
Hide file tree
Showing 9 changed files with 164 additions and 36 deletions.
2 changes: 1 addition & 1 deletion auto_fp8/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .modeling import AutoFP8ForCausalLM
from .config import BaseQuantizeConfig
from .modeling import AutoFP8ForCausalLM

__all__ = [
"AutoFP8ForCausalLM",
Expand Down
12 changes: 11 additions & 1 deletion auto_fp8/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from typing import List


class BaseQuantizeConfig:
def __init__(self, quant_method="fp8", activation_scheme="static"):
def __init__(
self,
quant_method: str = "fp8",
activation_scheme: str = "static",
ignore_patterns: List[str] = [],
):
if quant_method != "fp8":
raise ValueError("Only FP8 quantization is supported.")
if activation_scheme not in ["static", "dynamic"]:
Expand All @@ -8,3 +16,5 @@ def __init__(self, quant_method="fp8", activation_scheme="static"):
)
self.quant_method = quant_method
self.activation_scheme = activation_scheme
self.ignore_patterns = ignore_patterns
self.ignored_layers = []
59 changes: 51 additions & 8 deletions auto_fp8/modeling.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,34 @@
import re
from typing import List

import torch
from transformers import AutoModelForCausalLM, PreTrainedModel
from transformers import AutoModelForCausalLM

from auto_fp8.config import BaseQuantizeConfig
from auto_fp8.quantize import (
quantize_weights,
quantize_activations,
quantize_weights,
save_quantized_model,
)
from auto_fp8.config import BaseQuantizeConfig


class AutoFP8ForCausalLM:
def __init__(
self,
model: PreTrainedModel,
model: AutoModelForCausalLM,
quantize_config: BaseQuantizeConfig,
):
self.model = model
self.model_type = self.model.config.model_type
self.quantize_config = quantize_config
self.config = self.model.config

# Gather the Linear module names that we want to ignore
quantize_config.ignored_layers = get_layers_to_ignore(
self.model, quantize_config.ignore_patterns
)

self.quantize_config = quantize_config

@classmethod
def from_pretrained(
cls,
Expand Down Expand Up @@ -94,16 +104,49 @@ def _prepare_calibration_data(calibration_tokens):
return calibration_tokens

# Always quantize the weights as they do not require calibration data
quantize_weights(self.model)
quantize_weights(self.model, self.quantize_config)

if self.quantize_config.activation_scheme == "static":
quantize_activations(
self.model, _prepare_calibration_data(calibration_tokens)
self.model,
self.quantize_config,
_prepare_calibration_data(calibration_tokens),
)

# import copy
# for layer in self.model.model.layers:
# layer.self_attn.kv_scale = copy.deepcopy(layer.self_attn.k_proj.act_scale)

def save_quantized(self, save_dir):
save_quantized_model(
self.model,
activation_scheme=self.quantize_config.activation_scheme,
quant_config=self.quantize_config,
save_dir=save_dir,
)


def get_layers_to_ignore(model, ignore_patterns) -> List[str]:
ignored_layers = set()

# TODO: don't always ignore lm_head
ignore_patterns.append("re:.*lm_head")

for name, linear in model.named_modules():
if not isinstance(linear, torch.nn.Linear):
continue

for ignore_pattern in ignore_patterns:
regex_prefix = "re:"
if ignore_pattern.startswith(regex_prefix):
# check if name matches regex and add to set if true
regex_pattern = ignore_pattern[len(regex_prefix) :]
print(regex_pattern)
print(name)
if re.search(regex_pattern, name):
ignored_layers.add(name)
else:
# else, exact match
if ignore_pattern == name:
ignored_layers.add(name)

return list(ignored_layers)
77 changes: 56 additions & 21 deletions auto_fp8/quantize.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import gc
import re
from typing import Tuple
from typing import List, Tuple

import torch
import transformers
import tqdm
from transformers import AutoTokenizer
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer

from .config import BaseQuantizeConfig


# HACK: Override the dtype_byte_size function in transformers to support float8 types
Expand Down Expand Up @@ -39,8 +42,8 @@ def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
if tensor.numel() == 0:
# Deal with empty tensors (triggered by empty MoE experts)
min_val, max_val = (
torch.tensor(0.0, dtype=tensor.dtype),
torch.tensor(1.0, dtype=tensor.dtype),
torch.tensor(-16.0, dtype=tensor.dtype),
torch.tensor(16.0, dtype=tensor.dtype),
)
else:
min_val, max_val = tensor.aminmax()
Expand Down Expand Up @@ -80,7 +83,9 @@ def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype):


class FP8StaticLinearQuantizer(torch.nn.Module):
def __init__(self, qweight, weight_scale, bias):
def __init__(
self, qweight: torch.Tensor, weight_scale: torch.Tensor, bias: torch.Tensor
):
super().__init__()
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
Expand All @@ -105,7 +110,13 @@ def forward(self, x):


class FP8StaticLinear(torch.nn.Module):
def __init__(self, qweight, weight_scale, bias, act_scale=0.0):
def __init__(
self,
qweight: torch.Tensor,
weight_scale: torch.Tensor,
bias: torch.Tensor,
act_scale: float = 1.0,
):
super().__init__()
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False)
Expand Down Expand Up @@ -133,7 +144,7 @@ def forward(self, x):


class FP8DynamicLinear(torch.nn.Module):
def __init__(self, qweight, scale, bias):
def __init__(self, qweight: torch.Tensor, scale: torch.Tensor, bias: torch.Tensor):
super().__init__()
self.weight = torch.nn.Parameter(qweight, requires_grad=False)
self.weight_scale = torch.nn.Parameter(scale, requires_grad=False)
Expand All @@ -152,21 +163,28 @@ def forward(self, x):
return output


def replace_module(model, name, new_module):
def replace_module(model: AutoModelForCausalLM, name: str, new_module: torch.nn.Module):
if "." in name:
parent_name = name.rsplit(".", 1)[0]
child_name = name[len(parent_name) + 1 :]
parent = model.model.get_submodule(parent_name)
parent = model.get_submodule(parent_name)
else:
parent_name = ""
parent = model.model
parent = model
child_name = name
setattr(parent, child_name, new_module)


def quantize_weights(model):
for name, linear in model.model.named_modules():
if not isinstance(linear, torch.nn.Linear):
def quantize_weights(
model: AutoModelForCausalLM,
quantize_config: BaseQuantizeConfig,
ignored_layers: List[str] = [],
):
for name, linear in model.named_modules():
if (
not isinstance(linear, torch.nn.Linear)
or name in quantize_config.ignored_layers
):
continue
quant_weight, quant_scale = per_tensor_quantize(linear.weight)
quant_linear = FP8DynamicLinear(quant_weight, quant_scale, linear.bias)
Expand All @@ -175,9 +193,17 @@ def quantize_weights(model):
cleanup_memory()


def quantize_activations(model, calibration_tokens):
for name, dynamic_quant_linear in model.model.named_modules():
if not isinstance(dynamic_quant_linear, FP8DynamicLinear):
def quantize_activations(
model: AutoModelForCausalLM,
quantize_config: BaseQuantizeConfig,
calibration_tokens,
ignored_layers: List[str] = [],
):
for name, dynamic_quant_linear in model.named_modules():
if (
not isinstance(dynamic_quant_linear, FP8DynamicLinear)
or name in quantize_config.ignored_layers
):
continue
quantizer = FP8StaticLinearQuantizer(
dynamic_quant_linear.weight,
Expand All @@ -196,8 +222,11 @@ def quantize_activations(model, calibration_tokens):
pbar.update(1)

# Replace dynamic quantizer with StaticLinear for export
for name, quantizer in model.model.named_modules():
if not isinstance(quantizer, FP8StaticLinearQuantizer):
for name, quantizer in model.named_modules():
if (
not isinstance(quantizer, FP8StaticLinearQuantizer)
or name in quantize_config.ignored_layers
):
continue
static_proj = FP8StaticLinear(
quantizer.weight,
Expand All @@ -210,13 +239,19 @@ def quantize_activations(model, calibration_tokens):
cleanup_memory()


def save_quantized_model(model, activation_scheme, save_dir):
def save_quantized_model(
model: AutoModelForCausalLM,
quant_config: BaseQuantizeConfig,
save_dir: str,
ignored_layers: List[str] = [],
):
print(model)
print(f"Saving the model to {save_dir}")
static_q_dict = {
"quantization_config": {
"quant_method": "fp8",
"activation_scheme": activation_scheme,
"activation_scheme": quant_config.activation_scheme,
"ignored_layers": quant_config.ignored_layers,
}
}
model.config.update(static_q_dict)
Expand Down
9 changes: 7 additions & 2 deletions example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from transformers import AutoTokenizer

from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig

pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
Expand All @@ -8,9 +9,13 @@
examples = ["auto_fp8 is an easy-to-use model quantization library"]
examples = tokenizer(examples, return_tensors="pt").to("cuda")

ignore_patterns = ["re:.*gate"]

quantize_config = BaseQuantizeConfig(
quant_method="fp8", activation_scheme="dynamic"
) # or "static"
quant_method="fp8",
activation_scheme="dynamic", # or "static"
ignore_patterns=ignore_patterns,
)

model = AutoFP8ForCausalLM.from_pretrained(
pretrained_model_dir, quantize_config=quantize_config
Expand Down
33 changes: 33 additions & 0 deletions example_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from datasets import load_dataset
from transformers import AutoTokenizer

from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig

pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct"
quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8"

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True)

DATASET_ID = "mgoin/ultrachat_2k"
DATASET_SPLIT = "train_sft"
ds = load_dataset(DATASET_ID, split=DATASET_SPLIT)
ds = ds.map(
lambda batch: {
"text": tokenizer.apply_chat_template(batch["messages"], tokenize=False)
}
)
examples = [sample["text"] for sample in ds]
tokenizer.pad_token = tokenizer.eos_token
examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to(
"cuda"
)

quantize_config = BaseQuantizeConfig(
quant_method="fp8", activation_scheme="static"
) # or "static"

model = AutoFP8ForCausalLM.from_pretrained(
pretrained_model_dir, quantize_config=quantize_config
)
model.quantize(examples)
model.save_quantized(quantized_model_dir)
2 changes: 1 addition & 1 deletion examples/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from typing import Tuple

import torch
import transformers
import tqdm
import transformers
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from setuptools import setup, find_packages
from setuptools import find_packages, setup

setup(
name="auto_fp8",
Expand Down
4 changes: 3 additions & 1 deletion tests/test_auto_fp8.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import shutil

from transformers import AutoTokenizer

from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig
import shutil


def test_quantization():
Expand Down

0 comments on commit 4121b74

Please sign in to comment.