-
Notifications
You must be signed in to change notification settings - Fork 234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Transformers future] Loss Computation for Compatibility with Transformers 4.48.3 #1794
base: transformers_future
Are you sure you want to change the base?
[Transformers future] Loss Computation for Compatibility with Transformers 4.48.3 #1794
Conversation
… gradient accumulation issue
- Replaced _ContextParallelLoss with ContextParallelLossFunction for better clarity and consistency. - Updated fixed_cross_entropy to use ContextParallelLossFunction for gathering losses across context parallel groups. - Introduced ForCausalLMContextParallelLoss to handle loss computation for causal language modeling with context parallelism.
- Simplified loss computation by using a configurable loss function. - Replaced the inline loss computation logic with a call to self.loss_function. - Introduced ForCausalLMContextParallelLoss for context parallel loss computation. - Updated the __init__ method to set the loss function based on the parallel strategy. - Ensured compatibility with Transformers 4.48.3 by aligning with its structure and conventions.
@regisss I removed the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I left a few comments. Can you also add in the description of the PR that it brings some changes from Transformers 4.49.0 too?
for _ in range(total_updates): | ||
update_step += 1 | ||
num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder | ||
batch_samples, num_items_in_batch = self.get_batch_samples_transformers(epoch_iterator, num_batches) | ||
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will probably lead to an error in TRL examples as it is pinned to a quite old version: https://github.com/huggingface/optimum-habana/blob/main/examples/trl/requirements.txt#L1
I'll take a look to see if we can update it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We would need TRL v0.15 to not have a clash between the two self.get_batch_samples
methods of the trainers, but it requires Accelerate >= 0.34.
My suggestion is to keep get_batch_samples_transformers
and we can add a TODO saying that this should be removed when the Accelerate dependency is upgraded (which should happen soon, see my comments above).
@@ -1351,7 +1365,7 @@ def _maybe_log_save_evaluate(self, tr_loss, _grad_norm, model, trial, epoch, ign | |||
self._globalstep_last_logged = self.state.global_step | |||
self.store_flos() | |||
|
|||
self.log(logs, start_time=start_time) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also probably collides with former versions of TRL
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not look at TRL examples, will check more
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as #1794 (comment)
@@ -2615,31 +2636,3 @@ def _zero_model_grad(self, model): | |||
except TypeError: | |||
model.zero_grad() | |||
model._zero_grad_kwargs = {} | |||
|
|||
def get_batch_samples_transformers(self, epoch_iterator, num_batches): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you checked if there is no regression in terms of throughput? The reason I added it was because of that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In some cases there is a minor regression, but we can do an optimization later. One simple one is to pre-compute num_items_in_batch as your previous comment before the loop. I can do some profiling and measure the impact. Can you point me to examples you find with noticeable regression
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can try:
RUN_SLOW=1 GAUDI2_CI=1 pytest tests/test_examples.py -v -s -k "test_run_clm_gpt2_single_card"
Note that you'll have to uncomment this line: https://github.com/yafshar/optimum-habana/blob/transformers_future/tests/utils.py#L38
Using GPT2 because the test is much faster than other tests with bigger models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The regression for this model and test is > 50% difference. I am looking into it. Thanks for the hint
09298b6
to
7f32580
Compare
batch_samples, num_items_in_batch = self.get_batch_samples_transformers(epoch_iterator, num_batches) | ||
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't been able to reproduce the issue. However, I noticed that others have reported similar problems huggingface/trl#2275
what about changing that
batch_samples, num_items_in_batch = super(Trainer, self).get_batch_samples(epoch_iterator, num_batches)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would be a better way of managing it for sure! I guess it will ultimately depend on whether or not we need to override get_batch_samples
to avoid throughput regressions (linked to the comment above).
What does this PR do?
Update the GaudiTrainer to better match with with Transformers 4.48.3 with minor enhancement from 4.49.0
num_items_in_batch
is only constant for special case of using examples with--dataset_concatenation
Update llama model loss computation
Refactor loss computation in context parallel to better match with the upstream transformers
Before submitting