Skip to content

Commit

Permalink
fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
Edwardf0t1 committed Feb 1, 2025
1 parent ca781bd commit a8babf8
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
13 changes: 12 additions & 1 deletion python/sglang/srt/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,9 +644,20 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
return remapped_name

possible_scale_names = [".k_scale", ".v_scale"]
modelopt_scale_names = [".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale"]
for scale_name in possible_scale_names:
if name.endswith(scale_name):
remapped_name = name.replace(scale_name, f".attn{scale_name}")
# Check and remap the name based on modelopt scale names
if any(
modelopt_scale_name in name
for modelopt_scale_name in modelopt_scale_names
):
remapped_name = name.replace(
f".self_attn.{scale_name[1]}_proj{scale_name}",
f".self_attn.attn{scale_name}",
)
else:
remapped_name = name.replace(scale_name, f".attn{scale_name}")
if remapped_name not in params_dict:
print_warning_once(
f"Found {scale_name} in the checkpoint (e.g. {name}), "
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
kv_cache_scales_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.utils import make_layers
from sglang.utils import get_exception_traceback
Expand Down Expand Up @@ -457,6 +458,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue
# Handle FP8 kv-scale remapping
if "scale" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue

for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
Expand Down

0 comments on commit a8babf8

Please sign in to comment.