From dbe31f955cc3cad56f309bf084e634c695ede4db Mon Sep 17 00:00:00 2001 From: Qubitium-ModelCloud Date: Sun, 2 Feb 2025 09:49:02 +0800 Subject: [PATCH] reduce peak memory and reduce quant time (#1198) --- gptqmodel/quantization/gptq.py | 58 +++++++++++++++++++--------------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index 4606f0dba..d00bc39e1 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -30,7 +30,6 @@ logger = setup_logger() -# TODO do we really need max precision? torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False @@ -46,6 +45,12 @@ def __init__(self, layer): self.nsamples = 0 self.quantizer = Quantizer() + def shape(self): + if hasattr(self, "layer"): + return self.layer.weight.shape + else: + return (0, 0) + def _clone_layer(self): clone = self.layer.weight.data.clone() @@ -58,9 +63,9 @@ def _clone_layer(self): return clone.float() def add_batch(self, inp, out): - if os.environ.get("DEBUG"): - self.inp1 = inp - self.out1 = out + # if os.environ.get("DEBUG"): + # self.inp1 = inp + # self.out1 = out if len(inp.shape) == 2: inp = inp.unsqueeze(0) @@ -238,12 +243,12 @@ def quantize( W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) - if os.environ.get("DEBUG"): - self.layer.weight.data[:, :i2] = Q[:, :i2] - self.layer.weight.data[:, i2:] = W[:, i2:] - - logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) - logger.debug(torch.sum(Losses)) + # if os.environ.get("DEBUG"): + # self.layer.weight.data[:, :i2] = Q[:, :i2] + # self.layer.weight.data[:, i2:] = W[:, i2:] + # + # logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) + # logger.debug(torch.sum(Losses)) torch_sync(self.device) @@ -270,15 +275,15 @@ def quantize( Q = Q.t() if Q.shape != self.layer.weight.shape: - self.layer.weight.data = Q.cpu().reshape(self.layer.weight.shape).type_as(self.layer.weight.data) + self.layer.weight.data = Q.reshape(self.layer.weight.shape).type_as(self.layer.weight.data) else: - self.layer.weight.data = Q.cpu().type_as(self.layer.weight.data) + self.layer.weight.data = Q.type_as(self.layer.weight.data) # move back to self.dev self.layer.weight.data = self.layer.weight.data.to(device=self.device) - if os.environ.get("DEBUG"): - logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) + # if os.environ.get("DEBUG"): + # logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) if scale == []: scale.append(self.quantizer.scale) @@ -291,18 +296,19 @@ def quantize( return scale, zero, g_idx, duration, avg_loss, percdamp def free(self): - if os.environ.get("DEBUG"): - self.inp1 = None - self.out1 = None - - self.H = None - self.Losses = None - self.Trace = None - - self.quantizer = None - self.layer_copy = None - - torch_empty_cache(self.device) + # if os.environ.get("DEBUG"): + # self.inp1 = None + # self.out1 = None + # del self.inp1 + # del self.out1 + + if hasattr(self, "H"): + del self.H + del self.quantizer + del self.layer_copy + del self.layer + + # torch_empty_cache(self.device) __all__ = ["GPTQ"]