Skip to content

Commit

Permalink
skip zero math (#1199)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qubitium authored Feb 2, 2025
1 parent dbe31f9 commit d03d70b
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions gptqmodel/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

CPU = torch.device("cpu")

class GPTQ:
def __init__(self, layer):
Expand All @@ -41,7 +42,7 @@ def __init__(self, layer):
self.layer_copy = self._clone_layer()

self.rows, self.columns = self.layer_copy.shape[0], self.layer_copy.shape[1]
self.H = torch.zeros((self.columns, self.columns), device=self.device)
# self.H = torch.zeros((self.columns, self.columns), device=self.device)
self.nsamples = 0
self.quantizer = Quantizer()

Expand Down Expand Up @@ -87,7 +88,11 @@ def add_batch(self, inp, out):
inp = inp.permute([1, 0, 2])
inp = inp.flatten(1)

self.H *= self.nsamples / (self.nsamples + tmp)
if not hasattr(self, "H"):
self.H = torch.zeros((self.columns, self.columns), device=self.device)
else:
self.H *= self.nsamples / (self.nsamples + tmp)

self.nsamples += tmp
# inp = inp.float()
inp = math.sqrt(2 / self.nsamples) * inp.float()
Expand Down

0 comments on commit d03d70b

Please sign in to comment.