Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix lost loss values when using user-defined compute_loss_func in some cases #35872

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

dolphin-Dang
Copy link

What does this PR do?

When using User-defined compute_loss_func, the Trainer.compute_loss() function may do 'labels = inputs.pop("labels")', which leads to the absence of 'labels' in inputs of model.

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        if (self.label_smoother is not None or self.compute_loss_func is not None) and "labels" in inputs:
            labels = inputs.pop("labels")

Then some predefined models, like BertForTokenClassification, will not return a loss value in there forward() methods.

class BertForTokenClassification(BertPreTrainedModel):
    ...
    def forward(
        self,
        ...
    ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
        ...
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output
        ...

Then in Trainer.prediction_step() method, when doing 'logits = outputs[1:]' to get rid of loss from normal outputs, we are actually dropping logits and getting an empty tuple, which is a BUG.

    def prediction_step(
        self,
        ...
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
        ...

        with torch.no_grad():
            if is_sagemaker_mp_enabled():
                ...
            else:
                if has_labels or loss_without_labels:
                    ...

                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        logits = outputs[1:]

The correct way is as follows:

                    if isinstance(outputs, dict):
                        logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
                    else:
                        if len(outputs) == 1:
                            logits = outputs
                        else:
                            logits = outputs[1:]

Who can review?

@muellerzr and @SunMarc

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

…oss_func in some cases

When using User-defined compute_loss_func, the Trainer.compute_loss() function may do 'labels = inputs.pop("labels")', which leads to the absence of 'labels' in inputs of model. Then some predefined models, like BertForTokenClassification, will not return a loss value in there forward() methods. Then in Trainer.prediction_step() method, when doing 'logits = outputs[1:]' to get rid of loss from normal outputs, we are actually dropping logits and getting an empty tuple, which is a BUG.
@SunMarc
Copy link
Member

SunMarc commented Jan 24, 2025

Thanks for the explanation ! Could you share a reproducer of your issue please ?

@dolphin-Dang
Copy link
Author

Thanks for the explanation ! Could you share a reproducer of your issue please ?

train.py:

import os
import argparse
import random
import torch
from torch import nn
from transformers import TrainingArguments, Trainer, DataCollatorForTokenClassification, AutoTokenizer
from sklearn.model_selection import train_test_split
from src.dataset import NERDataset
from src.model import BertForNER
from functools import partial

def parse_args():
    parser = argparse.ArgumentParser(description="Train a BERT-based NER model")
    parser.add_argument("--train_data", type=str, required=True, help="Path to the training dataset file")
    parser.add_argument("--val_ratio", type=float, default=0.1, help="Ratio of the training data to use for validation")
    parser.add_argument("--model_path", type=str, required=True, help="Path to the pre-trained BERT model")
    parser.add_argument("--output_dir", type=str, default="./output", help="Output directory for the trained model")
    parser.add_argument("--batch_size", type=int, default=16, help="Batch size for training and evaluation")
    parser.add_argument("--num_epochs", type=int, default=3, help="Number of training epochs")
    parser.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate for AdamW optimizer")
    parser.add_argument("--warmup_ratio", type=float, default=0.1, help="Ratio of warmup steps for the scheduler")
    parser.add_argument("--logging_steps", type=int, default=10, help="Number of steps for logging")
    parser.add_argument("--save_steps", type=int, default=500, help="Number of steps for saving the model")
    parser.add_argument("--random_seed", type=int, default=42, help="Random seed.")
    parser.add_argument("--label_weight", type=float, default=1.0, help="Label weight for label 'O'.")
    parser.add_argument("--f1", type=str, default="macro", help="Macro or micro F1 score.")
    parser.add_argument("--max_len", type=int, default=200, help="Maximum characters in one single sentence.")
    parser.add_argument("--mode", type=str, choices=["A", "B", "C"], help="Mode for the NER task (A, B, or C). If not specified, it will be inferred from the train_data filename.")
    parser.add_argument("--crf", type=bool, default=False, help="Use CRF decoder or not.")
    return parser.parse_args()

def get_num_labels(mode):
    if mode in ["A", "C"]:
        return 25
    elif mode == "B":
        return 13
    else:
        raise ValueError(f"Invalid mode: {mode}. Mode should be one of A, B, or C.")

def main():
    args = parse_args()
    if args.mode is None:
        mode = os.path.basename(args.train_data).split('_')[-1].split('.')[0]
        print(f"Mode inferred from train_data filename: {mode}")
    else:
        mode = args.mode
        print(f"Mode specified: {mode}")

    num_labels = get_num_labels(mode)
    print(f"Number of labels for mode {mode}: {num_labels}")

    tokenizer = AutoTokenizer.from_pretrained(args.model_path)
    model = BertForNER(args.model_path, num_labels=num_labels, use_crf=args.crf)

    dataset = NERDataset(args.train_data, tokenizer, max_len=args.max_len)
    train_dataset, val_dataset = train_test_split(dataset, test_size=args.val_ratio, random_state=args.random_seed)
    data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

    training_args = TrainingArguments(
        output_dir=args.output_dir,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        learning_rate=args.learning_rate,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        num_train_epochs=args.num_epochs,
        save_total_limit=2,
        logging_steps=args.logging_steps,
        load_best_model_at_end=True,
        metric_for_best_model="f1",
        greater_is_better=True,
        warmup_ratio=args.warmup_ratio,
    )

    def compute_metrics(eval_pred):
        from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
        import numpy as np
        logits, labels = eval_pred
        predictions = logits.argmax(axis=-1)

        valid_indices = labels != -100
        labels = labels[valid_indices]
        predictions = predictions[valid_indices]

        # F1 Recall Precision
        f1 = f1_score(labels, predictions, average=args.f1)
        recall = recall_score(labels, predictions, average=args.f1)
        precision = precision_score(labels, predictions, average=args.f1)

        return {"f1": f1, "recall": recall, "precision": precision}


    def custom_loss_func(outputs, labels, label_weights=None, **kwargs):
        logits = outputs[0]
        if label_weights is not None:
            label_weights = label_weights.to(logits.device)
        labels = labels.to(logits.device)
        loss_fct = nn.CrossEntropyLoss(weight=label_weights)
        loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
        return loss

    label_weights = torch.tensor([args.label_weight if i == 0 else 1.0 for i in range(num_labels)], dtype=torch.float32)
    compute_loss = partial(custom_loss_func, label_weights=label_weights)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        compute_loss_func=compute_loss,
    )

    trainer.train()

    trainer.save_model(os.path.join(args.output_dir, "finetuned"))
    print(f"Model saved to {os.path.join(args.output_dir, 'finetuned')}")

if __name__ == "__main__":
    main()

src.model.py:

from transformers import BertForTokenClassification
import torch
from torch import nn

class BertForNER(nn.Module):
    def __init__(self, model_path, num_labels, use_crf=False):
        super(BertForNER, self).__init__()
        self.num_labels = num_labels
        self.bert = BertForTokenClassification.from_pretrained(model_path, num_labels=num_labels)


    def forward(self, input_ids, attention_mask, labels=None):
        bert_outputs = self.bert(
            input_ids=input_ids, labels=labels, attention_mask=attention_mask, return_dict=False
        )
        return bert_outputs

When Using a user-defined compute_loss_func, compute_metrics method's input EvalPredictions will not have a valid predictions field, which is caused by the issue mentioned above: "Then in Trainer.prediction_step() method, when doing 'logits = outputs[1:]' to get rid of loss from normal outputs, we are actually dropping logits and getting an empty tuple, which is a BUG."

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants