From a8babf82d5feb9275df5ca9722ca4f3345674f0f Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Thu, 30 Jan 2025 08:23:46 +0000 Subject: [PATCH] fix format --- python/sglang/srt/model_loader/weight_utils.py | 13 ++++++++++++- python/sglang/srt/models/llama.py | 6 ++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index c07a346f471..822e28844ab 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -644,9 +644,20 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: return remapped_name possible_scale_names = [".k_scale", ".v_scale"] + modelopt_scale_names = [".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale"] for scale_name in possible_scale_names: if name.endswith(scale_name): - remapped_name = name.replace(scale_name, f".attn{scale_name}") + # Check and remap the name based on modelopt scale names + if any( + modelopt_scale_name in name + for modelopt_scale_name in modelopt_scale_names + ): + remapped_name = name.replace( + f".self_attn.{scale_name[1]}_proj{scale_name}", + f".self_attn.attn{scale_name}", + ) + else: + remapped_name = name.replace(scale_name, f".attn{scale_name}") if remapped_name not in params_dict: print_warning_once( f"Found {scale_name} in the checkpoint (e.g. {name}), " diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 4ea77eede9b..d23b8597d0d 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -47,6 +47,7 @@ from sglang.srt.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, + maybe_remap_kv_scale_name, ) from sglang.srt.utils import make_layers from sglang.utils import get_exception_traceback @@ -457,6 +458,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue if name.startswith("model.vision_tower") and name not in params_dict: continue + # Handle FP8 kv-scale remapping + if "scale" in name: + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: