diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index ca1429f4bee..38f2e7b2d4e 100644 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -12,6 +12,7 @@ requantize_with_max_scale, ) +from sglang.srt.layers.attention import AttentionBackend from sglang.srt.layers.linear import LinearBase, LinearMethodBase from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.quantization.base_config import ( @@ -71,12 +72,12 @@ def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config": def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: - from vllm.attention.layer import Attention if isinstance(layer, LinearBase): return ModelOptFp8LinearMethod(self) - elif isinstance(layer, Attention): + if isinstance(layer, AttentionBackend): return ModelOptFp8KVCacheMethod(self) + return None def get_scaled_act_names(self) -> List[str]: @@ -182,7 +183,7 @@ def apply( class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): """ - Supports loading kv-cache scaling factors from FP8 checkpoints. + Handles loading FP8 kv-cache scaling factors from modelopt quantized checkpoints. """ def __init__(self, quant_config: ModelOptFp8Config):