From 5f221f35309e13eff204eb4791fbc4c8a3ef6deb Mon Sep 17 00:00:00 2001 From: Qubitium-ModelCloud Date: Sun, 2 Feb 2025 12:36:13 +0800 Subject: [PATCH] Rename some layers var/method to module (#1201) * fix bad var/method names: quant loop happens per layer but actual quant is per module * add logs --- gptqmodel/models/base.py | 10 ++++++++-- gptqmodel/models/loader.py | 18 +++++++++--------- gptqmodel/models/writer.py | 18 +++++++++--------- gptqmodel/utils/model.py | 34 ++++++++++++++++++---------------- tests/test_perplexity.py | 2 ++ 5 files changed, 46 insertions(+), 36 deletions(-) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index d0fe7b23..97a76dd4 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -36,7 +36,7 @@ from ..utils.device import get_cpu_usage_memory, get_gpu_usage_memory from ..utils.importer import select_quant_linear from ..utils.logger import setup_logger -from ..utils.model import (MODALITY, check_to_quantized, find_layers, get_device, +from ..utils.model import (MODALITY, check_to_quantized, find_modules, get_device, get_module, get_module_by_name_prefix, get_moe_layer_modules, move_to, nested_move_to, normalize_tokenizer, pack_model) from ..utils.progress import ProgressBar @@ -608,7 +608,7 @@ def store_lm_head_input_hook(_, args, kwargs): move_to(layer, self.quantize_config.device) cur_layer_device = get_device(layer) - full = find_layers(layer, name=self.lm_head if is_lm_head else "") + full = find_modules(layer, name=self.lm_head if is_lm_head else "") modules = [[self.lm_head]] if is_lm_head else layer_modules for index, names in enumerate(modules): subset = {n: full[n] for n in names if n in full} @@ -619,6 +619,7 @@ def store_lm_head_input_hook(_, args, kwargs): sym = self.quantize_config.sym mse = self.quantize_config.mse + # dynamic overrides if self.quantize_config.dynamic is not None: layer_name = self.lm_head if is_lm_head else f"{self.layers_node}.{i}.{name}" @@ -661,6 +662,7 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): else: handle.append(subset[name].register_forward_hook(add_batch(name))) + logger.info(f"layer-{i}-{name}: Begin Forward() Pass") fwd_start = time.time() for j in range(num_batches): layer_input = [] @@ -723,6 +725,8 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): damp_percent = self.quantize_config.dynamic_get(layer_name, "damp_percent", damp_percent) static_groups = self.quantize_config.dynamic_get(layer_name, "static_groups", static_groups) + + logger.info(f"Quantizing module START: {name}, {gptq[name].shape()}") scale, zero, g_idx, duration, avg_loss, damp_percent = gptq[name].quantize( percdamp=damp_percent, group_size=group_size, @@ -762,7 +766,9 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): move_to(g_idx, CPU), ) gptq[name].free() + logger.info(f"Quantizing module END: {name}, {gptq[name].shape()}") + logger.info(f"layer-{i}-{name}: Begin Forward() Pass 2 Post-Quant") for j in range(num_batches): layer_input = [] for k, layer_inp in enumerate(layer_inputs[j]): diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index 5c9d1b2d..27526e9f 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -38,7 +38,7 @@ from ..utils.logger import setup_logger from ..utils.marlin import (_validate_marlin_compatibility, _validate_marlin_device_support, prepare_model_for_marlin_load) -from ..utils.model import (auto_dtype, convert_gptq_v1_to_v2_format, find_layers, get_checkpoints, +from ..utils.model import (auto_dtype, convert_gptq_v1_to_v2_format, find_modules, get_checkpoints, get_moe_layer_modules, gptqmodel_post_init, load_checkpoint_in_model_then_tie_weights, make_quant, normalize_tokenizer, simple_dispatch_model, verify_model_hash, verify_sharded_model_hashes) @@ -430,25 +430,25 @@ def skip(*args, **kwargs): cls.layer_modules = get_moe_layer_modules(layer_modules=cls.layer_modules, num_experts=num_experts) - layers = find_layers(model) - ignore_layers = [cls.lm_head] + cls.base_modules + modules = find_modules(model) + ignore_modules = [cls.lm_head] + cls.base_modules - for name in list(layers.keys()): + for name in list(modules.keys()): # allow loading of quantized lm_head if qcfg.lm_head and name == cls.lm_head: continue - if any(name.startswith(ignore_layer) for ignore_layer in ignore_layers) or all( - not name.endswith(ignore_layer) for sublist in cls.layer_modules for ignore_layer in sublist + if any(name.startswith(ignore_module) for ignore_module in ignore_modules) or all( + not name.endswith(ignore_module) for sublist in cls.layer_modules for ignore_module in sublist ): - # log non-lm-head quantizerd layers only + # log non-lm-head quantizerd modules only if name is not cls.lm_head: logger.info(f"The layer {name} is not quantized.") - del layers[name] + del modules[name] preload_qlinear_kernel = make_quant( model, - layers, + modules, qcfg.bits, qcfg.group_size, backend=backend, diff --git a/gptqmodel/models/writer.py b/gptqmodel/models/writer.py index 38570950..b086ad7c 100644 --- a/gptqmodel/models/writer.py +++ b/gptqmodel/models/writer.py @@ -38,7 +38,7 @@ META_FIELD_URI, META_QUANTIZER_GPTQMODEL, META_VALUE_URI, MIN_VERSION_WITH_V2) from ..utils.backend import BACKEND from ..utils.logger import setup_logger -from ..utils.model import (convert_gptq_v2_to_v1_format, copy_py_files, find_layers, +from ..utils.model import (convert_gptq_v2_to_v1_format, copy_py_files, find_modules, get_model_files_size, get_moe_layer_modules, get_state_dict_for_save, load_checkpoint_in_model_then_tie_weights, make_quant) from ..utils.torch import torch_empty_cache @@ -354,25 +354,25 @@ def skip(*args, **kwargs): _ = get_moe_layer_modules(layer_modules=self.layer_modules, num_experts=num_experts) - layers = find_layers(model) - ignore_layers = [self.lm_head] + self.base_modules + modules = find_modules(model) + ignore_modules = [self.lm_head] + self.base_modules - for name in list(layers.keys()): + for name in list(modules.keys()): # allow loading of quantized lm_head if qcfg.lm_head and name == self.lm_head: continue - if any(name.startswith(ignore_layer) for ignore_layer in ignore_layers) or all( - not name.endswith(ignore_layer) for sublist in self.layer_modules for ignore_layer in sublist + if any(name.startswith(ignore_module) for ignore_module in ignore_modules) or all( + not name.endswith(ignore_module) for sublist in self.layer_modules for ignore_module in sublist ): - # log non-lm-head quantizerd layers only + # log non-lm-head quantizerd modules only if name is not self.lm_head: logger.info(f"The layer {name} is not quantized.") - del layers[name] + del modules[name] make_quant( model, - layers, + modules, qcfg.bits, qcfg.group_size, backend=BACKEND.AUTO, diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 14572042..f11026ca 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -102,7 +102,7 @@ def nested_move_to(v, device): return v -def find_layers(module, layers=None, name=""): +def find_modules(module, layers=None, name=""): if not layers: layers = SUPPORTS_MODULE_TYPES @@ -111,7 +111,7 @@ def find_layers(module, layers=None, name=""): return {name: module} res = {} for name1, child in module.named_children(): - res.update(find_layers(child, layers=layers, name=name + "." + name1 if name != "" else name1)) + res.update(find_modules(child, layers=layers, name=name + "." + name1 if name != "" else name1)) return res @@ -423,22 +423,24 @@ def convert_gptq_v2_to_v1_format( return model -def pack_layer(name, qlayers, quantizers, layers, QuantLinear, pbar): +def pack_module(name, qModules, quantizers, layers, pbar=None): # Limit pack() thread usage to avoid auto-parallizataion regression with tctl.threadpool_limits(limits=1): - pbar.set_description(f"Packing {name}") + if pbar: + pbar.set_description(f"Packing {name}") quantizers[name], scale, zero, g_idx = quantizers[name] - layer_device = qlayers[name].device - qlayers[name].to(CPU) + layer_device = qModules[name].device + qModules[name].to(CPU) layers[name], scale, zero, g_idx = ( layers[name].to(CPU), scale.to(CPU), zero.to(CPU), g_idx.to(CPU) if g_idx is not None else None, ) - qlayers[name].pack(layers[name], scale, zero, g_idx) - qlayers[name].to(layer_device) - pbar.progress() + qModules[name].pack(layers[name], scale, zero, g_idx) + qModules[name].to(layer_device) + if pbar: + pbar.progress() def pack_model( @@ -455,7 +457,7 @@ def pack_model( parallel_packing: bool = True, pack_dtype: torch.dtype = None, ): - QuantLinear = select_quant_linear( + quantLinear = select_quant_linear( bits=bits, dynamic=dynamic, group_size=group_size, @@ -471,8 +473,8 @@ def pack_model( logger.info("Packing model...") - layers = find_layers(model) - layers = {n: layers[n] for n in quantizers} + modules = find_modules(model) + modules = {n: modules[n] for n in quantizers} make_quant( model, quantizers, @@ -486,8 +488,8 @@ def pack_model( dynamic=dynamic, pack_dtype=pack_dtype, ) - qlayers = find_layers(model, [QuantLinear]) - names = list(qlayers.keys()) + qModules = find_modules(model, [quantLinear]) + names = list(qModules.keys()) if parallel_packing: max_workers = 2 @@ -497,13 +499,13 @@ def pack_model( with ThreadPoolExecutor(max_workers=max_workers) as executor: with ProgressBar(total=len(names)) as pbar: def wrapper(name): - pack_layer(name, qlayers, quantizers, layers, QuantLinear, pbar) + pack_module(name, qModules, quantizers, modules, pbar) for _ in executor.map(wrapper, names): pass logger.info("Model packed.") - return QuantLinear + return quantLinear def verify_model_hash(file_path: str, verify_hash: str): diff --git a/tests/test_perplexity.py b/tests/test_perplexity.py index 08e826e6..aeb24fae 100644 --- a/tests/test_perplexity.py +++ b/tests/test_perplexity.py @@ -27,6 +27,7 @@ from gptqmodel.quantization.config import FORMAT, QUANT_METHOD, AutoRoundQuantizeConfig, QuantizeConfig # noqa: E402 from gptqmodel.utils import Perplexity # noqa: E402 from gptqmodel.utils.rocm import IS_ROCM # noqa: E402 +from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 from parameterized import parameterized # noqa: E402 from transformers import AutoModelForCausalLM, AutoTokenizer # noqa: E402 @@ -167,6 +168,7 @@ def test_quantized_perplexity(self, method: QUANT_METHOD, format: FORMAT, bits: ) del model + torch_empty_cache() model = GPTQModel.load( tmp_dir,