Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Transformers future] Loss Computation for Compatibility with Transformers 4.48.3 #1794

Open
wants to merge 4 commits into
base: transformers_future
Choose a base branch
from

Conversation

yafshar
Copy link
Contributor

@yafshar yafshar commented Feb 23, 2025

What does this PR do?

  • Update the GaudiTrainer to better match with with Transformers 4.48.3 with minor enhancement from 4.49.0

    • Remove the get_batch_samples_transformers and use the super class function get_batch_samples. The num_items_in_batch is only constant for special case of using examples with --dataset_concatenation
  • Update llama model loss computation

    • Simplified loss computation by using a configurable loss function.
    • Replaced the inline loss computation logic with a call to self.loss_function.
    • Introduced ForCausalLMContextParallelLoss for context parallel loss computation.
    • Updated the init method to set the loss function based on the parallel strategy.
    • Ensured compatibility with Transformers 4.48.3 by aligning with its structure and conventions.
  • Refactor loss computation in context parallel to better match with the upstream transformers

    • Replaced _ContextParallelLoss with ContextParallelLossFunction for better clarity and consistency.
    • Updated fixed_cross_entropy to use ContextParallelLossFunction for gathering losses across context parallel groups.
    • Introduced ForCausalLMContextParallelLoss to handle loss computation for causal language modeling with context parallelism.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

- Replaced _ContextParallelLoss with ContextParallelLossFunction for
  better clarity and consistency.
- Updated fixed_cross_entropy to use ContextParallelLossFunction for
  gathering losses across context parallel groups.
- Introduced ForCausalLMContextParallelLoss to handle loss computation
  for causal language modeling with context parallelism.
- Simplified loss computation by using a configurable loss function.
- Replaced the inline loss computation logic with a call to self.loss_function.
- Introduced ForCausalLMContextParallelLoss for context parallel loss computation.
- Updated the __init__ method to set the loss function based on the parallel strategy.
- Ensured compatibility with Transformers 4.48.3 by aligning with its structure and conventions.
@yafshar yafshar marked this pull request as ready for review February 23, 2025 12:58
@yafshar
Copy link
Contributor Author

yafshar commented Feb 23, 2025

@regisss I removed the get_batch_samples_transformers in GaudiTrainer and used the super class function get_batch_samples. The num_items_in_batch is only constant for special cases of using examples with --dataset_concatenation. That can be an extra optimization for later, but it is not a general case.

Copy link
Collaborator

@regisss regisss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left a few comments. Can you also add in the description of the PR that it brings some changes from Transformers 4.49.0 too?

for _ in range(total_updates):
update_step += 1
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
batch_samples, num_items_in_batch = self.get_batch_samples_transformers(epoch_iterator, num_batches)
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will probably lead to an error in TRL examples as it is pinned to a quite old version: https://github.com/huggingface/optimum-habana/blob/main/examples/trl/requirements.txt#L1
I'll take a look to see if we can update it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would need TRL v0.15 to not have a clash between the two self.get_batch_samples methods of the trainers, but it requires Accelerate >= 0.34.
My suggestion is to keep get_batch_samples_transformers and we can add a TODO saying that this should be removed when the Accelerate dependency is upgraded (which should happen soon, see my comments above).

@@ -1351,7 +1365,7 @@ def _maybe_log_save_evaluate(self, tr_loss, _grad_norm, model, trial, epoch, ign
self._globalstep_last_logged = self.state.global_step
self.store_flos()

self.log(logs, start_time=start_time)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also probably collides with former versions of TRL

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not look at TRL examples, will check more

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as #1794 (comment)

@@ -2615,31 +2636,3 @@ def _zero_model_grad(self, model):
except TypeError:
model.zero_grad()
model._zero_grad_kwargs = {}

def get_batch_samples_transformers(self, epoch_iterator, num_batches):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you checked if there is no regression in terms of throughput? The reason I added it was because of that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In some cases there is a minor regression, but we can do an optimization later. One simple one is to pre-compute num_items_in_batch as your previous comment before the loop. I can do some profiling and measure the impact. Can you point me to examples you find with noticeable regression

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can try:

RUN_SLOW=1 GAUDI2_CI=1 pytest tests/test_examples.py -v -s -k "test_run_clm_gpt2_single_card"

Note that you'll have to uncomment this line: https://github.com/yafshar/optimum-habana/blob/transformers_future/tests/utils.py#L38
Using GPT2 because the test is much faster than other tests with bigger models.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The regression for this model and test is > 50% difference. I am looking into it. Thanks for the hint

@yafshar yafshar force-pushed the transformers_future branch from 09298b6 to 7f32580 Compare February 24, 2025 18:56
Comment on lines -998 to +1011
batch_samples, num_items_in_batch = self.get_batch_samples_transformers(epoch_iterator, num_batches)
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't been able to reproduce the issue. However, I noticed that others have reported similar problems huggingface/trl#2275

what about changing that

               batch_samples, num_items_in_batch = super(Trainer, self).get_batch_samples(epoch_iterator, num_batches)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be a better way of managing it for sure! I guess it will ultimately depend on whether or not we need to override get_batch_samples to avoid throughput regressions (linked to the comment above).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants