diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index d00bc39e..87f50c6a 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -33,6 +33,7 @@ torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False +CPU = torch.device("cpu") class GPTQ: def __init__(self, layer): @@ -41,7 +42,7 @@ def __init__(self, layer): self.layer_copy = self._clone_layer() self.rows, self.columns = self.layer_copy.shape[0], self.layer_copy.shape[1] - self.H = torch.zeros((self.columns, self.columns), device=self.device) + # self.H = torch.zeros((self.columns, self.columns), device=self.device) self.nsamples = 0 self.quantizer = Quantizer() @@ -87,7 +88,11 @@ def add_batch(self, inp, out): inp = inp.permute([1, 0, 2]) inp = inp.flatten(1) - self.H *= self.nsamples / (self.nsamples + tmp) + if not hasattr(self, "H"): + self.H = torch.zeros((self.columns, self.columns), device=self.device) + else: + self.H *= self.nsamples / (self.nsamples + tmp) + self.nsamples += tmp # inp = inp.float() inp = math.sqrt(2 / self.nsamples) * inp.float()