You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I encountered an unexpected behavior while using the Paligemma model. Specifically, the pad tokens' attention masks are being processed to zero (i.e., unmasked status) during the execution of the self._update_causal_mask(attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training) function.
After investigating, I found that this behavior stems from a change in the code between the latest version (4.48.0) and the previous version (e.g., 4.43.0). The order in which the pad tokens are masked has been reversed, leading to inconsistent behavior, especially when using left padding.
Code Comparison
Latest Version (4.48.0):
ifattention_maskisnotNone:
causal_mask=causal_mask.clone() # copy to contiguous memory for in-place editmask_length=attention_mask.shape[-1]
padding_mask=causal_mask[:, :, :, :mask_length] +attention_mask[:, None, None, :].to(causal_mask.device)
padding_mask=padding_mask==0causal_mask[:, :, :, :mask_length] =causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
# we are training thus we need to create a full mask on the image + prefix but causal on suffixifis_training:
causal_mask[:, :, :, :mask_length] =causal_mask[:, :, :, :mask_length].masked_fill(
token_type_ids[:, None, None, :].to(causal_mask.device) ==0, 0
)
Previous Version (4.43.0):
iftoken_type_idsisnotNoneandlabelsisnotNone:
# we are training thus we need to create a full mask on the image + prefix but causal on suffixtarget_length=cache_position[-1] +1causal_mask=torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
ifsequence_length!=1:
causal_mask=torch.triu(causal_mask, diagonal=1)
causal_mask*=torch.arange(target_length, device=device) >cache_position.reshape(-1, 1)
causal_mask=causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1)
ifattention_maskisnotNone:
causal_mask=causal_mask.clone() # copy to contiguous memory for in-place editmask_length=attention_mask.shape[-1]
padding_mask=causal_mask[:, :, :, :mask_length] +attention_mask[:, None, None, :].to(
causal_mask.device
)
# unmask the prefillcausal_mask[:, :, :, :mask_length] =causal_mask[:, :, :, :mask_length].masked_fill(
token_type_ids[:, None, None, :].to(causal_mask.device) ==0, 0
)
padding_mask=padding_mask==0causal_mask[:, :, :, :mask_length] =causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
Issue Details
In the latest version (4.48.0), the pad tokens are first masked using the padding_mask:
the pad tokens (which have token_type_ids == 0) are unmasked again, effectively reverting the previous masking operation.
Impact
This behavior is particularly problematic when using left padding, as the pad tokens end up being unmasked in the latest version, whereas they were correctly masked in the previous version (4.43.0).
Question
Is this an intentional change, or is it a bug? If it’s a bug, could you please provide guidance on how to address it? I’d be happy to help with a fix if needed.
Thanks for your support!
Information
The official example scripts
My own modified scripts
Tasks
An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
Yep, seems like a bug and was definitely not intended. Indeed the padding should be masked even when in training mode. What we can do is to switch order of padding masking and suffix unmasking in update_causal_mask, Would you like to open a PR?
Hey @zucchini-nlp,@amyeroberts , @molbap , I have raised a PR #35859 as a fix for this . hope you can review this and merge if issue is resolved . I'll acknowledge any comments you have on it as well if required .
System Info
transformers
version: 4.48.0Who can help?
@amyeroberts @molbap
Description:
I encountered an unexpected behavior while using the Paligemma model. Specifically, the pad tokens' attention masks are being processed to zero (i.e., unmasked status) during the execution of the
self._update_causal_mask(attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training)
function.After investigating, I found that this behavior stems from a change in the code between the latest version (4.48.0) and the previous version (e.g., 4.43.0). The order in which the pad tokens are masked has been reversed, leading to inconsistent behavior, especially when using left padding.
Code Comparison
Latest Version (4.48.0):
Previous Version (4.43.0):
Issue Details
In the latest version (4.48.0), the pad tokens are first masked using the padding_mask:
However, when the following block of code is executed during training:
the pad tokens (which have token_type_ids == 0) are unmasked again, effectively reverting the previous masking operation.
Impact
This behavior is particularly problematic when using left padding, as the pad tokens end up being unmasked in the latest version, whereas they were correctly masked in the previous version (4.43.0).
Question
Is this an intentional change, or is it a bug? If it’s a bug, could you please provide guidance on how to address it? I’d be happy to help with a fix if needed.
Thanks for your support!
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Expected behavior
The pad tokens end up being unmasked.
The text was updated successfully, but these errors were encountered: