Skip to content

Commit

Permalink
just remove arg
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress committed Jan 30, 2025
1 parent 53625b2 commit cd91880
Showing 1 changed file with 0 additions and 4 deletions.
4 changes: 0 additions & 4 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,7 +1346,6 @@ def compute_loss_from_logits(
shift_labels: bool,
labels: torch.Tensor,
loss_fn: nn.Module,
sample_weighing_factor: Optional[torch.Tensor] = None,
) -> torch.Tensor:
targets = get_targets(labels) if shift_labels else labels

Expand All @@ -1359,8 +1358,6 @@ def compute_loss_from_logits(
loss = losses.sum()
else:
loss = losses.sum() / (targets != loss_fn.ignore_index).sum()
if sample_weighing_factor is not None:
raise ValueError('sample_weighing_factor has been discontinued!')

return loss

Expand Down Expand Up @@ -1469,7 +1466,6 @@ def loss(self, outputs: CausalLMOutputWithPast,
self.shift_labels,
batch['labels'],
self.loss_fn,
batch.get('sample_weighing_factor', None),
)

if self.config.ffn_config['ffn_type'] in ffns_with_megablocks:
Expand Down

0 comments on commit cd91880

Please sign in to comment.