diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 8e48263c9300..cebeca858e27 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -202,19 +202,25 @@ def forward( if token_type_ids is None: if hasattr(self, "token_type_ids"): buffered_token_type_ids = self.token_type_ids[:, :seq_length] - buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + # Use repeat instead of expand for potential memory efficiency gains + buffered_token_type_ids_expanded = buffered_token_type_ids.repeat(input_shape[0], 1) token_type_ids = buffered_token_type_ids_expanded else: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) + + # Perform inplace addition for memory efficiency token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = inputs_embeds + token_type_embeddings # This creates new tensor. + del inputs_embeds, token_type_embeddings # Delete intermediate tensors to free memory - embeddings = inputs_embeds + token_type_embeddings if self.position_embedding_type == "absolute": position_embeddings = self.position_embeddings(position_ids) - embeddings += position_embeddings + embeddings += position_embeddings # Use in-place operation to reduce memory consumption + del position_embeddings # Delete to free memory if not used later. + embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings @@ -361,6 +367,7 @@ def __init__(self, config, position_embedding_type=None): self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") # Adapted from BertSelfAttention + def forward( self, hidden_states: torch.Tensor, @@ -371,25 +378,6 @@ def forward( past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: - if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. - logger.warning_once( - "BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " - "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " - "the manual attention implementation, but specifying the manual implementation will be required from " - "Transformers version v5.0.0 onwards. This warning can be removed using the argument " - '`attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - bsz, tgt_len, _ = hidden_states.size() query_layer = self.transpose_for_scores(self.query(hidden_states)) @@ -412,13 +400,6 @@ def forward( value_layer = torch.cat([past_key_value[1], value_layer], dim=2) if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_layer, value_layer) # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom @@ -581,7 +562,9 @@ def forward( output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor]: # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # Explicitly handle the case where past_key_value is None + self_attn_past_key_value = tuple(past_key_value[:2]) if past_key_value is not None else None + self_attention_outputs = self.attention( hidden_states, attention_mask, @@ -593,10 +576,13 @@ def forward( # if decoder, the last output is tuple of self-attn cache if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] + outputs = tuple(self_attention_outputs[1:-1]) # Convert to tuple for consistency + present_key_value = tuple(self_attention_outputs[-1]) # convert to tuple + else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + outputs = tuple( + self_attention_outputs[1:] + ) # add self attentions if we output attention weights, convert to tuple cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: @@ -607,7 +593,9 @@ def forward( ) # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + # Explicitly handle the case where past_key_value is None + cross_attn_past_key_value = tuple(past_key_value[-2:]) if past_key_value is not None else None + cross_attention_outputs = self.crossattention( attention_output, attention_mask, @@ -618,10 +606,14 @@ def forward( output_attentions, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + outputs = outputs + tuple( + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights, convert to tuple # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] + cross_attn_present_key_value = tuple(cross_attention_outputs[-1]) # convert to tuple + + # Ensure present_key_value is a tuple before adding present_key_value = present_key_value + cross_attn_present_key_value layer_output = apply_chunking_to_forward( @@ -631,7 +623,7 @@ def forward( # if decoder, return the attn key/values as the last output if self.is_decoder: - outputs = outputs + (present_key_value,) + outputs = outputs + (present_key_value,) # Ensure present_key_value is a tuple return outputs