Skip to content

Commit

Permalink
Merge pull request #3 from AnyISalIn/main
Browse files Browse the repository at this point in the history
fix vram leak in calibration
  • Loading branch information
mgoin authored May 8, 2024
2 parents e3b4e46 + a30af69 commit f895066
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torch.functional as F
import transformers
import tqdm
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

Expand Down Expand Up @@ -162,7 +163,7 @@ def forward(self, x):
def replace_module(model, name, new_module):
if "." in name:
parent_name = name.rsplit(".", 1)[0]
child_name = name[len(parent_name) + 1 :]
child_name = name[len(parent_name) + 1:]
parent = model.model.get_submodule(parent_name)
else:
parent_name = ""
Expand Down Expand Up @@ -197,8 +198,11 @@ def quantize_activations(model, calibration_tokens):
cleanup_memory()

# Calibration.
for row_idx in range(calibration_tokens.shape[0]):
_ = model(calibration_tokens[row_idx].reshape(1, -1))
with tqdm.tqdm(total=calibration_tokens.shape[0], desc="Calibrating") as pbar:
for row_idx in range(calibration_tokens.shape[0]):
model(calibration_tokens[row_idx].reshape(1, -1))
torch.cuda.empty_cache()
pbar.update(1)

# Replace quantizer with StaticLayer.
for name, quantizer in model.model.named_modules():
Expand Down

0 comments on commit f895066

Please sign in to comment.