diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d7abc3bf7e22..cd8a84964916 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3056,6 +3056,9 @@ def save_pretrained( # init state_dict for this shard shard_state_dict = {name: "" for name in shard} for module_name in shard: + # skip to collect this weight again + if shard_state_dict.get(module_name) != "": + continue module = module_map[module_name] # update state dict with onloaded parameters shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict)