Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Allow fallback to AWQ from AWQMarlin at per-layer granularity #13119

Merged
merged 3 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 19 additions & 16 deletions vllm/model_executor/layers/linear.py
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes in this file just make it such that we always have

self.input_size_per_partition
self.output_size_per_partition
self.output_partition_sizes

defined for layers before calling quant_config.get_quant_method() so we can check these there

Original file line number Diff line number Diff line change
Expand Up @@ -290,29 +290,30 @@ def __init__(self,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[list[int]] = None,
prefix: str = ""):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix)

self.gather_output = gather_output

# Divide the weight matrix along the last dimension.
tp_size = get_tensor_model_parallel_world_size()
assert self.quant_method is not None
self.output_size_per_partition = divide(self.output_size, tp_size)
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
# If QKV or MergedColumn, use output size of each partition.
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, tp_size)
divide(output_size, self.tp_size)
for output_size in self.output_sizes
]

super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix)

self.gather_output = gather_output

if output_sizes is None:
output_sizes = [output_size]

assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
Expand Down Expand Up @@ -1044,22 +1045,24 @@ def __init__(self,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
# Divide the weight matrix along the first dimension.
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]

super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config, prefix)

self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results

# Divide the weight matrix along the last dimension.
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
assert self.quant_method is not None

self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=[self.output_size],
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
Expand Down
28 changes: 18 additions & 10 deletions vllm/model_executor/layers/quantization/awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
set_weight_attrs)
from vllm.model_executor.layers.quantization.awq import is_layer_skipped_awq
from vllm.model_executor.layers.quantization.awq import (AWQConfig,
is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
marlin_permute_scales, moe_awq_to_marlin_zero_points,
verify_marlin_supported, verify_marlin_supports_shape)
check_marlin_supports_layer, marlin_make_empty_g_idx,
marlin_make_workspace, marlin_moe_permute_scales, marlin_permute_scales,
moe_awq_to_marlin_zero_points, verify_marlin_supported,
verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
Expand All @@ -40,18 +42,17 @@ class AWQMarlinConfig(QuantizationConfig):
8: scalar_types.uint8,
}

def __init__(self,
weight_bits: int,
group_size: int,
zero_point: bool,
def __init__(self, weight_bits: int, group_size: int, zero_point: bool,
lm_head_quantized: bool,
modules_to_not_convert: Optional[List[str]] = None) -> None:
modules_to_not_convert: Optional[List[str]],
full_config: Dict[str, Any]) -> None:
self.pack_factor = 32 // weight_bits # packed into int32
self.group_size = group_size
self.zero_point = zero_point
self.lm_head_quantized = lm_head_quantized
self.weight_bits = weight_bits
self.modules_to_not_convert = modules_to_not_convert or []
self.full_config = full_config

if self.weight_bits not in self.TYPE_MAP:
raise ValueError(f"Unsupported num_bits = {self.weight_bits}. "
Expand Down Expand Up @@ -96,7 +97,7 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQMarlinConfig":
modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None)
return cls(weight_bits, group_size, zero_point, lm_head_quantized,
modules_to_not_convert)
modules_to_not_convert, config)

@classmethod
def override_quantization_method(cls, hf_quant_cfg,
Expand Down Expand Up @@ -124,6 +125,13 @@ def get_quant_method(self, layer: torch.nn.Module,
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
# Check if the layer is supported by AWQMarlin.
if not check_marlin_supports_layer(layer, self.group_size):
logger.warning_once(
f"Layer '{prefix}' is not supported by AWQMarlin. "
"Falling back to unoptimized AWQ kernels.")
return AWQConfig.from_config(
self.full_config).get_quant_method(layer, prefix)
return AWQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
return AWQMoEMethod(self)
Expand Down
9 changes: 6 additions & 3 deletions vllm/model_executor/layers/quantization/moe_wna16.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supports_layer)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform

Expand Down Expand Up @@ -87,8 +89,8 @@ def from_config(cls, config: Dict[str, Any]) -> "MoeWNA16Config":
modules_to_not_convert = []
elif linear_quant_method == "awq":
has_zp = cls.get_from_keys(config, ["zero_point"])
modules_to_not_convert = cls.get_from_keys(
config, ["modules_to_not_convert"])
modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None)
else:
raise ValueError("moe_wna16 only support gptq and awq.")

Expand Down Expand Up @@ -135,7 +137,8 @@ def get_quant_method(self, layer: torch.nn.Module,
return GPTQConfig.from_config(
self.full_config).get_quant_method(layer, prefix)
elif self.linear_quant_method == "awq":
if self.use_marlin:
if self.use_marlin and check_marlin_supports_layer(
layer, self.group_size):
return AWQMarlinConfig.from_config(
self.full_config).get_quant_method(layer, prefix)
else:
Expand Down
15 changes: 15 additions & 0 deletions vllm/model_executor/layers/quantization/utils/marlin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch

from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types

Expand Down Expand Up @@ -135,6 +136,20 @@ def check_marlin_supports_shape(output_size_per_partition: int,
return True, None


def check_marlin_supports_layer(layer: LinearBase, group_size: int) \
-> bool:
output_size_per_partition = getattr(layer, "output_size_per_partition",
None) or layer.output_size
input_size_per_partition = getattr(layer, "input_size_per_partition",
None) or layer.input_size

return check_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition,
input_size=layer.input_size,
group_size=group_size)[0]


def marlin_make_workspace(output_size_per_partition: int,
device: torch.device) -> torch.Tensor:
max_workspace_size = (output_size_per_partition //
Expand Down