You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
My own task or dataset (give details below)
Reproduction
When I tried using model.gradient_checkpointing_enable() to reduce memory consumption during training, I encountered an error: "RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn." After troubleshooting, I found that the issue seems to be caused by loss.requires_grad being set to False, which prevents backpropagation. The following is the reproducible code to directly obtain loss.requires_grad False
This is confusing because model.gradient_checkpointing_enable() is designed to reduce memory consumption, but if loss.requires_grad is set to False, it disrupts the normal training process. Meanwhile, when I use similar code from LLama-factory to achieve the effect of model.gradient_checkpointing_enable(), I find that loss.requires_grad is True. Below is the code:
importosos.environ["CUDA_VISIBLE_DEVICES"] ="4"importtorchfromtransformersimportAutoModelForCausalLM, AutoTokenizerfrompeftimportget_peft_model, LoraConfig, TaskTypeimportcopyfromtypesimportMethodTypefromfunctoolsimportpartialimportinspectfromtypingimportTYPE_CHECKING, Any, Dict, Optional, TuplefromtransformersimportPreTrainedModeldef_gradient_checkpointing_enable(
self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] =None
) ->None:
r""" Activates gradient checkpointing for the current model. Modification of the original method to enable gradient checkpointing for block-wise optimizer. """fromtorch.utils.checkpointimportcheckpointifnotself.supports_gradient_checkpointing:
raiseValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__))
ifgradient_checkpointing_kwargsisNone:
gradient_checkpointing_kwargs= {"use_reentrant": True}
gradient_checkpointing_func=partial(checkpoint, **gradient_checkpointing_kwargs)
defcustom_gradient_checkpointing_func(func, *args, **kwargs):
module: "torch.nn.Module"=func.__self__ifany(param.requires_gradforparaminmodule.parameters()):
forarginargs:
iftorch.is_tensor(arg) andtorch.is_floating_point(arg):
arg.requires_grad_(True)
returngradient_checkpointing_func(func, *args, **kwargs)
if"value"ininspect.signature(self._set_gradient_checkpointing).parameters: # old GC formatself.apply(partial(self._set_gradient_checkpointing, value=True))
self.enable_input_require_grads()
print("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
else: # have already enabled input require gradientsself._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)
defmain():
train_data= {"input": "input test", "output": "output test"}
model_name="/workspace/model/CodeLlama-13b-Instruct-hf"output_dir="./test_debug"model=AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16,device_map="auto")
tokenizer=AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token=tokenizer.eos_token# set the pad token of the model's configurationmodel.config.pad_token_id=model.config.eos_token_id# return ifnotgetattr(model, "supports_gradient_checkpointing", False):
print("Current model does not support gradient checkpointing.")
else:
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)# According to: https://github.com/huggingface/transformers/issues/28339model.gradient_checkpointing_enable=MethodType(_gradient_checkpointing_enable, model)
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabledprint("Gradient checkpointing enabled.")
input_ids=tokenizer.encode(train_data["input"])
output_ids=tokenizer.encode(train_data["output"])
model_inputs_output=input_ids+output_ids+ [tokenizer.eos_token_id]
model_inputs_output=torch.tensor(model_inputs_output, dtype=torch.int64)
labels=copy.deepcopy(model_inputs_output)
labels[: len(input_ids)] =-1# example_mask=model_inputs_output.ge(0)
label_mask=labels.ge(0)
model_inputs_output[~example_mask] =0labels[~label_mask] =-100train_dataset= {
"input_ids": model_inputs_output.unsqueeze(0).to("cuda"),
"attention_mask": example_mask.unsqueeze(0).to("cuda"),
"labels": labels.unsqueeze(0).to("cuda")
}
lora_config=LoraConfig(
r=8,
lora_alpha=16,
target_modules=["q_proj", "gate_proj", "v_proj", "o_proj", "up_proj", "k_proj", "down_proj"], # 与llama-factory一致lora_dropout=0.05,
task_type=TaskType.CAUSAL_LM
)
model=get_peft_model(model, lora_config)
# model.gradient_checkpointing_enable()model.train()
model.print_trainable_parameters()
model.to("cuda")
output=model(**train_dataset)
loss=output["loss"]
print(f"loss: {loss.requires_grad}")
if__name__=="__main__":
main()
output is
loss: True
Expected behavior
I am not entirely sure if this is a bug in the implementation of model.gradient_checkpointing_enable(). If it is not, please feel free to close the issue directly and let me know. Thank you for taking the time to look into this issue :)
The text was updated successfully, but these errors were encountered:
System Info
Python 3.9.19
transformers 4.42.0
torch 2.2.2+cu118
peft 0.12.0
Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
When I tried using model.gradient_checkpointing_enable() to reduce memory consumption during training, I encountered an error: "RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn." After troubleshooting, I found that the issue seems to be caused by loss.requires_grad being set to False, which prevents backpropagation. The following is the reproducible code to directly obtain
loss.requires_grad False
Output is
This is confusing because
model.gradient_checkpointing_enable()
is designed to reduce memory consumption, but ifloss.requires_grad
is set toFalse
, it disrupts the normal training process. Meanwhile, when I use similar code from LLama-factory to achieve the effect of model.gradient_checkpointing_enable(), I find thatloss.requires_grad
isTrue
. Below is the code:output is
Expected behavior
I am not entirely sure if this is a bug in the implementation of
model.gradient_checkpointing_enable()
. If it is not, please feel free to close the issue directly and let me know. Thank you for taking the time to look into this issue :)The text was updated successfully, but these errors were encountered: