Skip to content

Commit

Permalink
Push register buffer down to base class and rename all in/out features (
Browse files Browse the repository at this point in the history
#1193)

* push register buffer down to base cls and rename all in/out features

* format

* remove in_features padding math

* use in-place bias add

* fix bitblas + change compile to use `max-autotune`

* update data with pytorch 2.6.0
  • Loading branch information
Qubitium authored Feb 1, 2025
1 parent ddf3044 commit 9e4129c
Show file tree
Hide file tree
Showing 19 changed files with 367 additions and 350 deletions.
11 changes: 7 additions & 4 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,7 @@ def save(
else:
self.save_pretrained(save_dir, **kwargs)

def compile(self, backend="inductor", mode="reduce-overhead"):
def compile(self, backend="inductor", mode="max-autotune"):
if not self.quantized:
logger.warning("model is not quantized, skip compiling...")
return self
Expand All @@ -914,9 +914,12 @@ def compile(self, backend="inductor", mode="reduce-overhead"):

try:
self.model = torch.compile(self.model, fullgraph=True, backend=backend, mode=mode)
except Exception:
logger.info("Compiling model again with `fullgraph=False`")
self.model = torch.compile(self.model, fullgraph=False, backend=backend, mode=mode)
except Exception as e:
logger.info(f"Compiling model again with `fullgraph=False`; `full-graph=True` compile failed: {e}")
try:
self.model = torch.compile(self.model, fullgraph=False, backend=backend, mode=mode)
except Exception as e:
logger.info(f"Compiling model failed: running model in non-compiled mode. {e}")
return self

def serve(self,
Expand Down
100 changes: 78 additions & 22 deletions gptqmodel/nn_modules/qlinear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import sys
from typing import List, Optional, Tuple

Expand All @@ -39,13 +39,24 @@ class BaseQuantLinear(nn.Module):
SUPPORTS_DEVICES: List[DEVICE] = None
SUPPORTS_PLATFORM: List[PLATFORM] = None

def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures: int, outfeatures: int, pack_dtype: t.dtype, *args,
def __init__(self,
bits: int,
group_size: int,
desc_act: bool,
sym: bool,
in_features: int,
out_features: int,
bias: bool,
pack_dtype: t.dtype,
register_buffers: bool = False,
register_buffers_in_features: int = None,
register_buffers_out_features: int = None,
**kwargs):
super().__init__()

self.infeatures = infeatures
self.outfeatures = outfeatures
self.group_size = group_size if group_size != -1 else infeatures
self.in_features = in_features
self.out_features = out_features
self.group_size = group_size if group_size != -1 else in_features
self.bits = bits
self.desc_act = desc_act
self.pack_dtype = pack_dtype
Expand Down Expand Up @@ -73,19 +84,64 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat
raise ValueError("Unsupported weight_dtype. Only int16 and int32 are supported.")

self.pack_factor = self.pack_dtype_bits // self.bits
_, err = self._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, infeatures=infeatures,outfeatures=outfeatures, pack_dtype=pack_dtype)
_, err = self._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, in_features=in_features, out_features=out_features, pack_dtype=pack_dtype)
if err:
raise err

# most kernels share same buffers so they can share same register buffer code
if register_buffers:
# some kernels auto-pads in/out features
in_features = self.in_features if not register_buffers_in_features else register_buffers_in_features
out_features = self.out_features if not register_buffers_out_features else register_buffers_out_features

self.register_buffer(
"qweight",
t.zeros((in_features // self.pack_dtype_bits * self.bits, out_features), dtype=self.pack_dtype),
)
self.register_buffer(
"qzeros",
t.zeros(
(
math.ceil(in_features / self.group_size),
out_features // self.pack_dtype_bits * self.bits,
),
dtype=self.pack_dtype,
),
)
self.register_buffer(
"scales",
t.zeros(
(math.ceil(in_features / self.group_size), out_features),
dtype=t.float16, # Scales are always float16
),
)
self.register_buffer(
"g_idx",
t.tensor([i // self.group_size for i in range(in_features)], dtype=t.int32),
)
if bias:
self.register_buffer("bias", t.zeros(out_features, dtype=t.float16))
else:
self.bias = None

@classmethod
# custom quant linear class can override this and add custom checks
def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures:int=None,
outfeatures:int=None, pack_dtype:t.dtype=None, dynamic:Optional[dict]=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[
def validate(
cls,
bits: int,
group_size: int,
desc_act: bool,
sym: bool,
in_features:int=None,
out_features:int=None,
pack_dtype:t.dtype=None,
dynamic:Optional[dict]=None,
device:Optional[DEVICE]=None,
trainable:Optional[bool]=None) -> Tuple[
bool, Optional[Exception]]:
validate, err = cls._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym,
infeatures=infeatures, outfeatures=outfeatures, pack_dtype=pack_dtype, dynamic=dynamic,
device=device, trainable=trainable)
return validate, err
return cls._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym,
in_features=in_features, out_features=out_features, pack_dtype=pack_dtype,
dynamic=dynamic, device=device, trainable=trainable)

@classmethod
# internal method and should not be overriden
Expand Down Expand Up @@ -121,8 +177,8 @@ def verify_supports_params(cls):
raise ValueError(f"{cls.__name__}.{name} cannot be None or an empty list.")

@classmethod
def _validate(cls, bits: int=4, group_size: int=128, desc_act: bool=False, sym: bool=False, pack_dtype:t.dtype=None, dynamic:Optional[dict]=None, infeatures:int=None,
outfeatures:int=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[bool, Optional[Exception]]:
def _validate(cls, bits: int=4, group_size: int=128, desc_act: bool=False, sym: bool=False, pack_dtype:t.dtype=None, dynamic:Optional[dict]=None, in_features:int=None,
out_features:int=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[bool, Optional[Exception]]:
cls.verify_supports_params()

if pack_dtype not in cls.SUPPORTS_PACK_DTYPES:
Expand Down Expand Up @@ -193,20 +249,20 @@ def _validate(cls, bits: int=4, group_size: int=128, desc_act: bool=False, sym:
err = f"{cls} only supports `{cls.SUPPORTS_DESC_ACT}` bits: actual desc_act = `{desc_act}` for layer `{layer}`"
return False, NotImplementedError(err)

if infeatures is not None:
validate = all(infeatures % in_fea == 0 for in_fea in cls.SUPPORTS_IN_FEATURES_DIVISIBLE_BY)
if in_features is not None:
validate = all(in_features % in_fea == 0 for in_fea in cls.SUPPORTS_IN_FEATURES_DIVISIBLE_BY)
if not validate:
err = f"{cls}: `infeatures` must be divisible by {cls.SUPPORTS_IN_FEATURES_DIVISIBLE_BY}."
err = f"{cls}: `in_features` must be divisible by {cls.SUPPORTS_IN_FEATURES_DIVISIBLE_BY}."
return False, NotImplementedError(err)

validate = infeatures % group_size == 0 or cls.SUPPORTS_AUTO_PADDING
validate = in_features % group_size == 0 or cls.SUPPORTS_AUTO_PADDING
if not validate:
err = f"{cls}: `infeatures` must be divisible by `group_size: {group_size}`."
err = f"{cls}: `in_features` must be divisible by `group_size: {group_size}`."
return False, NotImplementedError(err)
if outfeatures is not None:
validate = all(outfeatures % out_fea == 0 for out_fea in cls.SUPPORTS_OUT_FEATURES_DIVISIBLE_BY)
if out_features is not None:
validate = all(out_features % out_fea == 0 for out_fea in cls.SUPPORTS_OUT_FEATURES_DIVISIBLE_BY)
if not validate:
err = f"{cls}: `outfeatures` must be divisible by {cls.SUPPORTS_OUT_FEATURES_DIVISIBLE_BY}."
err = f"{cls}: `out_features` must be divisible by {cls.SUPPORTS_OUT_FEATURES_DIVISIBLE_BY}."
return False, NotImplementedError(err)
return True, None

Expand Down
42 changes: 26 additions & 16 deletions gptqmodel/nn_modules/qlinear/bitblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ def __init__(
group_size: int,
desc_act: bool,
sym: bool,
infeatures: int,
outfeatures: int,
in_features: int,
out_features: int,
pack_dtype: torch.dtype,
bias: bool,
enable_tuning: bool = True,
Expand All @@ -127,18 +127,28 @@ def __init__(
layout: str = "nt",
**kwargs,
):
super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, infeatures=infeatures, outfeatures=outfeatures, pack_dtype=pack_dtype, **kwargs)
super().__init__(
bits=bits,
group_size=group_size,
sym=sym,
desc_act=desc_act,
in_features=in_features,
out_features=out_features,
bias=bias,
pack_dtype=pack_dtype,
register_buffers=False,
**kwargs)

import_bitblas()

self._validate_parameters(group_size, infeatures, outfeatures)
self._validate_parameters(group_size, in_features, out_features)

self.opt_features = opt_features
self.target = BITBLAS_TARGET
self._configure_bitblas_matmul(
enable_tuning, fast_decoding, bias, propagate_b, layout, bits
)
self._initialize_buffers(infeatures, outfeatures, bias)
self._initialize_buffers(in_features, out_features, bias)
self.reset_parameters()

@classmethod
Expand All @@ -148,12 +158,12 @@ def validate(cls, **args) -> Tuple[bool, Optional[Exception]]:
return cls._validate(**args)

def _validate_parameters(
self, group_size: int, infeatures: int, outfeatures: int
self, group_size: int, in_features: int, out_features: int
):
if infeatures % group_size != 0:
raise ValueError("`infeatures` must be divisible by `group_size`.")
if in_features % group_size != 0:
raise ValueError("`in_features` must be divisible by `group_size`.")

def _initialize_buffers(self, infeatures: int, outfeatures: int, bias: bool):
def _initialize_buffers(self, in_features: int, out_features: int, bias: bool):
self.register_buffer(
"qweight",
torch.zeros(
Expand All @@ -164,28 +174,28 @@ def _initialize_buffers(self, infeatures: int, outfeatures: int, bias: bool):
self.register_buffer(
"scales",
torch.zeros(
(outfeatures, infeatures // self.group_size), dtype=self.TORCH_DTYPE
(out_features, in_features // self.group_size), dtype=self.TORCH_DTYPE
),
)
if self.zeros_mode == "quantized":
storage_nbit = int("".join(c for c in self.STORAGE_DTYPE if c.isdigit()))
self.register_buffer(
"zeros",
torch.zeros(
(infeatures // self.group_size, outfeatures // storage_nbit * self.bits), dtype=self.TORCH_STORAGE_DTYPE
(in_features // self.group_size, out_features // storage_nbit * self.bits), dtype=self.TORCH_STORAGE_DTYPE
),
)
else:
self.register_buffer(
"zeros",
torch.zeros(
(outfeatures, infeatures // self.group_size), dtype=self.TORCH_DTYPE
(out_features, in_features // self.group_size), dtype=self.TORCH_DTYPE
),
)

if bias:
self.register_buffer(
"bias", torch.zeros((outfeatures), dtype=self.TORCH_DTYPE)
"bias", torch.zeros((out_features), dtype=self.TORCH_DTYPE)
)
else:
self.bias = None
Expand All @@ -200,8 +210,8 @@ def _configure_bitblas_matmul(
W_dtype = f"uint{bits}"
matmul_config = MatmulConfig(
M=self.opt_features,
N=self.outfeatures,
K=self.infeatures,
N=self.out_features,
K=self.in_features,
A_dtype=bitblas_dtype,
W_dtype=W_dtype,
out_dtype=bitblas_dtype,
Expand Down Expand Up @@ -341,7 +351,7 @@ def pack(self, linear, scales, zeros, g_idx=None):
def repack_from_gptq(self, gptq_module):
from bitblas.quantization.utils import general_compress

# qweight in gptq old quant linear stored with (outfeatures, infeatures), should be transposed.
# qweight in gptq old quant linear stored with (out_features, in_features), should be transposed.
qweight = gptq_module.qweight.T.contiguous().view(self.TORCH_STORAGE_DTYPE)
if self.bitblas_matmul.weight_transform is not None:
qweight = self.bitblas_matmul.weight_transform(qweight.cpu()).cuda()
Expand Down
31 changes: 20 additions & 11 deletions gptqmodel/nn_modules/qlinear/dynamic_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,30 +56,39 @@ def __init__(
group_size: int,
sym: bool,
desc_act: bool,
infeatures: int,
outfeatures: int,
in_features: int,
out_features: int,
bias: bool,
pack_dtype: torch.dtype,
kernel_switch_threshold=128,
**kwargs,
):
if gptqmodel_cuda_import_exception is not None:
raise ValueError(
f"Trying to use the cuda backend, but could not import the C++/CUDA dependencies with the following error: {gptqmodel_cuda_import_exception}"
)
super().__init__(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, infeatures=infeatures,
outfeatures=outfeatures, bias=bias, **kwargs)
super().__init__(
bits=bits,
group_size=group_size,
sym=sym,
desc_act=desc_act,
in_features=in_features,
out_features=out_features,
bias=bias,
pack_dtype=pack_dtype,
**kwargs)

# assert in_features % 64 == 0 and out_features % 64 == 0

self.kernel_switch_threshold = kernel_switch_threshold

# use faster cuda_256 by default
self.gptqmodel_cuda = gptqmodel_cuda_256

# fall back to cuda_64
if infeatures % 256 != 0 or outfeatures % 256 != 0:
if in_features % 256 != 0 or out_features % 256 != 0:
self.gptqmodel_cuda = gptqmodel_cuda_64

assert infeatures % 64 == 0 and outfeatures % 64 == 0

if self.bits == 4:
self.qmatmul = self.gptqmodel_cuda.vecquant4matmul
elif self.bits == 8:
Expand All @@ -96,10 +105,10 @@ def validate(cls, **args) -> Tuple[bool, Optional[Exception]]:
return cls._validate(**args)

def forward(self, x: torch.Tensor):
out_shape = x.shape[:-1] + (self.outfeatures,)
out_shape = x.shape[:-1] + (self.out_features,)
x = x.reshape(-1, x.shape[-1])

assert x.device.type == "cuda"
# assert x.device.type == "cuda"

# switch to torch kernel when input shape is >= kernel_switch_threshold
# cuda is only optimized for < kernel_switch_threshold and will run slower than torch otherwise
Expand All @@ -108,7 +117,7 @@ def forward(self, x: torch.Tensor):
f"Input shape `{x.shape[0]}` >= `{self.kernel_switch_threshold}` is not optimized for cuda kernel: dynamic switching to torch kernel.")
return self._forward(x, x.dtype, out_shape)

out = torch.zeros((x.shape[0], self.outfeatures), device=x.device, dtype=torch.float32)
out = torch.zeros((x.shape[0], self.out_features), device=x.device, dtype=torch.float32)
self.qmatmul(
x.to(dtype=torch.float32),
self.qweight,
Expand All @@ -120,7 +129,7 @@ def forward(self, x: torch.Tensor):

out = out.to(x.dtype).reshape(out_shape)
if self.bias is not None:
out = out + self.bias
out.add_(self.bias)
return out


Expand Down
Loading

0 comments on commit 9e4129c

Please sign in to comment.