Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Support register quantization method out-of-tree #11969

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 117 additions & 0 deletions tests/quantization/test_register_quantization_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Tests register custom quantization config.
See https://github.com/vllm-project/vllm/issues/11926 for more details.
Run `pytest tests/quantization/test_register_quantization_config.py`.
"""
from typing import Any, Dict, List, Optional

import pytest
import torch
import torch.nn.functional as F

from vllm.model_executor.layers.linear import LinearBase # noqa: E501
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import (
get_quantization_config, register_quantization_config)
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig)


class FakeQuantLinearMethod(UnquantizedLinearMethod):
"""Fake quantization linear method for per-token dynamic quantization."""

def __init__(self, num_bits: int = 8) -> None:
"""Initialize the quantization method."""
super().__init__()
self.num_bits = num_bits

def apply(self,
layer: "torch.nn.Module",
x: "torch.Tensor",
bias: Optional["torch.Tensor"] = None) -> "torch.Tensor":
"""Perform fake quantization before the linear layer."""

# Calculate the scales dynamically
max_val = torch.amax(x, dim=(0, -1), keepdims=True)
min_val = torch.amin(x, dim=(0, -1), keepdims=True)
scales = (max_val - min_val) / (2**self.num_bits - 1)

# Fake quantize the input
quant_x = torch.clamp(torch.round(x / scales), -2**(self.num_bits - 1),
2**(self.num_bits - 1) - 1)
dequant_x = quant_x * scales

return F.linear(dequant_x, layer.weight, bias)


@register_quantization_config("custom_quant")
class CustomQuantConfig(QuantizationConfig):
"""Custom quantization config for per-token dynamic fake quantization."""

def __init__(self, num_bits: int = 8) -> None:
"""Initialize the quantization config."""
self.num_bits = num_bits

def get_name(self) -> str:
"""Name of the quantization method."""
return "custom_quant"

def get_supported_act_dtypes(self) -> List["torch.dtype"]:
"""List of supported activation dtypes."""
return [torch.float16, torch.bfloat16]

@classmethod
def get_min_capability(cls) -> int:
"""Minimum GPU capability to support the quantization method."""
return -1

@staticmethod
def get_config_filenames() -> List[str]:
"""List of filenames to search for in the model directory."""
return []

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "CustomQuantConfig":
"""Create a config class from the model's quantization config."""
return CustomQuantConfig(num_bits=config.get("num_bits", 8))

def get_quant_method(self, layer: "torch.nn.Module",
prefix: str) -> Optional["FakeQuantLinearMethod"]:
"""Get the quantize method to use for the quantized layer."""
if isinstance(layer, LinearBase):
return FakeQuantLinearMethod(num_bits=self.num_bits)
return None


def test_register_quantization_config():
"""Test register custom quantization config."""

# The quantization method `custom_quant` should be registered.
assert get_quantization_config("custom_quant") == CustomQuantConfig

# The quantization method `custom_quant` is already exists,
# should raise an error.
with pytest.raises(ValueError):
register_quantization_config("custom_quant")(CustomQuantConfig)


@pytest.mark.parametrize(argnames="model",
argvalues=[
"meta-llama/Meta-Llama-3-8B-Instruct",
])
def test_custom_quant(vllm_runner, model):
"""Test infer with the custom quantization method."""
with vllm_runner(model_name=model,
quantization="custom_quant",
enforce_eager=True) as llm:

model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj

# Check the quantization method is FakeQuantLinearMethod
assert isinstance(qkv_proj.quant_method, FakeQuantLinearMethod)

output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
41 changes: 41 additions & 0 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,45 @@
"quark"
]

# The customized quantization methods which will be added to this dict.
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG = {}


def register_quantization_config(quantization: str):
"""Register a customized vllm quantization config.
When a quantization method is not supported by vllm, you can register a customized
quantization config to support it.
Args:
quantization (str): The quantization method name.
Examples:
>>> from vllm.model_executor.layers.quantization import register_quantization_config
>>> from vllm.model_executor.layers.quantization import get_quantization_config
>>> from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
>>>
>>> @register_quantization_config("my_quant")
... class MyQuantConfig(QuantizationConfig):
... pass
>>>
>>> get_quantization_config("my_quant")
<class 'MyQuantConfig'>
""" # noqa: E501

def _wrapper(quant_config_cls):
if quantization in QUANTIZATION_METHODS:
raise ValueError(
f"The quantization method `{quantization}` is already exists.")
if not issubclass(quant_config_cls, QuantizationConfig):
raise ValueError("The quantization config must be a subclass of "
"`QuantizationConfig`.")
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG[quantization] = quant_config_cls
QUANTIZATION_METHODS.append(quantization)
return quant_config_cls

return _wrapper


def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
if quantization not in QUANTIZATION_METHODS:
Expand Down Expand Up @@ -84,6 +123,8 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
"ipex": IPEXConfig,
"quark": QuarkConfig
}
# Update the `method_to_config` with customized quantization methods.
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)

return method_to_config[quantization]

Expand Down
Loading