Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle num_items_in_batch in Mistral's forward #34576

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

gheinrich
Copy link

What does this PR do?

This PR enables handling loss keyword arguments in the Mistral model's forward() method.
Specifically, if num_items_in_batch is passed, the value is used to properly normalize the loss value.

This relates to the Gradient Accumulation fix (#34191)

Fixes #34575

cc @ArthurZucker as it relates to text models.

This PR enables handling loss keyword arguments in the Mistral
forward() method. Specifically, if `num_items_in_batch` is passed,
the value is used to properly normalize the loss value.

This relates to the Gradient Accumulation fix (huggingface#34191)

Fixes huggingface#34575
@gheinrich gheinrich force-pushed the dev/mistral-num-items-in-batch branch from adf418a to a4faa09 Compare November 2, 2024 08:38
@Rocketknight1
Copy link
Member

cc @muellerzr for the GA fix as well!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! This is not how we fixed it for other models 😉
see

loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)

@gheinrich
Copy link
Author

Hello, other models have the loss function defined in a parent class. Mistral models have it defined in the forward method. If I don't want to change this behavior, how do you suggest I proceed?

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Nov 25, 2024

You need to pretty much copy-paste the code in Llama / other models 🤗 The parent class is the same for Mistral as well

@muellerzr
Copy link
Contributor

Will be superseded/fulfilled with #35875

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Unhandled 'num_items_in_batch' in Mistral model
4 participants