From ae83b6b79bef7e1350c371093bc3c81fc05f162b Mon Sep 17 00:00:00 2001 From: eleanorTurintech Date: Tue, 25 Feb 2025 15:42:56 +0000 Subject: [PATCH 1/5] Optimize memory usage in BERT embeddings forward pass --- src/transformers/models/bert/modeling_bert.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 8e48263c9300..23bd32f42066 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -202,19 +202,25 @@ def forward( if token_type_ids is None: if hasattr(self, "token_type_ids"): buffered_token_type_ids = self.token_type_ids[:, :seq_length] - buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + # Use repeat instead of expand for potential memory efficiency gains + buffered_token_type_ids_expanded = buffered_token_type_ids.repeat(input_shape[0], 1) token_type_ids = buffered_token_type_ids_expanded else: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) - token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings + # Perform inplace addition for memory efficiency + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = inputs_embeds + token_type_embeddings # This creates new tensor. + del inputs_embeds, token_type_embeddings # Delete intermediate tensors to free memory + if self.position_embedding_type == "absolute": position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + embeddings += position_embeddings # Use in-place operation to reduce memory consumption + del position_embeddings # Delete to free memory if not used later. + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings From 1fa3fbf59bfb0faccbe03a21618658c81ff2ae00 Mon Sep 17 00:00:00 2001 From: eleanorTurintech Date: Tue, 25 Feb 2025 15:46:54 +0000 Subject: [PATCH 2/5] Remove fallback to manual attention implementation checks --- src/transformers/models/bert/modeling_bert.py | 150 ++++++++---------- 1 file changed, 62 insertions(+), 88 deletions(-) diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 23bd32f42066..d5433b61ca66 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -367,98 +367,72 @@ def __init__(self, config, position_embedding_type=None): self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") # Adapted from BertSelfAttention - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - output_attentions: Optional[bool] = False, - ) -> Tuple[torch.Tensor]: - if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. - logger.warning_once( - "BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " - "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " - "the manual attention implementation, but specifying the manual implementation will be required from " - "Transformers version v5.0.0 onwards. This warning can be removed using the argument " - '`attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - - bsz, tgt_len, _ = hidden_states.size() - - query_layer = self.transpose_for_scores(self.query(hidden_states)) - - # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention - # mask needs to be such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None - - current_states = encoder_hidden_states if is_cross_attention else hidden_states - attention_mask = encoder_attention_mask if is_cross_attention else attention_mask - - # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning - if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: - key_layer, value_layer = past_key_value - else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) - if past_key_value is not None and not is_cross_attention: - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) +def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, +) -> Tuple[torch.Tensor]: + bsz, tgt_len, _ = hidden_states.size() + + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention + # mask needs to be such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning + if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: + key_layer, value_layer = past_key_value + else: + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) + if past_key_value is not None and not is_cross_attention: + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom - # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. - # Reference: https://github.com/pytorch/pytorch/issues/112577 - if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: - query_layer = query_layer.contiguous() - key_layer = key_layer.contiguous() - value_layer = value_layer.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create - # a causal mask in case tgt_len == 1. - is_causal = ( - True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False - ) + if self.is_decoder: + past_key_value = (key_layer, value_layer) + + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom + # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. + # Reference: https://github.com/pytorch/pytorch/issues/112577 + if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: + query_layer = query_layer.contiguous() + key_layer = key_layer.contiguous() + value_layer = value_layer.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create + # a causal mask in case tgt_len == 1. + is_causal = ( + True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False + ) - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_layer, - key_layer, - value_layer, - attn_mask=attention_mask, - dropout_p=self.dropout_prob if self.training else 0.0, - is_causal=is_causal, - ) + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.dropout_prob if self.training else 0.0, + is_causal=is_causal, + ) - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) - outputs = (attn_output,) - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + outputs = (attn_output,) + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs class BertSelfOutput(nn.Module): From d2745fbabdee886eed206a6495492ab130d316a7 Mon Sep 17 00:00:00 2001 From: eleanorTurintech Date: Tue, 25 Feb 2025 15:49:07 +0000 Subject: [PATCH 3/5] Enhance type safety by enforcing tuple conversions in BERT forward pass --- src/transformers/models/bert/modeling_bert.py | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index d5433b61ca66..38f4ad4fb0e9 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -561,7 +561,9 @@ def forward( output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # Explicitly handle the case where past_key_value is None + self_attn_past_key_value = tuple(past_key_value[:2]) if past_key_value is not None else None + self_attention_outputs = self.attention( hidden_states, attention_mask, @@ -573,10 +575,11 @@ def forward( # if decoder, the last output is tuple of self-attn cache if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] + outputs = tuple(self_attention_outputs[1:-1]) # Convert to tuple for consistency + present_key_value = tuple(self_attention_outputs[-1]) # convert to tuple + else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + outputs = tuple(self_attention_outputs[1:]) # add self attentions if we output attention weights, convert to tuple cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: @@ -587,7 +590,9 @@ def forward( ) # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + # Explicitly handle the case where past_key_value is None + cross_attn_past_key_value = tuple(past_key_value[-2:]) if past_key_value is not None else None + cross_attention_outputs = self.crossattention( attention_output, attention_mask, @@ -598,11 +603,14 @@ def forward( output_attentions, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + outputs = outputs + tuple(cross_attention_outputs[1:-1]) # add cross attentions if we output attention weights, convert to tuple # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value + cross_attn_present_key_value = tuple(cross_attention_outputs[-1]) # convert to tuple + + # Ensure present_key_value is a tuple before adding + present_key_value = present_key_value + cross_attn_present_key_value + layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output @@ -611,7 +619,7 @@ def forward( # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (present_key_value,) # Ensure present_key_value is a tuple return outputs From 53341fab1a2c6b5e7ca0223a21d49006acc9ead1 Mon Sep 17 00:00:00 2001 From: eleanorTurintech Date: Tue, 25 Feb 2025 16:14:44 +0000 Subject: [PATCH 4/5] Cleanup --- src/transformers/models/bert/modeling_bert.py | 151 +++++++++--------- 1 file changed, 78 insertions(+), 73 deletions(-) diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 38f4ad4fb0e9..3d819e0d40b7 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -213,9 +213,9 @@ def forward( # Perform inplace addition for memory efficiency token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings # This creates new tensor. - del inputs_embeds, token_type_embeddings # Delete intermediate tensors to free memory - + embeddings = inputs_embeds + token_type_embeddings # This creates new tensor. + del inputs_embeds, token_type_embeddings # Delete intermediate tensors to free memory + if self.position_embedding_type == "absolute": position_embeddings = self.position_embeddings(position_ids) embeddings += position_embeddings # Use in-place operation to reduce memory consumption @@ -367,72 +367,74 @@ def __init__(self, config, position_embedding_type=None): self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") # Adapted from BertSelfAttention -def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, - output_attentions: Optional[bool] = False, -) -> Tuple[torch.Tensor]: - bsz, tgt_len, _ = hidden_states.size() - - query_layer = self.transpose_for_scores(self.query(hidden_states)) - - # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention - # mask needs to be such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None - - current_states = encoder_hidden_states if is_cross_attention else hidden_states - attention_mask = encoder_attention_mask if is_cross_attention else attention_mask - - # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning - if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: - key_layer, value_layer = past_key_value - else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) - if past_key_value is not None and not is_cross_attention: - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - if self.is_decoder: - past_key_value = (key_layer, value_layer) - - # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom - # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. - # Reference: https://github.com/pytorch/pytorch/issues/112577 - if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: - query_layer = query_layer.contiguous() - key_layer = key_layer.contiguous() - value_layer = value_layer.contiguous() - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create - # a causal mask in case tgt_len == 1. - is_causal = ( - True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False - ) - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_layer, - key_layer, - value_layer, - attn_mask=attention_mask, - dropout_p=self.dropout_prob if self.training else 0.0, - is_causal=is_causal, - ) + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + bsz, tgt_len, _ = hidden_states.size() - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) + query_layer = self.transpose_for_scores(self.query(hidden_states)) - outputs = (attn_output,) - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention + # mask needs to be such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + current_states = encoder_hidden_states if is_cross_attention else hidden_states + attention_mask = encoder_attention_mask if is_cross_attention else attention_mask + + # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning + if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: + key_layer, value_layer = past_key_value + else: + key_layer = self.transpose_for_scores(self.key(current_states)) + value_layer = self.transpose_for_scores(self.value(current_states)) + if past_key_value is not None and not is_cross_attention: + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + + if self.is_decoder: + past_key_value = (key_layer, value_layer) + + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom + # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. + # Reference: https://github.com/pytorch/pytorch/issues/112577 + if self.require_contiguous_qkv and query_layer.device.type == "cuda" and attention_mask is not None: + query_layer = query_layer.contiguous() + key_layer = key_layer.contiguous() + value_layer = value_layer.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create + # a causal mask in case tgt_len == 1. + is_causal = ( + True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False + ) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attention_mask, + dropout_p=self.dropout_prob if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) + + outputs = (attn_output,) + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs class BertSelfOutput(nn.Module): @@ -576,10 +578,12 @@ def forward( # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = tuple(self_attention_outputs[1:-1]) # Convert to tuple for consistency - present_key_value = tuple(self_attention_outputs[-1]) # convert to tuple + present_key_value = tuple(self_attention_outputs[-1]) # convert to tuple else: - outputs = tuple(self_attention_outputs[1:]) # add self attentions if we output attention weights, convert to tuple + outputs = tuple( + self_attention_outputs[1:] + ) # add self attentions if we output attention weights, convert to tuple cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: @@ -592,7 +596,7 @@ def forward( # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple # Explicitly handle the case where past_key_value is None cross_attn_past_key_value = tuple(past_key_value[-2:]) if past_key_value is not None else None - + cross_attention_outputs = self.crossattention( attention_output, attention_mask, @@ -603,14 +607,15 @@ def forward( output_attentions, ) attention_output = cross_attention_outputs[0] - outputs = outputs + tuple(cross_attention_outputs[1:-1]) # add cross attentions if we output attention weights, convert to tuple + outputs = outputs + tuple( + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights, convert to tuple # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = tuple(cross_attention_outputs[-1]) # convert to tuple - - # Ensure present_key_value is a tuple before adding - present_key_value = present_key_value + cross_attn_present_key_value + cross_attn_present_key_value = tuple(cross_attention_outputs[-1]) # convert to tuple + # Ensure present_key_value is a tuple before adding + present_key_value = present_key_value + cross_attn_present_key_value layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output From 213c91961de65965ced18b1f9dc947a0e988aff3 Mon Sep 17 00:00:00 2001 From: eleanorTurintech Date: Tue, 25 Feb 2025 20:01:01 +0000 Subject: [PATCH 5/5] ruff format --- src/transformers/models/bert/modeling_bert.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 3d819e0d40b7..cebeca858e27 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -368,7 +368,6 @@ def __init__(self, config, position_embedding_type=None): # Adapted from BertSelfAttention - def forward( self, hidden_states: torch.Tensor,