Skip to content

Commit

Permalink
Merge branch 'main' into support-kv-cache-scales
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin authored Jun 17, 2024
2 parents 142eac8 + b1c6ad6 commit bb15a62
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions auto_fp8/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,13 +272,12 @@ def quantize_activations(
cleanup_memory()

# Pass through calibration data to measure activation scales
with tqdm.tqdm(
total=calibration_tokens.shape[0], desc="Calibrating activation scales"
) as pbar:
for row_idx in range(calibration_tokens.shape[0]):
model(calibration_tokens[row_idx].reshape(1, -1))
cleanup_memory()
pbar.update(1)
with torch.inference_mode():
with tqdm.tqdm(total=calibration_tokens.shape[0], desc="Calibrating activation scales") as pbar:
for row_idx in range(calibration_tokens.shape[0]):
model(calibration_tokens[row_idx].reshape(1, -1))
cleanup_memory()
pbar.update(1)

# Replace dynamic quantizer observer with StaticLinear for export
for name, quantizer in model.named_modules():
Expand Down

0 comments on commit bb15a62

Please sign in to comment.