From 12f49e15b8f6c9a5c87786405fe4d288d90bacc8 Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Thu, 30 Jan 2025 22:44:41 +0000 Subject: [PATCH] fix format --- python/sglang/srt/layers/quantization/modelopt_quant.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index ca1429f4bee..38f2e7b2d4e 100644 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -12,6 +12,7 @@ requantize_with_max_scale, ) +from sglang.srt.layers.attention import AttentionBackend from sglang.srt.layers.linear import LinearBase, LinearMethodBase from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.quantization.base_config import ( @@ -71,12 +72,12 @@ def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config": def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: - from vllm.attention.layer import Attention if isinstance(layer, LinearBase): return ModelOptFp8LinearMethod(self) - elif isinstance(layer, Attention): + if isinstance(layer, AttentionBackend): return ModelOptFp8KVCacheMethod(self) + return None def get_scaled_act_names(self) -> List[str]: @@ -182,7 +183,7 @@ def apply( class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): """ - Supports loading kv-cache scaling factors from FP8 checkpoints. + Handles loading FP8 kv-cache scaling factors from modelopt quantized checkpoints. """ def __init__(self, quant_config: ModelOptFp8Config):