Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Optimize BERT memory usage and improve code readability #36401

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 29 additions & 37 deletions src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,19 +202,25 @@ def forward(
if token_type_ids is None:
if hasattr(self, "token_type_ids"):
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
# Use repeat instead of expand for potential memory efficiency gains
buffered_token_type_ids_expanded = buffered_token_type_ids.repeat(input_shape[0], 1)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)

# Perform inplace addition for memory efficiency
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings # This creates new tensor.
del inputs_embeds, token_type_embeddings # Delete intermediate tensors to free memory

embeddings = inputs_embeds + token_type_embeddings
if self.position_embedding_type == "absolute":
position_embeddings = self.position_embeddings(position_ids)
embeddings += position_embeddings
embeddings += position_embeddings # Use in-place operation to reduce memory consumption
del position_embeddings # Delete to free memory if not used later.

embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
Expand Down Expand Up @@ -361,6 +367,7 @@ def __init__(self, config, position_embedding_type=None):
self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")

# Adapted from BertSelfAttention

def forward(
self,
hidden_states: torch.Tensor,
Expand All @@ -371,25 +378,6 @@ def forward(
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
logger.warning_once(
"BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to "
"the manual attention implementation, but specifying the manual implementation will be required from "
"Transformers version v5.0.0 onwards. This warning can be removed using the argument "
'`attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)

bsz, tgt_len, _ = hidden_states.size()

query_layer = self.transpose_for_scores(self.query(hidden_states))
Expand All @@ -412,13 +400,6 @@ def forward(
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)

# SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom
Expand Down Expand Up @@ -581,7 +562,9 @@ def forward(
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# Explicitly handle the case where past_key_value is None
self_attn_past_key_value = tuple(past_key_value[:2]) if past_key_value is not None else None

self_attention_outputs = self.attention(
hidden_states,
attention_mask,
Expand All @@ -593,10 +576,13 @@ def forward(

# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
outputs = tuple(self_attention_outputs[1:-1]) # Convert to tuple for consistency
present_key_value = tuple(self_attention_outputs[-1]) # convert to tuple

else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
outputs = tuple(
self_attention_outputs[1:]
) # add self attentions if we output attention weights, convert to tuple

cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
Expand All @@ -607,7 +593,9 @@ def forward(
)

# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
# Explicitly handle the case where past_key_value is None
cross_attn_past_key_value = tuple(past_key_value[-2:]) if past_key_value is not None else None

cross_attention_outputs = self.crossattention(
attention_output,
attention_mask,
Expand All @@ -618,10 +606,14 @@ def forward(
output_attentions,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
outputs = outputs + tuple(
cross_attention_outputs[1:-1]
) # add cross attentions if we output attention weights, convert to tuple

# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
cross_attn_present_key_value = tuple(cross_attention_outputs[-1]) # convert to tuple

# Ensure present_key_value is a tuple before adding
present_key_value = present_key_value + cross_attn_present_key_value

layer_output = apply_chunking_to_forward(
Expand All @@ -631,7 +623,7 @@ def forward(

# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)
outputs = outputs + (present_key_value,) # Ensure present_key_value is a tuple

return outputs

Expand Down