Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Paliegemma Pad Token not Masked #35855

Open
2 of 4 tasks
TangsengT opened this issue Jan 23, 2025 · 2 comments · May be fixed by #35859
Open
2 of 4 tasks

Paliegemma Pad Token not Masked #35855

TangsengT opened this issue Jan 23, 2025 · 2 comments · May be fixed by #35859

Comments

@TangsengT
Copy link

TangsengT commented Jan 23, 2025

System Info

  • transformers version: 4.48.0
  • Platform: Linux-6.1.92-3-x86_64-with-glibc2.31
  • Python version: 3.10.15
  • Huggingface_hub version: 0.26.2
  • Safetensors version: 0.4.5
  • Accelerate version: 1.1.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.5.1+cu124 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: parallel set-up
  • Using GPU in script?: Yes
  • GPU type: NVIDIA A100 80GB PCIe

Who can help?

@amyeroberts @molbap
Description:

I encountered an unexpected behavior while using the Paligemma model. Specifically, the pad tokens' attention masks are being processed to zero (i.e., unmasked status) during the execution of the self._update_causal_mask(attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training) function.

After investigating, I found that this behavior stems from a change in the code between the latest version (4.48.0) and the previous version (e.g., 4.43.0). The order in which the pad tokens are masked has been reversed, leading to inconsistent behavior, especially when using left padding.

Code Comparison

Latest Version (4.48.0):

if attention_mask is not None:
    causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
    mask_length = attention_mask.shape[-1]
    padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(causal_mask.device)
    padding_mask = padding_mask == 0
    causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
        padding_mask, min_dtype
    )
    # we are training thus we need to create a full mask on the image + prefix but causal on suffix
    if is_training:
        causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
            token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
        )

Previous Version (4.43.0):

if token_type_ids is not None and labels is not None:
    # we are training thus we need to create a full mask on the image + prefix but causal on suffix
    target_length = cache_position[-1] + 1
    causal_mask = torch.full(
        (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
    )
    if sequence_length != 1:
        causal_mask = torch.triu(causal_mask, diagonal=1)
    causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
    causal_mask = causal_mask[None, None, :, :].expand(inputs_embeds.shape[0], 1, -1, -1)
    if attention_mask is not None:
        causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
        mask_length = attention_mask.shape[-1]
        padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
            causal_mask.device
        )
        # unmask the prefill
        causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
            token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
        )
        padding_mask = padding_mask == 0
        causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
            padding_mask, min_dtype
        )

Issue Details
In the latest version (4.48.0), the pad tokens are first masked using the padding_mask:

padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
    padding_mask, min_dtype
)

However, when the following block of code is executed during training:

if is_training:
    causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
        token_type_ids[:, None, None, :].to(causal_mask.device) == 0, 0
    )

the pad tokens (which have token_type_ids == 0) are unmasked again, effectively reverting the previous masking operation.

Impact
This behavior is particularly problematic when using left padding, as the pad tokens end up being unmasked in the latest version, whereas they were correctly masked in the previous version (4.43.0).

Question
Is this an intentional change, or is it a bug? If it’s a bug, could you please provide guidance on how to address it? I’d be happy to help with a fix if needed.

Thanks for your support!

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from PIL import Image
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration

processor = AutoProcessor.from_pretrained("path")
model = PaliGemmaForConditionalGeneration.from_pretrained("path")

prompt = ["<image>caption en","<image>caption en wwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwwww wwwwwwwwwwwwwwwwwww"]
labels = ["123123123123","456456456456"]
image_file = ["image path"] * 2
raw_image = [Image.open(i).convert("RGB") for i in image_file]

inputs = processor(text=prompt, images=raw_image, suffix=labels,
                    return_tensors="pt", padding="longest")
model(**inputs)

Expected behavior

The pad tokens end up being unmasked.

@zucchini-nlp
Copy link
Member

@TangsengT hey!

Yep, seems like a bug and was definitely not intended. Indeed the padding should be masked even when in training mode. What we can do is to switch order of padding masking and suffix unmasking in update_causal_mask, Would you like to open a PR?

@sambhavnoobcoder
Copy link

Hey @zucchini-nlp,@amyeroberts , @molbap , I have raised a PR #35859 as a fix for this . hope you can review this and merge if issue is resolved . I'll acknowledge any comments you have on it as well if required .

Thank you @TangsengT for raising the issue . 🤗

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants