Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sdpa in sam and refactor relative position embeddings #36422

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

geetu040
Copy link
Contributor

What does this PR do?

This PR identifies and fixes the following problems with SamVisionSdpaAttention

  • attn_weights that are returned by SamVisionAttention and SamVisionSdpaAttention are of different sizes: [batch_size, num_attention_head, seq_len, seq_len] and [batch_size * num_attention_head, seq_len, seq_len] respectively.
  • scaled_dot_product_attention in torch doesnot return attn_weights, so these attn_weights are calculated manually in SamVisionSdpaAttention which defeats the purpose of sdpa in the first place. Instead, if output_attentions==True we should fall back to eager implementation in SamVisionAttention.
  • relative position embeddings in both attention layers, can be refactored to cleaner code, fewer reshaping and a single function SamVisionAttention.get_decomposed_rel_pos.

I had to change add_decomposed_rel_pos function, since when output_attentions==False we are falling back to "eager", in which case the overloaded method is not called and parent's method is not compatible.
This change affected some other classes as well

  1. GotOcr2VisionAttention completely inherits SamVisionAttention, since modular, just running utils/modular_model_converter.py should fix it.
  2. TFSamVisionAttention mimics the SamVisionAttention class, both look slightly different now.

A few questions from my side, related to add_decomposed_rel_pos

  1. Should I update TFSamVisionAttention class to also look like SamVisionAttention?
  2. Or should I keep the original function add_decomposed_rel_pos and make slight changes to make it compatible with SamVisionSdpaAttention as well and also apply these changes to TFSamVisionAttention?

Who can review?

@amyeroberts, @qubvel, @zucchini-nlp

@qubvel
Copy link
Member

qubvel commented Feb 26, 2025

Hi @geetu040, thanks for opening the PR!

@qubvel
Copy link
Member

qubvel commented Feb 26, 2025

run-slow: sam

@qubvel qubvel added the Vision label Feb 26, 2025
Copy link

This comment contains run-slow, running the specified jobs: This comment contains run-slow, running the specified jobs:

models: ['models/sam']
quantizations: [] ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants