Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model.gradient_checkpointing_enable() makes loss.requires_grad be False #35826

Open
2 of 4 tasks
ZCWei51 opened this issue Jan 22, 2025 · 1 comment
Open
2 of 4 tasks
Labels

Comments

@ZCWei51
Copy link

ZCWei51 commented Jan 22, 2025

System Info

Python 3.9.19
transformers 4.42.0
torch 2.2.2+cu118
peft 0.12.0

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

When I tried using model.gradient_checkpointing_enable() to reduce memory consumption during training, I encountered an error: "RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn." After troubleshooting, I found that the issue seems to be caused by loss.requires_grad being set to False, which prevents backpropagation. The following is the reproducible code to directly obtain loss.requires_grad False

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
import torch
from transformers import  AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_model, LoraConfig, TaskType

def main():
    train_data = {"input": "input test", "output": "output test"}
    model_name = "/workspace/model/CodeLlama-13b-Instruct-hf"
    output_dir = "./test_debug"
    
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16,device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = model.config.eos_token_id

    input_ids = tokenizer.encode(train_data["input"])
    output_ids = tokenizer.encode(train_data["output"])
    model_inputs_output = input_ids + output_ids + [tokenizer.eos_token_id]
    model_inputs_output = torch.tensor(model_inputs_output, dtype=torch.int64)
    labels = copy.deepcopy(model_inputs_output)
    labels[: len(input_ids)] = -1 # 
    example_mask = model_inputs_output.ge(0)
    label_mask = labels.ge(0)
    model_inputs_output[~example_mask] = 0
    labels[~label_mask] = -100
    train_dataset = {
            "input_ids": model_inputs_output.unsqueeze(0).to("cuda"),
            "attention_mask": example_mask.unsqueeze(0).to("cuda"),
            "labels": labels.unsqueeze(0).to("cuda")
        }

    lora_config = LoraConfig(
            r=8,  
            lora_alpha=16,  
            target_modules=["q_proj", "gate_proj", "v_proj", "o_proj", "up_proj", "k_proj", "down_proj"],  # 与llama-factory一致
            lora_dropout=0.05,  
            task_type= TaskType.CAUSAL_LM  
        )
    model = get_peft_model(model, lora_config)
    model.gradient_checkpointing_enable()
    model.train()    
    model.print_trainable_parameters()
    model.to("cuda")

    output = model(**train_dataset)
    loss = output["loss"]
    print(f"loss: {loss.requires_grad}")


if __name__ == "__main__":
    main()

Output is

loss: False

This is confusing because model.gradient_checkpointing_enable() is designed to reduce memory consumption, but if loss.requires_grad is set to False, it disrupts the normal training process. Meanwhile, when I use similar code from LLama-factory to achieve the effect of model.gradient_checkpointing_enable(), I find that loss.requires_grad is True. Below is the code:

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
import torch
from transformers import  AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_model, LoraConfig, TaskType
import copy
from types import MethodType
from functools import partial
import inspect
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from transformers import PreTrainedModel

def _gradient_checkpointing_enable(
    self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None
) -> None:
    r"""
    Activates gradient checkpointing for the current model.

    Modification of the original method to enable gradient checkpointing for block-wise optimizer.
    """
    from torch.utils.checkpoint import checkpoint

    if not self.supports_gradient_checkpointing:
        raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__))

    if gradient_checkpointing_kwargs is None:
        gradient_checkpointing_kwargs = {"use_reentrant": True}

    gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs)

    def custom_gradient_checkpointing_func(func, *args, **kwargs):
        module: "torch.nn.Module" = func.__self__

        if any(param.requires_grad for param in module.parameters()):
            for arg in args:
                if torch.is_tensor(arg) and torch.is_floating_point(arg):
                    arg.requires_grad_(True)

        return gradient_checkpointing_func(func, *args, **kwargs)

    if "value" in inspect.signature(self._set_gradient_checkpointing).parameters:  # old GC format
        self.apply(partial(self._set_gradient_checkpointing, value=True))
        self.enable_input_require_grads()
        print("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
    else:  # have already enabled input require gradients
        self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func)


def main():
    train_data = {"input": "input test", "output": "output test"}
    model_name = "/workspace/model/CodeLlama-13b-Instruct-hf"
    output_dir = "./test_debug"
    
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16,device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    # set the pad token of the model's configuration
    model.config.pad_token_id = model.config.eos_token_id
    # return 
    if not getattr(model, "supports_gradient_checkpointing", False):
        print("Current model does not support gradient checkpointing.")
    else:
        # use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
        # According to: https://github.com/huggingface/transformers/issues/28339
        model.gradient_checkpointing_enable = MethodType(_gradient_checkpointing_enable, model)
        model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
        setattr(model.config, "use_cache", False)  # turn off when gradient checkpointing is enabled
        print("Gradient checkpointing enabled.")

    input_ids = tokenizer.encode(train_data["input"])
    output_ids = tokenizer.encode(train_data["output"])
    model_inputs_output = input_ids + output_ids + [tokenizer.eos_token_id]
    model_inputs_output = torch.tensor(model_inputs_output, dtype=torch.int64)
    labels = copy.deepcopy(model_inputs_output)
    labels[: len(input_ids)] = -1 # 
    example_mask = model_inputs_output.ge(0)
    label_mask = labels.ge(0)
    model_inputs_output[~example_mask] = 0
    labels[~label_mask] = -100
    train_dataset = {
            "input_ids": model_inputs_output.unsqueeze(0).to("cuda"),
            "attention_mask": example_mask.unsqueeze(0).to("cuda"),
            "labels": labels.unsqueeze(0).to("cuda")
        }

    lora_config = LoraConfig(
            r=8,  
            lora_alpha=16,  
            target_modules=["q_proj", "gate_proj", "v_proj", "o_proj", "up_proj", "k_proj", "down_proj"],  # 与llama-factory一致
            lora_dropout=0.05,  
            task_type= TaskType.CAUSAL_LM  
        )
    model = get_peft_model(model, lora_config)
    # model.gradient_checkpointing_enable()
    model.train()    
    model.print_trainable_parameters()
    model.to("cuda")

    output = model(**train_dataset)
    loss = output["loss"]
    print(f"loss: {loss.requires_grad}")


if __name__ == "__main__":
    main()

output is

loss: True

Expected behavior

I am not entirely sure if this is a bug in the implementation of model.gradient_checkpointing_enable(). If it is not, please feel free to close the issue directly and let me know. Thank you for taking the time to look into this issue :)

@ZCWei51 ZCWei51 added the bug label Jan 22, 2025
@Rocketknight1
Copy link
Member

cc @muellerzr @SunMarc

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants