Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Transformers future] Loss Computation for Compatibility with Transformers 4.48.3 #1794

Open
wants to merge 4 commits into
base: transformers_future
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 46 additions & 6 deletions optimum/habana/distributed/contextparallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,64 @@
)


# Gather losses across context parallel group
class _ContextParallelLoss(torch.autograd.Function):
class ContextParallelLossFunction(torch.autograd.Function):
"""
Gather losses across context parallel group.
This custom autograd function is designed to handle the distribution of loss computation
across multiple parallel contexts in a distributed training setup. It ensures that the loss
is gathered from all devices involved in the parallel context, allowing for consistent and
accurate computation of the overall loss.
The forward method gathers the loss from all ranks in the context parallel group, while the
backward method ensures that gradients are correctly synchronized across the different parallel
contexts.
"""

@staticmethod
def forward(ctx, loss):
ctx.seqlen = loss.size(0) * get_sequence_parallel_world_size()

# Create a tensor to gather all losses from context parallel group
loss_all = torch.empty(ctx.seqlen, dtype=loss.dtype, device=loss.device)
# Gather losses from all ranks in the group
torch.distributed.all_gather_into_tensor(loss_all, loss, group=get_sequence_parallel_group())
return loss_all

@staticmethod
def backward(ctx, grad_output):
step_seqlen = ctx.seqlen // get_sequence_parallel_world_size()
sp_rank = get_sequence_parallel_rank()
# Extract the relevant part of the gradient for this rank
grad_output_part = grad_output[step_seqlen * sp_rank : step_seqlen * (sp_rank + 1)]

return grad_output_part, None


def _get_loss_from_context_parallel(vocab_parallel_loss):
return _ContextParallelLoss.apply(vocab_parallel_loss)
def fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs):
loss_all = torch.nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
# Apply context parallel loss
loss_all = ContextParallelLossFunction.apply(loss_all)
if num_items_in_batch is None:
loss = torch.mean(loss_all)
else:
loss = torch.sum(loss_all) / num_items_in_batch
return loss


def ForCausalLMContextParallelLoss(
logits, labels, vocab_size: int, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs
):
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()

# Flatten the tokens
shift_logits = shift_logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)

loss = fixed_cross_entropy(shift_logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
return loss
29 changes: 5 additions & 24 deletions optimum/habana/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1418,6 +1418,10 @@ class GaudiLlamaForCausalLM(LlamaForCausalLM):
def __init__(self, config, parallel_strategy: DistributedStrategy = NoOpStrategy):
config.parallel_strategy = parallel_strategy
super().__init__(config)
if parallel_state.sequence_parallel_is_initialized() and parallel_state.get_sequence_parallel_world_size() > 1:
from ....distributed.contextparallel import ForCausalLMContextParallelLoss

self._loss_function = ForCausalLMContextParallelLoss

def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)
Expand Down Expand Up @@ -1506,30 +1510,7 @@ def forward(

loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = torch.nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
# Collect losses from context parallel group
# Each rank in group calculates loss on partial outputs
if (
parallel_state.sequence_parallel_is_initialized()
and parallel_state.get_sequence_parallel_world_size() > 1
):
from ....distributed.contextparallel import _get_loss_from_context_parallel

loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
loss_all = _get_loss_from_context_parallel(loss_fct(shift_logits, shift_labels))
loss = torch.mean(loss_all)
else:
loss = loss_fct(shift_logits, shift_labels)
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
Loading