-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix lost loss values when using user-defined compute_loss_func in some cases #35872
base: main
Are you sure you want to change the base?
Conversation
…oss_func in some cases When using User-defined compute_loss_func, the Trainer.compute_loss() function may do 'labels = inputs.pop("labels")', which leads to the absence of 'labels' in inputs of model. Then some predefined models, like BertForTokenClassification, will not return a loss value in there forward() methods. Then in Trainer.prediction_step() method, when doing 'logits = outputs[1:]' to get rid of loss from normal outputs, we are actually dropping logits and getting an empty tuple, which is a BUG.
Thanks for the explanation ! Could you share a reproducer of your issue please ? |
train.py: import os
import argparse
import random
import torch
from torch import nn
from transformers import TrainingArguments, Trainer, DataCollatorForTokenClassification, AutoTokenizer
from sklearn.model_selection import train_test_split
from src.dataset import NERDataset
from src.model import BertForNER
from functools import partial
def parse_args():
parser = argparse.ArgumentParser(description="Train a BERT-based NER model")
parser.add_argument("--train_data", type=str, required=True, help="Path to the training dataset file")
parser.add_argument("--val_ratio", type=float, default=0.1, help="Ratio of the training data to use for validation")
parser.add_argument("--model_path", type=str, required=True, help="Path to the pre-trained BERT model")
parser.add_argument("--output_dir", type=str, default="./output", help="Output directory for the trained model")
parser.add_argument("--batch_size", type=int, default=16, help="Batch size for training and evaluation")
parser.add_argument("--num_epochs", type=int, default=3, help="Number of training epochs")
parser.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate for AdamW optimizer")
parser.add_argument("--warmup_ratio", type=float, default=0.1, help="Ratio of warmup steps for the scheduler")
parser.add_argument("--logging_steps", type=int, default=10, help="Number of steps for logging")
parser.add_argument("--save_steps", type=int, default=500, help="Number of steps for saving the model")
parser.add_argument("--random_seed", type=int, default=42, help="Random seed.")
parser.add_argument("--label_weight", type=float, default=1.0, help="Label weight for label 'O'.")
parser.add_argument("--f1", type=str, default="macro", help="Macro or micro F1 score.")
parser.add_argument("--max_len", type=int, default=200, help="Maximum characters in one single sentence.")
parser.add_argument("--mode", type=str, choices=["A", "B", "C"], help="Mode for the NER task (A, B, or C). If not specified, it will be inferred from the train_data filename.")
parser.add_argument("--crf", type=bool, default=False, help="Use CRF decoder or not.")
return parser.parse_args()
def get_num_labels(mode):
if mode in ["A", "C"]:
return 25
elif mode == "B":
return 13
else:
raise ValueError(f"Invalid mode: {mode}. Mode should be one of A, B, or C.")
def main():
args = parse_args()
if args.mode is None:
mode = os.path.basename(args.train_data).split('_')[-1].split('.')[0]
print(f"Mode inferred from train_data filename: {mode}")
else:
mode = args.mode
print(f"Mode specified: {mode}")
num_labels = get_num_labels(mode)
print(f"Number of labels for mode {mode}: {num_labels}")
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
model = BertForNER(args.model_path, num_labels=num_labels, use_crf=args.crf)
dataset = NERDataset(args.train_data, tokenizer, max_len=args.max_len)
train_dataset, val_dataset = train_test_split(dataset, test_size=args.val_ratio, random_state=args.random_seed)
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
training_args = TrainingArguments(
output_dir=args.output_dir,
evaluation_strategy="epoch",
save_strategy="epoch",
learning_rate=args.learning_rate,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
num_train_epochs=args.num_epochs,
save_total_limit=2,
logging_steps=args.logging_steps,
load_best_model_at_end=True,
metric_for_best_model="f1",
greater_is_better=True,
warmup_ratio=args.warmup_ratio,
)
def compute_metrics(eval_pred):
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
import numpy as np
logits, labels = eval_pred
predictions = logits.argmax(axis=-1)
valid_indices = labels != -100
labels = labels[valid_indices]
predictions = predictions[valid_indices]
# F1 Recall Precision
f1 = f1_score(labels, predictions, average=args.f1)
recall = recall_score(labels, predictions, average=args.f1)
precision = precision_score(labels, predictions, average=args.f1)
return {"f1": f1, "recall": recall, "precision": precision}
def custom_loss_func(outputs, labels, label_weights=None, **kwargs):
logits = outputs[0]
if label_weights is not None:
label_weights = label_weights.to(logits.device)
labels = labels.to(logits.device)
loss_fct = nn.CrossEntropyLoss(weight=label_weights)
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
return loss
label_weights = torch.tensor([args.label_weight if i == 0 else 1.0 for i in range(num_labels)], dtype=torch.float32)
compute_loss = partial(custom_loss_func, label_weights=label_weights)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
compute_loss_func=compute_loss,
)
trainer.train()
trainer.save_model(os.path.join(args.output_dir, "finetuned"))
print(f"Model saved to {os.path.join(args.output_dir, 'finetuned')}")
if __name__ == "__main__":
main() src.model.py: from transformers import BertForTokenClassification
import torch
from torch import nn
class BertForNER(nn.Module):
def __init__(self, model_path, num_labels, use_crf=False):
super(BertForNER, self).__init__()
self.num_labels = num_labels
self.bert = BertForTokenClassification.from_pretrained(model_path, num_labels=num_labels)
def forward(self, input_ids, attention_mask, labels=None):
bert_outputs = self.bert(
input_ids=input_ids, labels=labels, attention_mask=attention_mask, return_dict=False
)
return bert_outputs When Using a user-defined compute_loss_func, compute_metrics method's input EvalPredictions will not have a valid predictions field, which is caused by the issue mentioned above: "Then in Trainer.prediction_step() method, when doing 'logits = outputs[1:]' to get rid of loss from normal outputs, we are actually dropping logits and getting an empty tuple, which is a BUG." |
What does this PR do?
When using User-defined compute_loss_func, the Trainer.compute_loss() function may do 'labels = inputs.pop("labels")', which leads to the absence of 'labels' in inputs of model.
Then some predefined models, like BertForTokenClassification, will not return a loss value in there forward() methods.
Then in Trainer.prediction_step() method, when doing 'logits = outputs[1:]' to get rid of loss from normal outputs, we are actually dropping logits and getting an empty tuple, which is a BUG.
The correct way is as follows:
Who can review?
@muellerzr and @SunMarc
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.