From d03d70b826579ca5aace15c2305cce9f66fe7ad8 Mon Sep 17 00:00:00 2001 From: Qubitium-ModelCloud Date: Sun, 2 Feb 2025 10:22:41 +0800 Subject: [PATCH] skip zero math (#1199) --- gptqmodel/quantization/gptq.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index d00bc39e..87f50c6a 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -33,6 +33,7 @@ torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False +CPU = torch.device("cpu") class GPTQ: def __init__(self, layer): @@ -41,7 +42,7 @@ def __init__(self, layer): self.layer_copy = self._clone_layer() self.rows, self.columns = self.layer_copy.shape[0], self.layer_copy.shape[1] - self.H = torch.zeros((self.columns, self.columns), device=self.device) + # self.H = torch.zeros((self.columns, self.columns), device=self.device) self.nsamples = 0 self.quantizer = Quantizer() @@ -87,7 +88,11 @@ def add_batch(self, inp, out): inp = inp.permute([1, 0, 2]) inp = inp.flatten(1) - self.H *= self.nsamples / (self.nsamples + tmp) + if not hasattr(self, "H"): + self.H = torch.zeros((self.columns, self.columns), device=self.device) + else: + self.H *= self.nsamples / (self.nsamples + tmp) + self.nsamples += tmp # inp = inp.float() inp = math.sqrt(2 / self.nsamples) * inp.float()