Skip to content

Commit

Permalink
reduce peak memory and reduce quant time (#1198)
Browse files Browse the repository at this point in the history
  • Loading branch information
Qubitium authored Feb 2, 2025
1 parent f095cb0 commit dbe31f9
Showing 1 changed file with 32 additions and 26 deletions.
58 changes: 32 additions & 26 deletions gptqmodel/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

logger = setup_logger()

# TODO do we really need max precision?
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False

Expand All @@ -46,6 +45,12 @@ def __init__(self, layer):
self.nsamples = 0
self.quantizer = Quantizer()

def shape(self):
if hasattr(self, "layer"):
return self.layer.weight.shape
else:
return (0, 0)

def _clone_layer(self):
clone = self.layer.weight.data.clone()

Expand All @@ -58,9 +63,9 @@ def _clone_layer(self):
return clone.float()

def add_batch(self, inp, out):
if os.environ.get("DEBUG"):
self.inp1 = inp
self.out1 = out
# if os.environ.get("DEBUG"):
# self.inp1 = inp
# self.out1 = out

if len(inp.shape) == 2:
inp = inp.unsqueeze(0)
Expand Down Expand Up @@ -238,12 +243,12 @@ def quantize(

W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])

if os.environ.get("DEBUG"):
self.layer.weight.data[:, :i2] = Q[:, :i2]
self.layer.weight.data[:, i2:] = W[:, i2:]

logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
logger.debug(torch.sum(Losses))
# if os.environ.get("DEBUG"):
# self.layer.weight.data[:, :i2] = Q[:, :i2]
# self.layer.weight.data[:, i2:] = W[:, i2:]
#
# logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
# logger.debug(torch.sum(Losses))

torch_sync(self.device)

Expand All @@ -270,15 +275,15 @@ def quantize(
Q = Q.t()

if Q.shape != self.layer.weight.shape:
self.layer.weight.data = Q.cpu().reshape(self.layer.weight.shape).type_as(self.layer.weight.data)
self.layer.weight.data = Q.reshape(self.layer.weight.shape).type_as(self.layer.weight.data)
else:
self.layer.weight.data = Q.cpu().type_as(self.layer.weight.data)
self.layer.weight.data = Q.type_as(self.layer.weight.data)

# move back to self.dev
self.layer.weight.data = self.layer.weight.data.to(device=self.device)

if os.environ.get("DEBUG"):
logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
# if os.environ.get("DEBUG"):
# logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2))

if scale == []:
scale.append(self.quantizer.scale)
Expand All @@ -291,18 +296,19 @@ def quantize(
return scale, zero, g_idx, duration, avg_loss, percdamp

def free(self):
if os.environ.get("DEBUG"):
self.inp1 = None
self.out1 = None

self.H = None
self.Losses = None
self.Trace = None

self.quantizer = None
self.layer_copy = None

torch_empty_cache(self.device)
# if os.environ.get("DEBUG"):
# self.inp1 = None
# self.out1 = None
# del self.inp1
# del self.out1

if hasattr(self, "H"):
del self.H
del self.quantizer
del self.layer_copy
del self.layer

# torch_empty_cache(self.device)


__all__ = ["GPTQ"]

0 comments on commit dbe31f9

Please sign in to comment.