Fix sdpa in sam and refactor relative position embeddings #36422
+42
−86
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
This PR identifies and fixes the following problems with
SamVisionSdpaAttention
attn_weights
that are returned bySamVisionAttention
andSamVisionSdpaAttention
are of different sizes:[batch_size, num_attention_head, seq_len, seq_len]
and[batch_size * num_attention_head, seq_len, seq_len]
respectively.scaled_dot_product_attention
intorch
doesnot returnattn_weights
, so theseattn_weights
are calculated manually inSamVisionSdpaAttention
which defeats the purpose of sdpa in the first place. Instead, ifoutput_attentions==True
we should fall back to eager implementation inSamVisionAttention
.SamVisionAttention.get_decomposed_rel_pos
.I had to change
add_decomposed_rel_pos
function, since whenoutput_attentions==False
we are falling back to "eager", in which case the overloaded method is not called and parent's method is not compatible.This change affected some other classes as well
GotOcr2VisionAttention
completely inheritsSamVisionAttention
, since modular, just runningutils/modular_model_converter.py
should fix it.TFSamVisionAttention
mimics theSamVisionAttention
class, both look slightly different now.A few questions from my side, related to
add_decomposed_rel_pos
TFSamVisionAttention
class to also look likeSamVisionAttention
?add_decomposed_rel_pos
and make slight changes to make it compatible withSamVisionSdpaAttention
as well and also apply these changes toTFSamVisionAttention
?Who can review?
@amyeroberts, @qubvel, @zucchini-nlp