Skip to content

Commit

Permalink
fix examples/perplexity (#1191)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qubitium authored Jan 31, 2025
1 parent fb60ef3 commit 275dabb
Showing 1 changed file with 11 additions and 21 deletions.
32 changes: 11 additions & 21 deletions examples/benchmark/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,27 @@
import argparse
import os

import torch
from gptqmodel.utils import Perplexity
from transformers import AutoTokenizer

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

if __name__ == "__main__":
"""
Example usage.
Example usage:
Default usage with GPT2 model:
python examples/benchmark/perplexity.py
Specify GPTQ quantized model:
python examples/benchmark/perplexity.py \
--model_name LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit \
--is_quantized \
--backend AUTO
--model ModelCloud/Llama-3.2-1B-Instruct-gptqmodel-4bit-vortex-v2.5 \
--is_quantized
Change your dataset:
python examples/benchmark/perplexity.py --dataset_path tiny_shakespeare
"""
parser = argparse.ArgumentParser(description="Calculate Perplexity for a model.")
parser.add_argument("--model_name", type=str, default="gpt2", help="Model name.")
parser.add_argument("--model_basename", type=str, default=None, help="Model file's basename.")
parser.add_argument("--n_ctx", type=int, default=512, help="Context size.")
parser.add_argument("--n_batch", type=int, default=512, help="Batch size.")
parser.add_argument("--model", type=str, default="ModelCloud/Llama-3.2-1B-Instruct-gptqmodel-4bit-vortex-v2.5", help="Model name.")
parser.add_argument("--n_ctx", type=int, default=1024, help="Context size.")
parser.add_argument("--n_batch", type=int, default=1024, help="Batch size.")
parser.add_argument("--dataset_path", type=str, default="wikitext", help="Path to the dataset.")
parser.add_argument("--dataset_name", type=str, default=None, help="Name of the dataset.")
parser.add_argument("--split", type=str, default="test", help="Dataset split to use.")
Expand All @@ -56,32 +49,29 @@
parser.add_argument("--is_quantized", action="store_true", help="Is the model GPTQ quantized?")
parser.add_argument("--use_fast_tokenizer", action="store_true", help="Whether to use fast tokenizer")
parser.add_argument("--trust_remote_code", action="store_true", help="Whether to use remote code")
parser.add_argument("--backend", choices=['AUTO', 'TRITON', 'EXLLAMA_V2', 'MARLIN', 'CUDA', 'BITBLAS', 'IPEX'], help="Whether to use BACKEND format")
parser.add_argument("--backend", choices=['auto', 'marlin', 'exllama_v1', 'exllama_v2', 'triton', 'cuda', 'torch', 'ipex', 'bitblas'], default='auto', help="Whether to use BACKEND format")
args = parser.parse_args()

os.environ["TOKENIZERS_PARALLELISM"] = "false"

tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=args.use_fast_tokenizer)
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=args.use_fast_tokenizer)
if not tokenizer.pad_token_id:
tokenizer.pad_token_id = tokenizer.eos_token_id

if args.is_quantized:
from gptqmodel import BACKEND, GPTQModel

model = GPTQModel.load(
args.model_name,
args.model,
device_map="auto",
model_basename=args.model_basename,
trust_remote_code=args.trust_remote_code,
backend=BACKEND(args.backend.lower()),
)
else:
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
args.model_name,
args.model,
device_map="auto",
torch_dtype=torch.float16,
torch_dtype="auto",
trust_remote_code=args.trust_remote_code,
)

Expand Down

0 comments on commit 275dabb

Please sign in to comment.