Skip to content

Commit

Permalink
[Quant][Perf] Use moe_wna16 kernel by default for MoEs with many expe…
Browse files Browse the repository at this point in the history
…rts (#13236)

Signed-off-by: mgoin <[email protected]>
  • Loading branch information
mgoin authored Feb 14, 2025
1 parent c9e2d64 commit 5e5c8e0
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 26 deletions.
2 changes: 1 addition & 1 deletion tests/weight_loading/test_weight_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"robertgshaw2/zephyr-7b-beta-channelwise-gptq")
REVISION = os.environ.get("REVISION", "main")
QUANTIZATION = os.environ.get("QUANTIZATION", "gptq_marlin")
MIN_CAPABILITY = os.environ.get("MIN_CAPABILITY", "89")
MIN_CAPABILITY = os.environ.get("MIN_CAPABILITY", "80")


@pytest.mark.skipif(
Expand Down
8 changes: 7 additions & 1 deletion vllm/model_executor/layers/quantization/awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
Expand Down Expand Up @@ -134,7 +135,12 @@ def get_quant_method(self, layer: torch.nn.Module,
self.full_config).get_quant_method(layer, prefix)
return AWQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
return AWQMoEMethod(self)
if layer.num_experts > 32:
# For MoEs with many experts the moe_wna16 kernel is faster
return MoeWNA16Config.from_config(
self.full_config).get_quant_method(layer, prefix)
else:
return AWQMoEMethod(self)
return None

@classmethod
Expand Down
35 changes: 16 additions & 19 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,18 @@
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearMethodBase,
UnquantizedLinearMethod,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported, marlin_moe_permute_scales,
marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
from vllm.model_executor.layers.vocab_parallel_embedding import (
UnquantizedEmbeddingMethod)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
Expand All @@ -44,15 +42,10 @@ class GPTQMarlinConfig(QuantizationConfig):
(8, True): scalar_types.uint8b128,
}

def __init__(
self,
weight_bits: int,
group_size: int,
desc_act: bool,
is_sym: bool,
lm_head_quantized: bool,
dynamic: Dict[str, Dict[str, Union[int, bool]]],
) -> None:
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
is_sym: bool, lm_head_quantized: bool,
dynamic: Dict[str, Dict[str, Union[int, bool]]],
full_config: Dict[str, Any]) -> None:
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
Expand Down Expand Up @@ -90,6 +83,7 @@ def __init__(
self.group_size = group_size
self.desc_act = desc_act
self.lm_head_quantized = lm_head_quantized
self.full_config = full_config

if (weight_bits, is_sym) not in self.TYPE_MAP:
raise ValueError("Unsupported quantization config: "
Expand Down Expand Up @@ -132,7 +126,7 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, desc_act, is_sym,
lm_head_quantized, dynamic)
lm_head_quantized, dynamic, config)

@classmethod
def override_quantization_method(cls, hf_quant_cfg,
Expand All @@ -155,12 +149,15 @@ def override_quantization_method(cls, hf_quant_cfg,
" faster inference")
return None

def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod",
UnquantizedLinearMethod, UnquantizedEmbeddingMethod]]:
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, FusedMoE):
return GPTQMarlinMoEMethod(self)
if layer.num_experts > 32:
# For MoEs with many experts the moe_wna16 kernel is faster
return MoeWNA16Config.from_config(
self.full_config).get_quant_method(layer, prefix)
else:
return GPTQMarlinMoEMethod(self)
return get_linear_quant_method(self, layer, prefix,
GPTQMarlinLinearMethod)

Expand Down
20 changes: 15 additions & 5 deletions vllm/model_executor/layers/quantization/moe_wna16.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,8 @@
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supports_layer)
from vllm.model_executor.utils import set_weight_attrs
Expand All @@ -37,6 +32,12 @@ def __init__(self, linear_quant_method: str, weight_bits: int,
self.linear_quant_method = linear_quant_method
self.full_config = full_config
self.use_marlin = False
# Avoid circular import
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinConfig)
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
if self.linear_quant_method == "gptq":
self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(
full_config)
Expand Down Expand Up @@ -115,6 +116,8 @@ def is_moe_wna16_compatible(cls, quant_config: Dict[str, Any]):
capability_tuple = current_platform.get_device_capability()
device_capability = (-1 if capability_tuple is None else
capability_tuple.to_int())
# Avoid circular import
from vllm.model_executor.layers.quantization.awq import AWQConfig
awq_min_capability = AWQConfig.get_min_capability()

gptq_compatible = quant_method == "gptq" and \
Expand All @@ -129,6 +132,13 @@ def get_quant_method(self, layer: torch.nn.Module,
if is_layer_skipped_quant(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
elif isinstance(layer, LinearBase):
# Avoid circular import
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinConfig)
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
if self.linear_quant_method == "gptq":
if self.use_marlin:
return GPTQMarlinConfig.from_config(
Expand Down

0 comments on commit 5e5c8e0

Please sign in to comment.