Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PaliGemma Pad Token Masking During Training #35855 #35859

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

sambhavnoobcoder
Copy link

Problem Statement

In PaliGemma model's _update_causal_mask function, padding tokens were being incorrectly unmasked during training mode. This occurred because the order of operations first applied padding masks and then unmasked prefix tokens (including pad tokens) during training, leading to inconsistent behavior especially with left padding.

Fixes : #35855

Approach

After analyzing the masking logic, we identified that the issue stemmed from the sequence of operations in mask application. The solution was to reorder the masking operations to:

  1. First unmask prefix tokens during training mode
  2. Then apply padding masks

This ensures pad tokens remain masked regardless of their position or training state.

Implementation

The fix involved reordering the masking operations in the _update_causal_mask function while maintaining the same mathematical operations. This approach ensures:

  • No changes to the underlying logic
  • Minimal code changes
  • Preserved backward compatibility
  • Consistent behavior across training and inference modes

Test Coverage

The new test validates:

  1. Pad tokens remain masked (value = dtype.min) during training
  2. Non-pad tokens are properly masked/unmasked based on their position
  3. Behavior is consistent across different batch sizes and sequence lengths

Screenshots

Screenshot 2025-01-24 at 12 57 47 AM

cc: @amyeroberts @molbap @zucchini-nlp kindly review this whenever you find the time .

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for opening a PR! Can you move the test to the test_modeling_paligemma and adapt it for dummy weights. I don't think we need slow test with big model, since we're just checking mask on attn weights

@@ -0,0 +1,62 @@
import unittest
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add a test in test_modeling_paligemma.py file with tiny dummy model weights, and a single test covering attention with/without suffix should be enough

Copy link
Author

@sambhavnoobcoder sambhavnoobcoder Jan 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay . I have done this in 96e1b43 commit , and removed the separate testing file from the PR . kindly review it for any more changes as well . if everything is fine , I've pushed the fixed style changes and we are good to merge as well.

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating on this!

LGTM, but for the test we better just add a new one under PaliGemmaForConditionalGenerationModelTest which already creates dummy model and inputs. Sorry if it wasn't clear the first time
Also, let's test the model forward in general with output_attentions=True, instead of relying on a single update_causal_mask call. And add test if token type ids are passed vs when not passed, to make sure attention is masked correctly in all cases

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Paliegemma Pad Token not Masked
2 participants