From 767ef9c01ca5e3832ab74569e63a9abd22ef72a3 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Thu, 8 Aug 2024 18:35:49 -0700 Subject: [PATCH] [TPU] Add Load-time W8A16 quantization for TPU Backend (#7005) Signed-off-by: Alvant --- vllm/config.py | 6 + .../layers/quantization/__init__.py | 2 + .../layers/quantization/tpu_int8.py | 118 ++++++++++++++++++ vllm/model_executor/model_loader/loader.py | 17 +-- 4 files changed, 135 insertions(+), 8 deletions(-) create mode 100644 vllm/model_executor/layers/quantization/tpu_int8.py diff --git a/vllm/config.py b/vllm/config.py index 63a5acc50b943..6fc0045fb93aa 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -244,6 +244,7 @@ def _verify_quantization(self) -> None: "fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fbgemm_fp8", "compressed_tensors", "compressed-tensors" ] + tpu_supported_quantization = ["tpu_int8"] if self.quantization is not None: self.quantization = self.quantization.lower() @@ -282,6 +283,11 @@ def _verify_quantization(self) -> None: raise ValueError( f"{self.quantization} quantization is currently not " f"supported in ROCm.") + if is_tpu( + ) and self.quantization not in tpu_supported_quantization: + raise ValueError( + f"{self.quantization} quantization is currently not " + f"supported in TPU Backend.") if self.quantization not in optimized_quantization_methods: logger.warning( "%s quantization is not fully " diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index db2a245561699..e1b3bc9b4ad54 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -22,11 +22,13 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.qqq import QQQConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig +from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "aqlm": AQLMConfig, "awq": AWQConfig, "deepspeedfp": DeepSpeedFPConfig, + "tpu_int8": Int8TpuConfig, "fp8": Fp8Config, "fbgemm_fp8": FBGEMMFp8Config, # The order of gptq methods is important for config.py iteration over diff --git a/vllm/model_executor/layers/quantization/tpu_int8.py b/vllm/model_executor/layers/quantization/tpu_int8.py new file mode 100644 index 0000000000000..ae34e01497db4 --- /dev/null +++ b/vllm/model_executor/layers/quantization/tpu_int8.py @@ -0,0 +1,118 @@ +from typing import Any, Dict, List, Optional, Tuple + +import torch +from torch.nn import Module +from torch.nn.parameter import Parameter + +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs + +ACTIVATION_SCHEMES = ["none"] + + +class Int8TpuConfig(QuantizationConfig): + """Int8 Quantization Config class for TPU Backend.""" + + def __init__( + self, + activation_scheme: str = "none", + ) -> None: + if activation_scheme not in ACTIVATION_SCHEMES: + raise ValueError( + f"Unsupported activation scheme {activation_scheme}") + self.activation_scheme = activation_scheme + + def get_name(self) -> str: + return "tpu_int8" + + def get_supported_act_dtypes(self) -> List[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + raise NotImplementedError( + "This function should not be called with TPU Backend") + + @staticmethod + def get_config_filenames() -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "Int8TpuConfig": + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + return cls(activation_scheme=activation_scheme) + + def get_quant_method(self, layer: Module, + prefix: str) -> Optional["TPUInt8LinearMethod"]: + if isinstance(layer, LinearBase): + return TPUInt8LinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class TPUInt8LinearMethod(LinearMethodBase): + """Int8 Linear method for TPU Quant. """ + + def __init__(self, quant_config: Int8TpuConfig): + self.quant_config = quant_config + + def create_weights(self, layer: Module, input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + weight = Parameter(torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, { + **extra_weight_attrs, + "input_dim": 1, + "output_dim": 0, + }) + + def _quantize_weight( + self, weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + weight_dtype = weight.dtype + weight = weight.cpu().to(torch.float32) + n_bit = 8 + eps = 1e-5 + max_int = 2**(n_bit - 1) - 1 + min_int = -(2**(n_bit - 1)) + max_val = weight.abs().amax(dim=-1, keepdim=True) + max_val = max_val.clamp(min=eps) + qscale = max_val / max_int + qweight = torch.clamp(torch.round(weight * (1.0 / qscale)), min_int, + max_int).to(torch.int8) + qscale = qscale.squeeze().to(weight_dtype) + return qweight, qscale + + def process_weights_after_loading(self, layer: Module) -> None: + device = layer.weight.device + qweight, qscale = self._quantize_weight(layer.weight) + qweight = qweight.to(device) + qscale = qscale.to(device) + layer.weight = Parameter(qweight, requires_grad=False) + layer.scale = Parameter(qscale, requires_grad=False) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + try: + import torch_xla.experimental.xla_quantized_matmul # noqa: F401 + except ImportError as err: + raise ImportError( + "Please install torch_xla by following the instructions at " + "https://docs.vllm.ai/en/latest/getting_started/tpu-installation.html " # noqa: E501 + "to run vLLM on TPU.") from err + weight = layer.weight + scale = layer.scale + out = torch.ops.xla.quantized_matmul(x, weight, scale) + if bias is not None: + out = out + bias + return out diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 44c04c9ba8ddc..ba9c8af88f864 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -94,14 +94,15 @@ def _get_quantization_config( """Get the quantization config.""" if model_config.quantization is not None: quant_config = get_quant_config(model_config, load_config) - capability = current_platform.get_device_capability() - capability = capability[0] * 10 + capability[1] - if capability < quant_config.get_min_capability(): - raise ValueError( - f"The quantization method {model_config.quantization} is not " - "supported for the current GPU. " - f"Minimum capability: {quant_config.get_min_capability()}. " - f"Current capability: {capability}.") + if not is_tpu(): + capability = current_platform.get_device_capability() + capability = capability[0] * 10 + capability[1] + if capability < quant_config.get_min_capability(): + raise ValueError( + f"The quantization method {model_config.quantization} " + "is not supported for the current GPU. " + f"Minimum capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}.") supported_dtypes = quant_config.get_supported_act_dtypes() if model_config.dtype not in supported_dtypes: raise ValueError(