Skip to content

Commit

Permalink
Refactor loss computation
Browse files Browse the repository at this point in the history
- Simplified loss computation by using a configurable loss function.
- Replaced the inline loss computation logic with a call to self.loss_function.
- Introduced ForCausalLMContextParallelLoss for context parallel loss computation.
- Updated the __init__ method to set the loss function based on the parallel strategy.
- Ensured compatibility with Transformers 4.48.3 by aligning with its structure and conventions.
  • Loading branch information
yafshar committed Feb 22, 2025
1 parent 7eecd23 commit 7f32580
Showing 1 changed file with 5 additions and 27 deletions.
32 changes: 5 additions & 27 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,7 +1418,10 @@ class GaudiLlamaForCausalLM(LlamaForCausalLM):
def __init__(self, config, parallel_strategy: DistributedStrategy = NoOpStrategy):
config.parallel_strategy = parallel_strategy
super().__init__(config)
self._is_context_parallel_active: bool = parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1
if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1:
from ....distributed.contextparallel import ForCausalLMContextParallelLoss

self._loss_function = ForCausalLMContextParallelLoss

def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)
Expand Down Expand Up @@ -1507,32 +1510,7 @@ def forward(

loss = None
if labels is not None:
# Collect losses from context parallel group
# Each rank in group calculates loss on partial outputs
if self._is_context_parallel_active:
from ....distributed.contextparallel import _get_loss_from_context_parallel

# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)

loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
loss_all = _get_loss_from_context_parallel(loss_fct(shift_logits, shift_labels))
num_items_in_batch = kwargs.get("num_items_in_batch", None)
if num_items_in_batch is None:
loss = torch.mean(loss_all)
else:
loss = torch.sum(loss_all) / num_items_in_batch
else:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down

0 comments on commit 7f32580

Please sign in to comment.