diff --git a/tests/quantization/test_register_quantization_config.py b/tests/quantization/test_register_quantization_config.py new file mode 100644 index 0000000000000..8e7f44a399ddf --- /dev/null +++ b/tests/quantization/test_register_quantization_config.py @@ -0,0 +1,117 @@ +"""Tests register custom quantization config. + +See https://github.com/vllm-project/vllm/issues/11926 for more details. + +Run `pytest tests/quantization/test_register_quantization_config.py`. +""" +from typing import Any, Dict, List, Optional + +import pytest +import torch +import torch.nn.functional as F + +from vllm.model_executor.layers.linear import LinearBase # noqa: E501 +from vllm.model_executor.layers.linear import UnquantizedLinearMethod +from vllm.model_executor.layers.quantization import ( + get_quantization_config, register_quantization_config) +from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 + QuantizationConfig) + + +class FakeQuantLinearMethod(UnquantizedLinearMethod): + """Fake quantization linear method for per-token dynamic quantization.""" + + def __init__(self, num_bits: int = 8) -> None: + """Initialize the quantization method.""" + super().__init__() + self.num_bits = num_bits + + def apply(self, + layer: "torch.nn.Module", + x: "torch.Tensor", + bias: Optional["torch.Tensor"] = None) -> "torch.Tensor": + """Perform fake quantization before the linear layer.""" + + # Calculate the scales dynamically + max_val = torch.amax(x, dim=(0, -1), keepdims=True) + min_val = torch.amin(x, dim=(0, -1), keepdims=True) + scales = (max_val - min_val) / (2**self.num_bits - 1) + + # Fake quantize the input + quant_x = torch.clamp(torch.round(x / scales), -2**(self.num_bits - 1), + 2**(self.num_bits - 1) - 1) + dequant_x = quant_x * scales + + return F.linear(dequant_x, layer.weight, bias) + + +@register_quantization_config("custom_quant") +class CustomQuantConfig(QuantizationConfig): + """Custom quantization config for per-token dynamic fake quantization.""" + + def __init__(self, num_bits: int = 8) -> None: + """Initialize the quantization config.""" + self.num_bits = num_bits + + def get_name(self) -> str: + """Name of the quantization method.""" + return "custom_quant" + + def get_supported_act_dtypes(self) -> List["torch.dtype"]: + """List of supported activation dtypes.""" + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + """Minimum GPU capability to support the quantization method.""" + return -1 + + @staticmethod + def get_config_filenames() -> List[str]: + """List of filenames to search for in the model directory.""" + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "CustomQuantConfig": + """Create a config class from the model's quantization config.""" + return CustomQuantConfig(num_bits=config.get("num_bits", 8)) + + def get_quant_method(self, layer: "torch.nn.Module", + prefix: str) -> Optional["FakeQuantLinearMethod"]: + """Get the quantize method to use for the quantized layer.""" + if isinstance(layer, LinearBase): + return FakeQuantLinearMethod(num_bits=self.num_bits) + return None + + +def test_register_quantization_config(): + """Test register custom quantization config.""" + + # The quantization method `custom_quant` should be registered. + assert get_quantization_config("custom_quant") == CustomQuantConfig + + # The quantization method `custom_quant` is already exists, + # should raise an error. + with pytest.raises(ValueError): + register_quantization_config("custom_quant")(CustomQuantConfig) + + +@pytest.mark.parametrize(argnames="model", + argvalues=[ + "meta-llama/Meta-Llama-3-8B-Instruct", + ]) +def test_custom_quant(vllm_runner, model): + """Test infer with the custom quantization method.""" + with vllm_runner(model_name=model, + quantization="custom_quant", + enforce_eager=True) as llm: + + model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 + layer = model.model.layers[0] + qkv_proj = layer.self_attn.qkv_proj + + # Check the quantization method is FakeQuantLinearMethod + assert isinstance(qkv_proj.quant_method, FakeQuantLinearMethod) + + output = llm.generate_greedy("Hello my name is", max_tokens=20) + assert output diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index caeb8b95e02f2..d2bde13fcf546 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -29,6 +29,45 @@ "quark" ] +# The customized quantization methods which will be added to this dict. +_CUSTOMIZED_METHOD_TO_QUANT_CONFIG = {} + + +def register_quantization_config(quantization: str): + """Register a customized vllm quantization config. + + When a quantization method is not supported by vllm, you can register a customized + quantization config to support it. + + Args: + quantization (str): The quantization method name. + + Examples: + >>> from vllm.model_executor.layers.quantization import register_quantization_config + >>> from vllm.model_executor.layers.quantization import get_quantization_config + >>> from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + >>> + >>> @register_quantization_config("my_quant") + ... class MyQuantConfig(QuantizationConfig): + ... pass + >>> + >>> get_quantization_config("my_quant") + + """ # noqa: E501 + + def _wrapper(quant_config_cls): + if quantization in QUANTIZATION_METHODS: + raise ValueError( + f"The quantization method `{quantization}` is already exists.") + if not issubclass(quant_config_cls, QuantizationConfig): + raise ValueError("The quantization config must be a subclass of " + "`QuantizationConfig`.") + _CUSTOMIZED_METHOD_TO_QUANT_CONFIG[quantization] = quant_config_cls + QUANTIZATION_METHODS.append(quantization) + return quant_config_cls + + return _wrapper + def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: if quantization not in QUANTIZATION_METHODS: @@ -84,6 +123,8 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: "ipex": IPEXConfig, "quark": QuarkConfig } + # Update the `method_to_config` with customized quantization methods. + method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) return method_to_config[quantization]