Skip to content

Commit

Permalink
Big refactor into package
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin committed May 10, 2024
1 parent 424a4c9 commit 4e6ed88
Show file tree
Hide file tree
Showing 9 changed files with 907 additions and 0 deletions.
189 changes: 189 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/
docs/source/getting_started/examples/*.rst
!**/*.template.rst

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

# VSCode
.vscode/

# DS Store
.DS_Store

# Results
*.csv

# Python pickle files
*.pkl

# Sphinx documentation
_build/

# vim swap files
*.swo
*.swp

# hip files generated by PyTorch
*.hip
*_hip*
hip_compat.h

# Benchmark dataset
*.json
2 changes: 2 additions & 0 deletions auto_fp8/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .modeling import AutoFP8ForCausalLM
from .config import BaseQuantizeConfig
10 changes: 10 additions & 0 deletions auto_fp8/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class BaseQuantizeConfig:
def __init__(self, quant_method="fp8", activation_scheme="static"):
if quant_method != "fp8":
raise ValueError("Only FP8 quantization is supported.")
if activation_scheme not in ["static", "dynamic"]:
raise ValueError(
"Invalid activation_scheme. Choose either 'static' or 'dynamic'."
)
self.quant_method = quant_method
self.activation_scheme = activation_scheme
118 changes: 118 additions & 0 deletions auto_fp8/modeling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import torch
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
from auto_fp8.quantize import (
quantize_weights,
quantize_activations,
save_quantized_model,
)
from auto_fp8.config import BaseQuantizeConfig


class AutoFP8ForCausalLM:
def __init__(
self,
model: PreTrainedModel,
quantize_config: BaseQuantizeConfig,
):
# super().__init__()

self.model = model
self.model_type = self.model.config.model_type
self.quantize_config = quantize_config
self.config = self.model.config

@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
quantize_config: BaseQuantizeConfig,
**model_init_kwargs,
):
"""load un-quantized pretrained model to cpu"""

if not torch.cuda.is_available():
raise EnvironmentError(
"Load pretrained model to do quantization requires CUDA available."
)

def skip(*args, **kwargs):
pass

torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip

# Parameters related to loading from Hugging Face Hub
cache_dir = model_init_kwargs.pop("cache_dir", None)
force_download = model_init_kwargs.pop("force_download", False)
resume_download = model_init_kwargs.pop("resume_download", False)
proxies = model_init_kwargs.pop("proxies", None)
local_files_only = model_init_kwargs.pop("local_files_only", False)
use_auth_token = model_init_kwargs.pop("use_auth_token", None)
revision = model_init_kwargs.pop("revision", None)
subfolder = model_init_kwargs.pop("subfolder", "")
commit_hash = model_init_kwargs.pop("_commit_hash", None)

cached_file_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"proxies": proxies,
"resume_download": resume_download,
"local_files_only": local_files_only,
"use_auth_token": use_auth_token,
"revision": revision,
"subfolder": subfolder,
"_commit_hash": commit_hash,
}

torch.cuda.empty_cache()

# Important defaults
if not hasattr(model_init_kwargs, "torch_dtype"):
model_init_kwargs["torch_dtype"] = "auto"

if not hasattr(model_init_kwargs, "device_map"):
model_init_kwargs["device_map"] = "auto"

merged_kwargs = {**model_init_kwargs, **cached_file_kwargs}
print(merged_kwargs)
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path, **merged_kwargs
)

model_config = model.config.to_dict()
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
if any(k in model_config for k in seq_len_keys):
for key in seq_len_keys:
if key in model_config:
model.seqlen = model_config[key]
break
else:
print(
"can't get model's sequence length from model config, will set to 2048."
)
model.seqlen = 2048
model.eval()

return cls(model, quantize_config)

def quantize(self, calibration_tokens):
def _prepare_calibration_data(calibration_tokens):
if hasattr(calibration_tokens, "input_ids"):
return calibration_tokens.input_ids
return calibration_tokens

if self.quantize_config.activation_scheme == "dynamic":
quantize_weights(self.model)
else:
quantize_weights(self.model)
quantize_activations(
self.model, _prepare_calibration_data(calibration_tokens)
)

def save_quantized(self, save_dir):
save_quantized_model(
self.model,
activation_scheme=self.quantize_config.activation_scheme,
save_dir=save_dir,
)
Loading

0 comments on commit 4e6ed88

Please sign in to comment.