diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index d2057cca..d0fe7b23 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -899,7 +899,7 @@ def save( else: self.save_pretrained(save_dir, **kwargs) - def compile(self, backend="inductor", mode="reduce-overhead"): + def compile(self, backend="inductor", mode="max-autotune"): if not self.quantized: logger.warning("model is not quantized, skip compiling...") return self @@ -914,9 +914,12 @@ def compile(self, backend="inductor", mode="reduce-overhead"): try: self.model = torch.compile(self.model, fullgraph=True, backend=backend, mode=mode) - except Exception: - logger.info("Compiling model again with `fullgraph=False`") - self.model = torch.compile(self.model, fullgraph=False, backend=backend, mode=mode) + except Exception as e: + logger.info(f"Compiling model again with `fullgraph=False`; `full-graph=True` compile failed: {e}") + try: + self.model = torch.compile(self.model, fullgraph=False, backend=backend, mode=mode) + except Exception as e: + logger.info(f"Compiling model failed: running model in non-compiled mode. {e}") return self def serve(self, diff --git a/gptqmodel/nn_modules/qlinear/__init__.py b/gptqmodel/nn_modules/qlinear/__init__.py index 5b072f17..c6a2aed1 100644 --- a/gptqmodel/nn_modules/qlinear/__init__.py +++ b/gptqmodel/nn_modules/qlinear/__init__.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import math import sys from typing import List, Optional, Tuple @@ -39,13 +39,24 @@ class BaseQuantLinear(nn.Module): SUPPORTS_DEVICES: List[DEVICE] = None SUPPORTS_PLATFORM: List[PLATFORM] = None - def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures: int, outfeatures: int, pack_dtype: t.dtype, *args, + def __init__(self, + bits: int, + group_size: int, + desc_act: bool, + sym: bool, + in_features: int, + out_features: int, + bias: bool, + pack_dtype: t.dtype, + register_buffers: bool = False, + register_buffers_in_features: int = None, + register_buffers_out_features: int = None, **kwargs): super().__init__() - self.infeatures = infeatures - self.outfeatures = outfeatures - self.group_size = group_size if group_size != -1 else infeatures + self.in_features = in_features + self.out_features = out_features + self.group_size = group_size if group_size != -1 else in_features self.bits = bits self.desc_act = desc_act self.pack_dtype = pack_dtype @@ -73,19 +84,64 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat raise ValueError("Unsupported weight_dtype. Only int16 and int32 are supported.") self.pack_factor = self.pack_dtype_bits // self.bits - _, err = self._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, infeatures=infeatures,outfeatures=outfeatures, pack_dtype=pack_dtype) + _, err = self._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, in_features=in_features, out_features=out_features, pack_dtype=pack_dtype) if err: raise err + # most kernels share same buffers so they can share same register buffer code + if register_buffers: + # some kernels auto-pads in/out features + in_features = self.in_features if not register_buffers_in_features else register_buffers_in_features + out_features = self.out_features if not register_buffers_out_features else register_buffers_out_features + + self.register_buffer( + "qweight", + t.zeros((in_features // self.pack_dtype_bits * self.bits, out_features), dtype=self.pack_dtype), + ) + self.register_buffer( + "qzeros", + t.zeros( + ( + math.ceil(in_features / self.group_size), + out_features // self.pack_dtype_bits * self.bits, + ), + dtype=self.pack_dtype, + ), + ) + self.register_buffer( + "scales", + t.zeros( + (math.ceil(in_features / self.group_size), out_features), + dtype=t.float16, # Scales are always float16 + ), + ) + self.register_buffer( + "g_idx", + t.tensor([i // self.group_size for i in range(in_features)], dtype=t.int32), + ) + if bias: + self.register_buffer("bias", t.zeros(out_features, dtype=t.float16)) + else: + self.bias = None + @classmethod # custom quant linear class can override this and add custom checks - def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures:int=None, - outfeatures:int=None, pack_dtype:t.dtype=None, dynamic:Optional[dict]=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[ + def validate( + cls, + bits: int, + group_size: int, + desc_act: bool, + sym: bool, + in_features:int=None, + out_features:int=None, + pack_dtype:t.dtype=None, + dynamic:Optional[dict]=None, + device:Optional[DEVICE]=None, + trainable:Optional[bool]=None) -> Tuple[ bool, Optional[Exception]]: - validate, err = cls._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, - infeatures=infeatures, outfeatures=outfeatures, pack_dtype=pack_dtype, dynamic=dynamic, - device=device, trainable=trainable) - return validate, err + return cls._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, + in_features=in_features, out_features=out_features, pack_dtype=pack_dtype, + dynamic=dynamic, device=device, trainable=trainable) @classmethod # internal method and should not be overriden @@ -121,8 +177,8 @@ def verify_supports_params(cls): raise ValueError(f"{cls.__name__}.{name} cannot be None or an empty list.") @classmethod - def _validate(cls, bits: int=4, group_size: int=128, desc_act: bool=False, sym: bool=False, pack_dtype:t.dtype=None, dynamic:Optional[dict]=None, infeatures:int=None, - outfeatures:int=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[bool, Optional[Exception]]: + def _validate(cls, bits: int=4, group_size: int=128, desc_act: bool=False, sym: bool=False, pack_dtype:t.dtype=None, dynamic:Optional[dict]=None, in_features:int=None, + out_features:int=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[bool, Optional[Exception]]: cls.verify_supports_params() if pack_dtype not in cls.SUPPORTS_PACK_DTYPES: @@ -193,20 +249,20 @@ def _validate(cls, bits: int=4, group_size: int=128, desc_act: bool=False, sym: err = f"{cls} only supports `{cls.SUPPORTS_DESC_ACT}` bits: actual desc_act = `{desc_act}` for layer `{layer}`" return False, NotImplementedError(err) - if infeatures is not None: - validate = all(infeatures % in_fea == 0 for in_fea in cls.SUPPORTS_IN_FEATURES_DIVISIBLE_BY) + if in_features is not None: + validate = all(in_features % in_fea == 0 for in_fea in cls.SUPPORTS_IN_FEATURES_DIVISIBLE_BY) if not validate: - err = f"{cls}: `infeatures` must be divisible by {cls.SUPPORTS_IN_FEATURES_DIVISIBLE_BY}." + err = f"{cls}: `in_features` must be divisible by {cls.SUPPORTS_IN_FEATURES_DIVISIBLE_BY}." return False, NotImplementedError(err) - validate = infeatures % group_size == 0 or cls.SUPPORTS_AUTO_PADDING + validate = in_features % group_size == 0 or cls.SUPPORTS_AUTO_PADDING if not validate: - err = f"{cls}: `infeatures` must be divisible by `group_size: {group_size}`." + err = f"{cls}: `in_features` must be divisible by `group_size: {group_size}`." return False, NotImplementedError(err) - if outfeatures is not None: - validate = all(outfeatures % out_fea == 0 for out_fea in cls.SUPPORTS_OUT_FEATURES_DIVISIBLE_BY) + if out_features is not None: + validate = all(out_features % out_fea == 0 for out_fea in cls.SUPPORTS_OUT_FEATURES_DIVISIBLE_BY) if not validate: - err = f"{cls}: `outfeatures` must be divisible by {cls.SUPPORTS_OUT_FEATURES_DIVISIBLE_BY}." + err = f"{cls}: `out_features` must be divisible by {cls.SUPPORTS_OUT_FEATURES_DIVISIBLE_BY}." return False, NotImplementedError(err) return True, None diff --git a/gptqmodel/nn_modules/qlinear/bitblas.py b/gptqmodel/nn_modules/qlinear/bitblas.py index 02fca774..ac13db07 100644 --- a/gptqmodel/nn_modules/qlinear/bitblas.py +++ b/gptqmodel/nn_modules/qlinear/bitblas.py @@ -116,8 +116,8 @@ def __init__( group_size: int, desc_act: bool, sym: bool, - infeatures: int, - outfeatures: int, + in_features: int, + out_features: int, pack_dtype: torch.dtype, bias: bool, enable_tuning: bool = True, @@ -127,18 +127,28 @@ def __init__( layout: str = "nt", **kwargs, ): - super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, infeatures=infeatures, outfeatures=outfeatures, pack_dtype=pack_dtype, **kwargs) + super().__init__( + bits=bits, + group_size=group_size, + sym=sym, + desc_act=desc_act, + in_features=in_features, + out_features=out_features, + bias=bias, + pack_dtype=pack_dtype, + register_buffers=False, + **kwargs) import_bitblas() - self._validate_parameters(group_size, infeatures, outfeatures) + self._validate_parameters(group_size, in_features, out_features) self.opt_features = opt_features self.target = BITBLAS_TARGET self._configure_bitblas_matmul( enable_tuning, fast_decoding, bias, propagate_b, layout, bits ) - self._initialize_buffers(infeatures, outfeatures, bias) + self._initialize_buffers(in_features, out_features, bias) self.reset_parameters() @classmethod @@ -148,12 +158,12 @@ def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: return cls._validate(**args) def _validate_parameters( - self, group_size: int, infeatures: int, outfeatures: int + self, group_size: int, in_features: int, out_features: int ): - if infeatures % group_size != 0: - raise ValueError("`infeatures` must be divisible by `group_size`.") + if in_features % group_size != 0: + raise ValueError("`in_features` must be divisible by `group_size`.") - def _initialize_buffers(self, infeatures: int, outfeatures: int, bias: bool): + def _initialize_buffers(self, in_features: int, out_features: int, bias: bool): self.register_buffer( "qweight", torch.zeros( @@ -164,7 +174,7 @@ def _initialize_buffers(self, infeatures: int, outfeatures: int, bias: bool): self.register_buffer( "scales", torch.zeros( - (outfeatures, infeatures // self.group_size), dtype=self.TORCH_DTYPE + (out_features, in_features // self.group_size), dtype=self.TORCH_DTYPE ), ) if self.zeros_mode == "quantized": @@ -172,20 +182,20 @@ def _initialize_buffers(self, infeatures: int, outfeatures: int, bias: bool): self.register_buffer( "zeros", torch.zeros( - (infeatures // self.group_size, outfeatures // storage_nbit * self.bits), dtype=self.TORCH_STORAGE_DTYPE + (in_features // self.group_size, out_features // storage_nbit * self.bits), dtype=self.TORCH_STORAGE_DTYPE ), ) else: self.register_buffer( "zeros", torch.zeros( - (outfeatures, infeatures // self.group_size), dtype=self.TORCH_DTYPE + (out_features, in_features // self.group_size), dtype=self.TORCH_DTYPE ), ) if bias: self.register_buffer( - "bias", torch.zeros((outfeatures), dtype=self.TORCH_DTYPE) + "bias", torch.zeros((out_features), dtype=self.TORCH_DTYPE) ) else: self.bias = None @@ -200,8 +210,8 @@ def _configure_bitblas_matmul( W_dtype = f"uint{bits}" matmul_config = MatmulConfig( M=self.opt_features, - N=self.outfeatures, - K=self.infeatures, + N=self.out_features, + K=self.in_features, A_dtype=bitblas_dtype, W_dtype=W_dtype, out_dtype=bitblas_dtype, @@ -341,7 +351,7 @@ def pack(self, linear, scales, zeros, g_idx=None): def repack_from_gptq(self, gptq_module): from bitblas.quantization.utils import general_compress - # qweight in gptq old quant linear stored with (outfeatures, infeatures), should be transposed. + # qweight in gptq old quant linear stored with (out_features, in_features), should be transposed. qweight = gptq_module.qweight.T.contiguous().view(self.TORCH_STORAGE_DTYPE) if self.bitblas_matmul.weight_transform is not None: qweight = self.bitblas_matmul.weight_transform(qweight.cpu()).cuda() diff --git a/gptqmodel/nn_modules/qlinear/dynamic_cuda.py b/gptqmodel/nn_modules/qlinear/dynamic_cuda.py index 04950deb..7901992a 100644 --- a/gptqmodel/nn_modules/qlinear/dynamic_cuda.py +++ b/gptqmodel/nn_modules/qlinear/dynamic_cuda.py @@ -56,9 +56,10 @@ def __init__( group_size: int, sym: bool, desc_act: bool, - infeatures: int, - outfeatures: int, + in_features: int, + out_features: int, bias: bool, + pack_dtype: torch.dtype, kernel_switch_threshold=128, **kwargs, ): @@ -66,8 +67,18 @@ def __init__( raise ValueError( f"Trying to use the cuda backend, but could not import the C++/CUDA dependencies with the following error: {gptqmodel_cuda_import_exception}" ) - super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, infeatures=infeatures, - outfeatures=outfeatures, bias=bias, **kwargs) + super().__init__( + bits=bits, + group_size=group_size, + sym=sym, + desc_act=desc_act, + in_features=in_features, + out_features=out_features, + bias=bias, + pack_dtype=pack_dtype, + **kwargs) + + # assert in_features % 64 == 0 and out_features % 64 == 0 self.kernel_switch_threshold = kernel_switch_threshold @@ -75,11 +86,9 @@ def __init__( self.gptqmodel_cuda = gptqmodel_cuda_256 # fall back to cuda_64 - if infeatures % 256 != 0 or outfeatures % 256 != 0: + if in_features % 256 != 0 or out_features % 256 != 0: self.gptqmodel_cuda = gptqmodel_cuda_64 - assert infeatures % 64 == 0 and outfeatures % 64 == 0 - if self.bits == 4: self.qmatmul = self.gptqmodel_cuda.vecquant4matmul elif self.bits == 8: @@ -96,10 +105,10 @@ def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: return cls._validate(**args) def forward(self, x: torch.Tensor): - out_shape = x.shape[:-1] + (self.outfeatures,) + out_shape = x.shape[:-1] + (self.out_features,) x = x.reshape(-1, x.shape[-1]) - assert x.device.type == "cuda" + # assert x.device.type == "cuda" # switch to torch kernel when input shape is >= kernel_switch_threshold # cuda is only optimized for < kernel_switch_threshold and will run slower than torch otherwise @@ -108,7 +117,7 @@ def forward(self, x: torch.Tensor): f"Input shape `{x.shape[0]}` >= `{self.kernel_switch_threshold}` is not optimized for cuda kernel: dynamic switching to torch kernel.") return self._forward(x, x.dtype, out_shape) - out = torch.zeros((x.shape[0], self.outfeatures), device=x.device, dtype=torch.float32) + out = torch.zeros((x.shape[0], self.out_features), device=x.device, dtype=torch.float32) self.qmatmul( x.to(dtype=torch.float32), self.qweight, @@ -120,7 +129,7 @@ def forward(self, x: torch.Tensor): out = out.to(x.dtype).reshape(out_shape) if self.bias is not None: - out = out + self.bias + out.add_(self.bias) return out diff --git a/gptqmodel/nn_modules/qlinear/exllama.py b/gptqmodel/nn_modules/qlinear/exllama.py index de171b44..dc30d8a7 100644 --- a/gptqmodel/nn_modules/qlinear/exllama.py +++ b/gptqmodel/nn_modules/qlinear/exllama.py @@ -74,53 +74,35 @@ class ExllamaQuantLinear(PackableQuantLinear): """Linear layer implementation with per-group 4-bit quantization of the weights""" - def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures: int, outfeatures: int, pack_dtype: torch.dtype, bias: bool, **kwargs,): + def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, in_features: int, out_features: int, pack_dtype: torch.dtype, bias: bool, **kwargs, ): if exllama_import_exception is not None: raise ValueError( f"Trying to use the exllama backend, but could not import the C++/CUDA dependencies with the following error: {exllama_import_exception}" ) # backup original values - self.original_outfeatures = outfeatures - self.original_infeatures = infeatures + self.original_out_features = out_features + self.original_in_features = in_features # auto pad - group_size = group_size if group_size != -1 else infeatures - outfeatures = outfeatures + (-outfeatures % 32) - infeatures = infeatures + (-infeatures % group_size) - - super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, infeatures=infeatures, outfeatures=outfeatures, pack_dtype=pack_dtype, **kwargs) - - self.register_buffer( - "qweight", - torch.zeros((self.original_infeatures // self.pack_dtype_bits * self.bits, self.original_outfeatures), dtype=torch.int32), - ) - self.register_buffer( - "qzeros", - torch.zeros( - ( - math.ceil(self.original_infeatures / self.group_size), - self.original_outfeatures // self.pack_dtype_bits * self.bits, - ), - dtype=torch.int32, - ), - ) - self.register_buffer( - "scales", - torch.zeros( - (math.ceil(self.original_infeatures / self.group_size), self.original_outfeatures), - dtype=torch.float16, - ), - ) - self.register_buffer( - "g_idx", - torch.tensor([i // self.group_size for i in range(self.original_infeatures)], dtype=torch.int32), - ) - - if bias: - self.register_buffer("bias", torch.zeros(self.original_outfeatures, dtype=torch.float16)) - else: - self.bias = None + group_size = group_size if group_size != -1 else in_features + out_features = out_features + (-out_features % 32) + in_features = in_features + (-in_features % group_size) + self.in_features_padding_size = in_features - self.original_in_features + self.in_features_padding_shape = (0, self.in_features_padding_size) + + super().__init__( + bits=bits, + group_size=group_size, + sym=sym, desc_act=desc_act, + in_features=in_features, + out_features=out_features, + bias=bias, + pack_dtype=pack_dtype, + register_buffers=True, + register_buffers_in_features=self.original_in_features, + register_buffers_out_feature=self.original_out_features, + **kwargs) @classmethod def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: @@ -130,16 +112,16 @@ def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: def post_init(self): # resize due to padding after model weights have been loaded - if self.outfeatures != self.original_outfeatures or self.infeatures != self.original_infeatures: - self.qweight.resize_(self.infeatures // self.pack_dtype_bits * self.bits, self.outfeatures) + if self.out_features != self.original_out_features or self.in_features != self.original_in_features: + self.qweight.resize_(self.in_features // self.pack_dtype_bits * self.bits, self.out_features) self.qzeros.resize_( - math.ceil(self.infeatures / self.group_size), - self.outfeatures // self.pack_dtype_bits * self.bits + math.ceil(self.in_features / self.group_size), + self.out_features // self.pack_dtype_bits * self.bits ) - self.scales.resize_((math.ceil(self.infeatures / self.group_size), self.outfeatures),) - self.g_idx = torch.tensor([i // self.group_size for i in range(self.infeatures)], dtype=torch.int32, device=self.g_idx.device) + self.scales.resize_((math.ceil(self.in_features / self.group_size), self.out_features), ) + self.g_idx = torch.tensor([i // self.group_size for i in range(self.in_features)], dtype=torch.int32, device=self.g_idx.device) if self.bias is not None: - self.bias.resize_(self.outfeatures) + self.bias.resize_(self.out_features) self.width = self.qweight.shape[1] @@ -163,9 +145,9 @@ def forward(self, x): x = x.half() # TODO: need to run checks to make sure there is no performance regression padding with F.pad - # if infeatures is padded, we need to pad the input as well - if x.size(-1) != self.infeatures: - x = F.pad(x, (0, self.infeatures - self.original_infeatures)) + # if in_features is padded, we need to pad the input as well + if x.size(-1) != self.in_features: + x = F.pad(x, self.in_features_padding_shape) out = ext_q4_matmul(x, self.q4, self.width) diff --git a/gptqmodel/nn_modules/qlinear/exllamav2.py b/gptqmodel/nn_modules/qlinear/exllamav2.py index e73b1ec3..f564b1cf 100644 --- a/gptqmodel/nn_modules/qlinear/exllamav2.py +++ b/gptqmodel/nn_modules/qlinear/exllamav2.py @@ -138,8 +138,8 @@ class ExllamaV2QuantLinear(BaseQuantLinear): """Linear layer implementation with per-group 4-bit quantization of the weights""" - def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures: int, outfeatures: int, pack_dtype: torch.dtype, - bias: bool, **kwargs,): + def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, in_features: int, out_features: int, pack_dtype: torch.dtype, + bias: bool, **kwargs, ): if exllama_v2_import_exception is not None: raise ValueError( @@ -147,51 +147,33 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat ) # backup original values - self.original_outfeatures = outfeatures - self.original_infeatures = infeatures + self.original_out_features = out_features + self.original_in_features = in_features # auto pad - group_size = group_size if group_size != -1 else infeatures - outfeatures = outfeatures + (-outfeatures % 32) - infeatures = infeatures + (-infeatures % group_size) - - super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, infeatures=infeatures, outfeatures=outfeatures, pack_dtype=pack_dtype, **kwargs) + group_size = group_size if group_size != -1 else in_features + out_features = out_features + (-out_features % 32) + in_features = in_features + (-in_features % group_size) + self.in_features_padding_size = in_features - self.original_in_features + self.in_features_padding_shape = (0, self.in_features_padding_size) + + super().__init__( + bits=bits, + group_size=group_size, + sym=sym, + desc_act=desc_act, + in_features=in_features, + out_features=out_features, + bias=bias, + pack_dtype=pack_dtype, + register_buffers=True, + register_buffers_in_features=self.original_in_features, + register_buffers_out_feature=self.original_out_features, + **kwargs) self.q_handle = None self.q_tensors = None - # I need to register the tensors, otherwise, we won't be able to load them easily using transformers ... - self.register_buffer( - "qweight", - torch.zeros((self.original_infeatures // self.pack_dtype_bits * self.bits, self.original_outfeatures), dtype=torch.int32), - ) - self.register_buffer( - "qzeros", - torch.zeros( - ( - math.ceil(self.original_infeatures / self.group_size), - self.original_outfeatures // self.pack_dtype_bits * self.bits, - ), - dtype=torch.int32, - ), - ) - self.register_buffer( - "scales", - torch.zeros( - (math.ceil(self.original_infeatures / self.group_size), self.original_outfeatures), - dtype=torch.float16, - ), - ) - self.register_buffer( - "g_idx", - torch.tensor([i // self.group_size for i in range(self.original_infeatures)], dtype=torch.int32), - ) - - if bias: - self.register_buffer("bias", torch.zeros((self.original_outfeatures), dtype=torch.float16)) - else: - self.bias = None - @classmethod def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: if exllama_v2_import_exception is not None: @@ -200,16 +182,16 @@ def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: def post_init(self, temp_dq): # resize due to padding after model weights have been loaded - if self.outfeatures != self.original_outfeatures or self.infeatures != self.original_infeatures: - self.qweight.resize_(self.infeatures // self.pack_dtype_bits * self.bits, self.outfeatures) + if self.out_features != self.original_out_features or self.in_features != self.original_in_features: + self.qweight.resize_(self.in_features // self.pack_dtype_bits * self.bits, self.out_features) self.qzeros.resize_( - math.ceil(self.infeatures / self.group_size), - self.outfeatures // self.pack_dtype_bits * self.bits + math.ceil(self.in_features / self.group_size), + self.out_features // self.pack_dtype_bits * self.bits ) - self.scales.resize_(math.ceil(self.infeatures / self.group_size), self.outfeatures) - self.g_idx = torch.tensor([i // self.group_size for i in range(self.infeatures)], dtype=torch.int32, device=self.g_idx.device) + self.scales.resize_(math.ceil(self.in_features / self.group_size), self.out_features) + self.g_idx = torch.tensor([i // self.group_size for i in range(self.in_features)], dtype=torch.int32, device=self.g_idx.device) if self.bias is not None: - self.bias.resize_(self.outfeatures) + self.bias.resize_(self.out_features) self.q_tensors = { "qweight": self.qweight, @@ -229,11 +211,11 @@ def forward(self, x, force_cuda=False): x = x.half() # TODO: need to run checks to make sure there is no performance regression padding with F.pad - # if infeatures is padded, we need to pad the input as well - if x.size(-1) != self.infeatures: - x = F.pad(x, (0, self.infeatures - self.original_infeatures)) + # if in_features is padded, we need to pad the input as well + if x.size(-1) != self.in_features: + x = F.pad(x, self.in_features_padding_shape) - output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda) + output = ext_gemm_half_q_half(x, self.q_handle, self.out_features, force_cuda) if self.bias is not None: output.add_(self.bias) @@ -241,10 +223,10 @@ def forward(self, x, force_cuda=False): return output def temp_dq_size(self): - return self.infeatures * self.outfeatures * 2 + 128 + return self.in_features * self.out_features * 2 + 128 def temp_fwd_size(self, max_input_len, max_batch_size): - return self.outfeatures * max_input_len * max_batch_size * 4 + 128 + return self.out_features * max_input_len * max_batch_size * 4 + 128 def scratch_space_fixed(self, max_input_len=2048, max_batch_size=8): return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size) diff --git a/gptqmodel/nn_modules/qlinear/ipex.py b/gptqmodel/nn_modules/qlinear/ipex.py index a6aacd30..cb1120c4 100644 --- a/gptqmodel/nn_modules/qlinear/ipex.py +++ b/gptqmodel/nn_modules/qlinear/ipex.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math from typing import Optional, Tuple import numpy as np @@ -111,50 +110,30 @@ def __init__( group_size: int, desc_act: bool, sym: bool, - infeatures: int, - outfeatures: int, + in_features: int, + out_features: int, pack_dtype: torch.dtype, bias: bool, kernel_switch_threshold=128, training=False, **kwargs, ): - super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, infeatures=infeatures, outfeatures=outfeatures, pack_dtype=pack_dtype, **kwargs) + super().__init__( + bits=bits, + group_size=group_size, + sym=sym, + desc_act=desc_act, + in_features=in_features, + out_features=out_features, + bias=bias, + pack_dtype=pack_dtype, + register_buffers=True, + **kwargs) # FIX ME IPEX CPU has no float16 support self.weight_dtype = torch.float16 if HAS_XPU else torch.bfloat16 self.init_ipex = False - self.register_buffer( - "qweight", - torch.zeros((infeatures // self.pack_dtype_bits * self.bits, outfeatures), dtype=torch.int32), - ) - self.register_buffer( - "qzeros", - torch.zeros( - ( - math.ceil(infeatures / self.group_size), - outfeatures // self.pack_dtype_bits * self.bits, - ), - dtype=torch.int32, - ), - ) - self.register_buffer( - "scales", - torch.zeros( - (math.ceil(infeatures / self.group_size), outfeatures), - dtype=self.weight_dtype, - ), - ) - self.register_buffer( - "g_idx", - torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32), - ) - if bias: - self.register_buffer("bias", torch.zeros((outfeatures), dtype=self.weight_dtype)) - else: - self.bias = None - self.kernel_switch_threshold = kernel_switch_threshold self.training = training @@ -174,8 +153,8 @@ def post_init(self): def init_ipex_linear(self, x: torch.Tensor): if not self.training and HAS_IPEX and not x.requires_grad: self.ipex_linear = IPEXWeightOnlyQuantizedLinear.from_weight(self.qweight, self.scales, self.qzeros, - self.infeatures, self.outfeatures, None, self.bias, - self.group_size, self.g_idx, quant_method=QuantMethod.GPTQ_GEMM, dtype=QuantDtype.INT4) + self.in_features, self.out_features, None, self.bias, + self.group_size, self.g_idx, quant_method=QuantMethod.GPTQ_GEMM, dtype=QuantDtype.INT4) def pack(self, linear, scales, zeros, g_idx=None): W = linear.weight.data.clone() @@ -229,7 +208,7 @@ def forward(self, x: torch.Tensor): if self.wf.device != x.device: self.wf = self.wf.to(x.device) - out_shape = x.shape[:-1] + (self.outfeatures,) + out_shape = x.shape[:-1] + (self.out_features,) x = x.reshape(-1, x.shape[-1]) x_dtype = x.dtype zeros = torch.bitwise_right_shift( @@ -264,7 +243,8 @@ def forward(self, x: torch.Tensor): out = torch.matmul(x, weights) out = out.to(x_dtype) out = out.reshape(out_shape) - out = out + self.bias if self.bias is not None else out + if self.bias is not None: + out.add_(self.bias) return out diff --git a/gptqmodel/nn_modules/qlinear/marlin.py b/gptqmodel/nn_modules/qlinear/marlin.py index 61f617d2..27abcff1 100644 --- a/gptqmodel/nn_modules/qlinear/marlin.py +++ b/gptqmodel/nn_modules/qlinear/marlin.py @@ -173,22 +173,32 @@ class MarlinQuantLinear(BaseQuantLinear): # for transformers/optimum tests compat QUANT_TYPE = "marlin" - def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures: int, outfeatures: int, pack_dtype: torch.dtype, + def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, in_features: int, out_features: int, pack_dtype: torch.dtype, bias: bool, **kwargs): if marlin_import_exception is not None: raise ValueError( f"Trying to use the marlin backend, but could not import the C++/CUDA dependencies with the following error: {marlin_import_exception}" ) - self.original_infeatures = infeatures - self.original_outfeatures = outfeatures + self.original_in_features = in_features + self.original_out_features = out_features if desc_act and group_size == -1: # In this case, act_order == True is the same as act_order == False # (since we have only one group per output channel) desc_act = False - super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, infeatures=infeatures, outfeatures=outfeatures, pack_dtype=pack_dtype, **kwargs) + super().__init__( + bits=bits, + group_size=group_size, + sym=sym, + desc_act=desc_act, + in_features=in_features, + out_features=out_features, + bias=bias, + pack_dtype=pack_dtype, + register_buffers=False, + **kwargs) # Determine sharding if marlin_repeat_scales_on_all_ranks(desc_act, @@ -197,18 +207,18 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat # By setting scale_dim == None, weight_loader will # repeat the scales on each GPU in TP>1 case. scales_and_zp_input_dim = None - scales_and_zp_size = self.infeatures // self.group_size + scales_and_zp_size = self.in_features // self.group_size else: # By setting scale_dim == 0, weight_loader will # shard the scales in TP>1 case. scales_and_zp_input_dim = 0 - scales_and_zp_size = self.infeatures // self.group_size + scales_and_zp_size = self.in_features // self.group_size # Quantized weights qweight = Parameter( torch.empty( - self.infeatures // self.pack_factor, - self.outfeatures, + self.in_features // self.pack_factor, + self.out_features, dtype=torch.int32, ), requires_grad=False, @@ -226,7 +236,7 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat # Activation order g_idx = Parameter( torch.empty( - self.infeatures, + self.in_features, dtype=torch.int32, ), requires_grad=False, @@ -244,7 +254,7 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat scales = Parameter( torch.empty( scales_and_zp_size, - self.outfeatures, + self.out_features, dtype=torch.float16, ), requires_grad=False, @@ -261,7 +271,7 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat qzeros = Parameter( torch.empty( scales_and_zp_size, - self.outfeatures // self.pack_factor, + self.out_features // self.pack_factor, dtype=torch.int32, ), requires_grad=False, @@ -284,7 +294,7 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat self.is_k_full = marlin_is_k_full(self.desc_act, is_row_parallel=False) if bias: - self.register_buffer("bias", torch.zeros((self.outfeatures), dtype=torch.half)) + self.register_buffer("bias", torch.zeros((self.out_features), dtype=torch.float16)) else: self.bias = None @@ -317,7 +327,7 @@ def post_init(self): device = self.qweight.device # Allocate marlin workspace self.workspace = marlin_make_workspace( - self.outfeatures, device) + self.out_features, device) # Handle sorting for activation reordering if needed. if self.desc_act: @@ -335,8 +345,8 @@ def post_init(self): marlin_qweight = gptqmodel_marlin_kernels.gptq_marlin_repack( self.qweight, self.g_idx_sort_indices, - self.infeatures, - self.outfeatures, + self.in_features, + self.out_features, self.bits, self.pack_dtype_bits) replace_tensor(self, "qweight", marlin_qweight) @@ -344,14 +354,14 @@ def post_init(self): # Permute scales from autogptq format to marlin format. marlin_scales = marlin_permute_scales( self.scales, - size_k=self.infeatures, - size_n=self.outfeatures, + size_k=self.in_features, + size_n=self.out_features, group_size=self.group_size) replace_tensor(self, "scales", marlin_scales) def forward(self, A: torch.Tensor): if A.dtype != torch.float16: - A = A.half() + A = A.to(torch.float16) return apply_gptq_marlin_linear( input=A.contiguous() if self.is_lm_head else A, @@ -362,8 +372,8 @@ def forward(self, A: torch.Tensor): g_idx_sort_indices=self.g_idx_sort_indices, workspace=self.workspace, num_bits=self.bits, - output_size_per_partition=self.outfeatures, - input_size_per_partition=self.infeatures, + output_size_per_partition=self.out_features, + input_size_per_partition=self.in_features, is_k_full=self.is_k_full, bias=self.bias) diff --git a/gptqmodel/nn_modules/qlinear/torch.py b/gptqmodel/nn_modules/qlinear/torch.py index 3be53233..85a64d85 100644 --- a/gptqmodel/nn_modules/qlinear/torch.py +++ b/gptqmodel/nn_modules/qlinear/torch.py @@ -49,49 +49,29 @@ def __init__( group_size: int, sym: bool, desc_act: bool, - infeatures: int, - outfeatures: int, + in_features: int, + out_features: int, bias: bool, pack_dtype: torch.dtype, **kwargs, ): - super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, infeatures=infeatures, outfeatures=outfeatures, pack_dtype=pack_dtype, **kwargs) - - if self.group_size != self.infeatures: - self.padded_infeatures = self.infeatures + (-self.infeatures % self.group_size) + super().__init__( + bits=bits, + group_size=group_size, + sym=sym, + desc_act=desc_act, + in_features=in_features, + out_features=out_features, + bias=bias, + pack_dtype=pack_dtype, + register_buffers=True, + **kwargs) + + if self.group_size != self.in_features: + self.padded_infeatures = self.in_features + (-self.in_features % self.group_size) else: self.padded_infeatures = self.padded_infeatures - self.register_buffer( - "qweight", - torch.zeros((self.infeatures // self.pack_dtype_bits * self.bits, self.outfeatures), dtype=self.pack_dtype), - ) - self.register_buffer( - "qzeros", - torch.zeros( - ( - math.ceil(self.infeatures / self.group_size), - self.outfeatures // self.pack_dtype_bits * self.bits, - ), - dtype=self.pack_dtype, - ), - ) - self.register_buffer( - "scales", - torch.zeros( - (math.ceil(self.infeatures / self.group_size), self.outfeatures), - dtype=torch.float16, # Scales are always float16 - ), - ) - self.register_buffer( - "g_idx", - torch.tensor([i // self.group_size for i in range(self.infeatures)], dtype=torch.int32), - ) - if bias: - self.register_buffer("bias", torch.zeros((self.outfeatures), dtype=torch.float16)) - else: - self.bias = None - if self.bits in [2, 4, 8]: self.wf = torch.tensor(list(range(0, self.pack_dtype_bits, self.bits)), dtype=torch.int32).unsqueeze(0) elif self.bits == 3: @@ -105,13 +85,13 @@ def __init__( ).reshape(1, 3, 12) def post_init(self): - if self.padded_infeatures != self.infeatures: - self.qweight.resize_(self.padded_infeatures // self.pack_dtype_bits * self.bits, self.outfeatures) + if self.padded_infeatures != self.in_features: + self.qweight.resize_(self.padded_infeatures // self.pack_dtype_bits * self.bits, self.out_features) self.qzeros.resize_( math.ceil(self.padded_infeatures / self.group_size), - self.outfeatures // self.pack_dtype_bits * self.bits + self.out_features // self.pack_dtype_bits * self.bits ) - self.scales.resize_((math.ceil(self.padded_infeatures / self.group_size), self.outfeatures), ) + self.scales.resize_((math.ceil(self.padded_infeatures / self.group_size), self.out_features), ) self.g_idx = torch.tensor([i // self.group_size for i in range(self.padded_infeatures)], dtype=torch.int32, device=self.g_idx.device) @@ -119,9 +99,9 @@ def post_init(self): def forward(self, x: torch.Tensor): if x.size(-1) != self.padded_infeatures: - x = F.pad(x, (0, self.padded_infeatures - self.infeatures)) + x = F.pad(x, (0, self.padded_infeatures - self.in_features)) - out_shape = x.shape[:-1] + (self.outfeatures,) + out_shape = x.shape[:-1] + (self.out_features,) x = x.reshape(-1, x.shape[-1]) out = self._forward(x, x.dtype, out_shape) return out @@ -132,7 +112,7 @@ def _forward(self, x, x_dtype, out_shape): out = torch.matmul(x, weights).reshape(out_shape).to(x_dtype) if self.bias is not None: - out = out + self.bias + out.add_(self.bias) return out # clear gptq only weights: useful in de-quantization @@ -209,7 +189,7 @@ def dequantize_model(model: nn.Module): if isinstance(module, TorchQuantLinear): # Create a new Linear layer with dequantized weights - new_module = nn.Linear(module.infeatures, module.outfeatures) + new_module = nn.Linear(module.in_features, module.out_features) new_module.weight = nn.Parameter(module.dequantize_weight().T.detach().to("cpu", torch.float16)) new_module.bias = module.bias diff --git a/gptqmodel/nn_modules/qlinear/tritonv2.py b/gptqmodel/nn_modules/qlinear/tritonv2.py index a50c8892..43c39ba5 100644 --- a/gptqmodel/nn_modules/qlinear/tritonv2.py +++ b/gptqmodel/nn_modules/qlinear/tritonv2.py @@ -71,46 +71,26 @@ class TritonV2QuantLinear(PackableQuantLinear, TritonModuleMixin): dequant and matmul into single kernel.add() """ - def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures, outfeatures, pack_dtype, bias, **kwargs,): + def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, in_features, out_features, pack_dtype, bias, **kwargs, ): if not TRITON_AVAILABLE: raise ValueError(TRITON_INSTALL_HINT) - super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, infeatures=infeatures, outfeatures=outfeatures, pack_dtype=pack_dtype, **kwargs) - - if self.group_size != self.infeatures: - self.padded_infeatures = self.infeatures + (-self.infeatures % self.group_size) + super().__init__( + bits=bits, + group_size=group_size, + sym=sym, + desc_act=desc_act, + in_features=in_features, + out_features=out_features, + bias=bias, + pack_dtype=pack_dtype, + register_buffers=True, + **kwargs) + + if self.group_size != self.in_features: + self.padded_infeatures = self.in_features + (-self.in_features % self.group_size) else: self.padded_infeatures = self.padded_infeatures - self.register_buffer( - "qweight", - torch.zeros((infeatures // self.pack_factor, outfeatures), dtype=self.pack_dtype), - ) - self.register_buffer( - "qzeros", - torch.zeros( - ( - math.ceil(infeatures / self.group_size), - outfeatures // self.pack_factor, - ), - dtype=self.pack_dtype, - ), - ) - self.register_buffer( - "scales", - torch.zeros( - (math.ceil(infeatures / self.group_size), outfeatures), - dtype=torch.float16, - ), - ) - self.register_buffer( - "g_idx", - torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32), - ) - if bias: - self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16)) - else: - self.bias = None - @classmethod def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: if not TRITON_AVAILABLE: @@ -124,22 +104,22 @@ def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: return cls._validate(**args) def post_init(self): - if self.padded_infeatures != self.infeatures: - self.qweight.resize_(self.padded_infeatures // self.pack_factor, self.outfeatures) + if self.padded_infeatures != self.in_features: + self.qweight.resize_(self.padded_infeatures // self.pack_factor, self.out_features) self.qzeros.resize_( math.ceil(self.padded_infeatures / self.group_size), - self.outfeatures // self.pack_factor + self.out_features // self.pack_factor ) - self.scales.resize_((math.ceil(self.padded_infeatures / self.group_size), self.outfeatures), ) + self.scales.resize_((math.ceil(self.padded_infeatures / self.group_size), self.out_features), ) self.g_idx = torch.tensor([i // self.group_size for i in range(self.padded_infeatures)], dtype=torch.int32, device=self.g_idx.device) def forward(self, x): - # if infeatures is padded, we need to pad the input as well + # if in_features is padded, we need to pad the input as well if x.size(-1) != self.padded_infeatures: - x = F.pad(x, (0, self.padded_infeatures - self.infeatures)) + x = F.pad(x, (0, self.padded_infeatures - self.in_features)) - out_shape = x.shape[:-1] + (self.outfeatures,) + out_shape = x.shape[:-1] + (self.out_features,) out = QuantLinearFunction.apply( x.reshape(-1, x.shape[-1]), @@ -153,7 +133,7 @@ def forward(self, x): ) out = out.to(dtype=x.dtype).reshape(out_shape) if self.bias is not None: - out = out + self.bias + out.add_(self.bias) return out diff --git a/gptqmodel/nn_modules/triton_utils/dequant.py b/gptqmodel/nn_modules/triton_utils/dequant.py index 0e2e2e25..ec4fabc6 100644 --- a/gptqmodel/nn_modules/triton_utils/dequant.py +++ b/gptqmodel/nn_modules/triton_utils/dequant.py @@ -44,7 +44,7 @@ def dequant_kernel( pack_bits: tl.constexpr, maxq: tl.constexpr, bits: tl.constexpr, - outfeatures: tl.constexpr, + out_features: tl.constexpr, num_groups: tl.constexpr, X_BLOCK: tl.constexpr, ): @@ -52,15 +52,15 @@ def dequant_kernel( xoffset = tl.program_id(0) * X_BLOCK x_index = xoffset + tl.arange(0, X_BLOCK) xmask = x_index < numels - row_idx = x_index // outfeatures - col_idx = x_index % outfeatures + row_idx = x_index // out_features + col_idx = x_index % out_features elements_per_feature: tl.constexpr = pack_bits // bits # Load parameters g_idx = tl.load(g_idx_ptr + (row_idx), None, eviction_policy="evict_last") qweights = tl.load( - qweight_ptr + (col_idx + (outfeatures * (row_idx // elements_per_feature))), + qweight_ptr + (col_idx + (out_features * (row_idx // elements_per_feature))), None, ) @@ -72,13 +72,13 @@ def dequant_kernel( # tl.device_assert(g_idx >= 0, "index out of bounds: 0 <= tmp0 < 0") groups = tl.where(tmp2, tmp1, g_idx) # tmp3 are g_idx - scales = tl.load(scales_ptr + (col_idx + (outfeatures * groups)), None).to(tl.float32) + scales = tl.load(scales_ptr + (col_idx + (out_features * groups)), None).to(tl.float32) # Unpack weights weights = (qweights >> wf_weights) & maxq # bit shift qweight # Unpack zeros - qzero_ncols: tl.constexpr = outfeatures // elements_per_feature + qzero_ncols: tl.constexpr = out_features // elements_per_feature qzeros = tl.load( qzeros_ptr + ((qzero_ncols * groups) + (col_idx // elements_per_feature)), None, @@ -98,10 +98,10 @@ def dequant(qweight, scales, qzeros, g_idx, bits, pack_bits, maxq): """ num_groups = scales.shape[0] - outfeatures = scales.shape[1] - infeatures = g_idx.shape[0] + out_features = scales.shape[1] + in_features = g_idx.shape[0] - out = torch.empty((infeatures, outfeatures), device=qweight.device, dtype=torch.float16) + out = torch.empty((in_features, out_features), device=qweight.device, dtype=torch.float16) numels = out.numel() grid = lambda meta: (triton.cdiv(numels, meta["X_BLOCK"]),) # noqa: E731 @@ -115,7 +115,7 @@ def dequant(qweight, scales, qzeros, g_idx, bits, pack_bits, maxq): pack_bits=pack_bits, maxq=maxq, bits=bits, - outfeatures=outfeatures, + out_features=out_features, num_groups=num_groups, ) return out diff --git a/gptqmodel/utils/bitblas.py b/gptqmodel/utils/bitblas.py index b80096f6..fb7551ff 100644 --- a/gptqmodel/utils/bitblas.py +++ b/gptqmodel/utils/bitblas.py @@ -106,8 +106,8 @@ def convert_to_bitblas(model, model_quantlinear, qcfg: QuantizeConfig, sym: bool group_size=qcfg.group_size, sym=sym, desc_act=desc_act, - infeatures=module.infeatures, - outfeatures=module.outfeatures, + in_features=module.in_features, + out_features=module.out_features, pack_dtype=qcfg.pack_dtype, bias=module.bias is not None, enable_tuning=True diff --git a/gptqmodel/utils/marlin.py b/gptqmodel/utils/marlin.py index 3093de8d..13e9c9aa 100644 --- a/gptqmodel/utils/marlin.py +++ b/gptqmodel/utils/marlin.py @@ -124,22 +124,22 @@ def convert_to_marlin( group_size=module.group_size, sym=sym, desc_act=desc_act, - infeatures=module.original_infeatures, - outfeatures=module.original_outfeatures, + in_features=module.original_in_features, + out_features=module.original_out_features, pack_dtype=module.pack_dtype, bias=module.bias is not None, ) # workspace is never in the state_dict, thus we need to allocate it manually. - new_module.workspace = torch.zeros(new_module.outfeatures // 128 * 16, dtype=module.pack_dtype, device=module.device) + new_module.workspace = torch.zeros(new_module.out_features // 128 * 16, dtype=module.pack_dtype, device=module.device) # Dequantize the weight. if repack: import gptqmodel_marlin_cuda qweight = module.qweight - if new_module.infeatures != new_module.original_infeatures or new_module.outfeatures != new_module.original_outfeatures: - padded_qweight = torch.zeros((new_module.infeatures, new_module.outfeatures), dtype=torch.int, device=module.qweight.device) + if new_module.in_features != new_module.original_in_features or new_module.out_features != new_module.original_out_features: + padded_qweight = torch.zeros((new_module.in_features, new_module.out_features), dtype=torch.int, device=module.qweight.device) padded_qweight[:module.qweight.size(0), :module.qweight.size(1)] = qweight qweight = padded_qweight @@ -158,17 +158,17 @@ def convert_to_marlin( s = module.scales.data.clone() - if new_module.infeatures != new_module.original_infeatures or new_module.outfeatures != new_module.original_outfeatures: - padded_s = torch.zeros((s.size(0), new_module.outfeatures), dtype=torch.half, device=s.device) + if new_module.in_features != new_module.original_in_features or new_module.out_features != new_module.original_out_features: + padded_s = torch.zeros((s.size(0), new_module.out_features), dtype=torch.half, device=s.device) padded_s[:s.size(0), :s.size(1)] = s s = padded_s - if module.group_size != module.infeatures: + if module.group_size != module.in_features: s = s.reshape((1, -1)) s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm] else: s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single] - s = s.reshape((-1, new_module.outfeatures)).contiguous() + s = s.reshape((-1, new_module.out_features)).contiguous() new_module.B = marlin_repacked_weight new_module.s = s diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index fb5d6e7d..14572042 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -180,9 +180,18 @@ def make_quant( # else: # logger.info("make_quant: Testing linear: {linear}") - linear_instance = create_quant_layer(linear=linear, bits=bits, desc_act=desc_act, dynamic=dynamic, group_size=group_size, - module=module, names=names, sym=sym, device=device, lm_head_name=lm_head_name, - pack_dtype=pack_dtype) + linear_instance = create_quant_layer( + linear=linear, + bits=bits, + desc_act=desc_act, + dynamic=dynamic, + group_size=group_size, + module=module, + names=names, + sym=sym, + device=device, + lm_head_name=lm_head_name, + pack_dtype=pack_dtype) logger.info(f"make_quant: Selected linear: `{linear}`.") return linear_instance except NotImplementedError as e: @@ -194,8 +203,18 @@ def make_quant( raise ValueError("No compatible quant linear was found for this module: {module}") -def create_quant_layer(linear: nn.Module, bits: int, desc_act: bool, dynamic, group_size: int, module, names, sym: bool, - device: DEVICE, lm_head_name: str, pack_dtype: torch.dtype, +def create_quant_layer( + linear: nn.Module, + bits: int, + desc_act: bool, + dynamic, + group_size: int, + module, + names, + sym: bool, + device: DEVICE, + lm_head_name: str, + pack_dtype: torch.dtype, ) -> BaseQuantLinear: if isinstance(module, linear): return linear @@ -213,8 +232,8 @@ def create_quant_layer(linear: nn.Module, bits: int, desc_act: bool, dynamic, gr out_features = submodule.weight.shape[1] elif isinstance(submodule, BaseQuantLinear): # if submodule is already a quant layer, we need to get in_features and out_features from the submodule - in_features = submodule.infeatures - out_features = submodule.outfeatures + in_features = submodule.in_features + out_features = submodule.out_features else: raise NotImplementedError(f"Unsupported module {submodule}") @@ -246,9 +265,15 @@ def create_quant_layer(linear: nn.Module, bits: int, desc_act: bool, dynamic, gr # when loading a quantized model, device is target device passed in GPTQModel.load() # check in_features and out_features validate - _, err = linear.validate(bits=tmp_bits, group_size=tmp_group_size, desc_act=tmp_desc_act, sym=tmp_sym, - pack_dtype=tmp_pack_dtype, infeatures=in_features, outfeatures=out_features, - device=device) + _, err = linear.validate( + bits=tmp_bits, + group_size=tmp_group_size, + desc_act=tmp_desc_act, + sym=tmp_sym, + pack_dtype=tmp_pack_dtype, + in_features=in_features, + out_features=out_features, + device=device) if err is not None: raise err @@ -257,8 +282,8 @@ def create_quant_layer(linear: nn.Module, bits: int, desc_act: bool, dynamic, gr group_size=tmp_group_size, desc_act=tmp_desc_act, sym=tmp_sym, - infeatures=in_features, - outfeatures=out_features, + in_features=in_features, + out_features=out_features, pack_dtype=tmp_pack_dtype, bias=bias, #weight_dtype=submodule.qweight.dtype if isinstance(submodule, BaseQuantLinear) else submodule.weight.dtype, @@ -589,8 +614,8 @@ def gptqmodel_post_init(model, use_act_order: bool, quantize_config: QuantizeCon if use_act_order: device_to_buffers_size[device]["max_inner_outer_dim"] = max( device_to_buffers_size[device]["max_inner_outer_dim"], - submodule.infeatures, - submodule.outfeatures, + submodule.in_features, + submodule.out_features, ) if model_uses_exllama: diff --git a/tests/test_inference_speed.py b/tests/test_inference_speed.py index c372e925..1391b6ff 100644 --- a/tests/test_inference_speed.py +++ b/tests/test_inference_speed.py @@ -41,13 +41,13 @@ class TestInferenceSpeed(InferenceSpeed): @parameterized.expand( [ - (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.MARLIN, 219.87), - (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.CUDA, 50.46), - (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.EXLLAMA_V1, 218.35), - (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.EXLLAMA_V2, 222.24), + (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.MARLIN, 224.39), + (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.CUDA, 50.67), + (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.EXLLAMA_V1, 221.48), + (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.EXLLAMA_V2, 225.14), (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.TRITON, 163.94), (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.TORCH, 50.65), - # (InferenceSpeed.BITBLAS_NATIVE_MODEL_ID, BACKEND.BITBLAS, 1474), # Second time running bitblas, there is cache + (InferenceSpeed.BITBLAS_NATIVE_MODEL_ID, BACKEND.BITBLAS, 1129.44), # Second time running bitblas, there is cache ] ) def test_inference_speed(self, model_path, backend, tokens_per_second): diff --git a/tests/test_packing.py b/tests/test_packing.py index 484bff88..60d03885 100644 --- a/tests/test_packing.py +++ b/tests/test_packing.py @@ -89,8 +89,8 @@ def pack(self, qlinearCls): group_size=self.group_size, sym=True, desc_act=True, - infeatures=self.k, - outfeatures=self.n, + in_features=self.k, + out_features=self.n, pack_dtype=self.pack_dtype, bias=False) diff --git a/tests/test_packing_speed.py b/tests/test_packing_speed.py index 4b843117..a9e100be 100644 --- a/tests/test_packing_speed.py +++ b/tests/test_packing_speed.py @@ -90,8 +90,8 @@ def pack(self, qlinearCls): group_size=self.group_size, sym=True, desc_act=True, - infeatures=self.k, - outfeatures=self.n, + inf_eatures=self.k, + out_features=self.n, bias=False) qlinear.pack(self.linear, self.s.T, self.zeros.T, g_idx=None) diff --git a/tests/test_q4_exllama_v1.py b/tests/test_q4_exllama_v1.py index 14fbd4b4..2da5943e 100644 --- a/tests/test_q4_exllama_v1.py +++ b/tests/test_q4_exllama_v1.py @@ -1094,8 +1094,8 @@ def test_exllama(self): group_size=group_size, desc_act=False, sym=True, - infeatures=k, - outfeatures=n, + in_features=k, + out_features=n, bias=False, pack_dtype=pack_dtype, ) @@ -1113,7 +1113,7 @@ def test_exllama(self): linear = gptqmodel_post_init(linear, use_act_order=False) max_inner_outer_dim = max(k, n) - max_dq_buffer_size = linear.infeatures * linear.outfeatures + max_dq_buffer_size = linear.in_features * linear.out_features max_input_len = 2048 buffers = { "temp_state": torch.zeros((max_input_len, max_inner_outer_dim), dtype=torch.float16, device=device), diff --git a/tests/test_q4_exllama_v2.py b/tests/test_q4_exllama_v2.py index e1a239cf..2d4a7b4e 100644 --- a/tests/test_q4_exllama_v2.py +++ b/tests/test_q4_exllama_v2.py @@ -57,8 +57,8 @@ def test_exllamav2(self): group_size=group_size, desc_act=False, sym=True, - infeatures=k, - outfeatures=n, + in_features=k, + out_features=n, bias=False, pack_dtype=pack_dtype, )