-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix PaliGemma Pad Token Masking During Training #35855 #35859
base: main
Are you sure you want to change the base?
Fix PaliGemma Pad Token Masking During Training #35855 #35859
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for opening a PR! Can you move the test to the test_modeling_paligemma
and adapt it for dummy weights. I don't think we need slow test with big model, since we're just checking mask on attn weights
@@ -0,0 +1,62 @@ | |||
import unittest |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to add a test in test_modeling_paligemma.py
file with tiny dummy model weights, and a single test covering attention with/without suffix should be enough
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay . I have done this in 96e1b43 commit , and removed the separate testing file from the PR . kindly review it for any more changes as well . if everything is fine , I've pushed the fixed style changes and we are good to merge as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating on this!
LGTM, but for the test we better just add a new one under PaliGemmaForConditionalGenerationModelTest
which already creates dummy model and inputs. Sorry if it wasn't clear the first time
Also, let's test the model forward in general with output_attentions=True
, instead of relying on a single update_causal_mask
call. And add test if token type ids are passed vs when not passed, to make sure attention is masked correctly in all cases
Problem Statement
In PaliGemma model's
_update_causal_mask
function, padding tokens were being incorrectly unmasked during training mode. This occurred because the order of operations first applied padding masks and then unmasked prefix tokens (including pad tokens) during training, leading to inconsistent behavior especially with left padding.Fixes : #35855
Approach
After analyzing the masking logic, we identified that the issue stemmed from the sequence of operations in mask application. The solution was to reorder the masking operations to:
This ensures pad tokens remain masked regardless of their position or training state.
Implementation
The fix involved reordering the masking operations in the
_update_causal_mask
function while maintaining the same mathematical operations. This approach ensures:Test Coverage
The new test validates:
Screenshots
cc: @amyeroberts @molbap @zucchini-nlp kindly review this whenever you find the time .