From 4cb7afc9352126d0c1bc2d130302d3a7157f385f Mon Sep 17 00:00:00 2001 From: wejoncy <247153481@qq.com> Date: Wed, 26 Feb 2025 10:08:25 +0800 Subject: [PATCH 1/2] Skip collecting duplicated weight --- src/transformers/modeling_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d7abc3bf7e22..afd13b177747 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3056,6 +3056,9 @@ def save_pretrained( # init state_dict for this shard shard_state_dict = {name: "" for name in shard} for module_name in shard: + # skip to collect this weight again + if shard_state_dict.get(module_name) != '': + continue module = module_map[module_name] # update state dict with onloaded parameters shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict) From 3d5108e825885b92d3489c9bcc22df0fad6fab59 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Wed, 26 Feb 2025 11:59:03 +0800 Subject: [PATCH 2/2] format --- src/transformers/modeling_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index afd13b177747..cd8a84964916 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3057,7 +3057,7 @@ def save_pretrained( shard_state_dict = {name: "" for name in shard} for module_name in shard: # skip to collect this weight again - if shard_state_dict.get(module_name) != '': + if shard_state_dict.get(module_name) != "": continue module = module_map[module_name] # update state dict with onloaded parameters