Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

concurrency without model cloning #573

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions optimum/intel/openvino/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(

self.model = model
self.request = None
self.compiled_model = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure why we need a new attribute here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is needed to create new infer_request in the context of generate method for each concurrent thread. So far we had in the model class request attribute which was pointing to a static infer_request and can not be used to allocate new request. Generally there is a bit confusing setup when the request attribute is set to the compiled_model object in the based class but latest it is overwritten to become the infer_request. Eventually the recommendation would be to switch to using compiled_model attribute instead and create infer_requests dynamically. It was proposed to make this switch in a separate PR.

if enable_compilation:
self.compile()

Expand Down
94 changes: 69 additions & 25 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import os
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Dict, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import openvino
Expand All @@ -28,7 +28,8 @@
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
echarlaix marked this conversation as resolved.
Show resolved Hide resolved

from transformers.utils import ModelOutput
from dataclasses import dataclass
from optimum.utils.normalized_config import NormalizedConfigManager

from ...exporters.openvino import ensure_stateful_is_available, main_export, patch_stateful
Expand All @@ -44,6 +45,23 @@

core = Core()

@dataclass
class OVCausalLMOutputWithPast(ModelOutput):
"""
Base class for causal language model (or autoregressive) outputs.

Args:
infer_request(`openvino.runtime.InferRequest` to be reused in the generation cycles.
beam_idx (`torch.Tensor` beam search algorimth context for the generation using stateful models
"""

loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
infer_request: Optional[openvino.runtime.InferRequest] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we rename it to something like request or inference_request ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current name is aligned with openvino api name, so for me infer_request sounds better

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that'd be clearer for users who are not familiar with the openvino ecosystem, also we don't use infer_request anywhere in optimum-intel so was thinking about something a bit more explicit

beam_idx: Optional[torch.Tensor] = None

TEXT_GENERATION_EXAMPLE = r"""
Example of text generation:
Expand Down Expand Up @@ -119,7 +137,6 @@ def __init__(
self.key_value_output_names = [key for key in self.output_names if "present" in key]
self._original_model = self.model.clone() # keep original model for serialization
self._pkv_precision = Type.f32
self.next_beam_idx = None
self.update_pkv_precision()
if self.is_dynamic:
self.model = self._reshape(self.model, -1, -1)
Expand Down Expand Up @@ -197,6 +214,7 @@ def update_pkv_precision(self, force_fp32=False):
if self.is_dynamic:
self.model = self._reshape(self.model, -1, -1)
self.request = None
self.compiled_model = None

def _save_pretrained(self, save_directory: Union[str, Path]):
"""
Expand Down Expand Up @@ -322,6 +340,7 @@ def normalized_config(self):
def compile(self):
if self.request is None:
super().compile()
self.compiled_model = self.request
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it could make sense to also set self.compiled_model to None (along with self.request) when the model is statically reshaped or moved to an other device https://github.com/huggingface/optimum-intel/blob/2a397e37dd606cdeafce6b356f5e7f869630ea1b/optimum/intel/openvino/modeling_base.py#L442C9-L442C21
an option could be to add a clear_requests method as done for seq2seq models
Currently it should work anyway as self.compiled_model will be correctly updated after calling .compile() (as self.request is set to None after each of these steps)

self.request = self.request.create_infer_request()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not remove this

Suggested change
self.request = self.request.create_infer_request()

and use self.request instead of self.compiled_model ? (self.request doesn't seem to be used anywhere)


def _make_stateful(self):
Expand All @@ -340,6 +359,13 @@ class OVModelForCausalLM(OVBaseDecoderModel, GenerationMixin):
export_feature = "text-generation"
auto_model_class = AutoModelForCausalLM

# def generate(self, *args, **kwargs):
# self.compile()
# if kwargs.get("infer_request") is None:
# infer_context = [self.compiled_model.create_infer_request()]
# kwargs["infer_context"] = infer_context
# return super().generate(*args, **kwargs)

@add_start_docstrings_to_model_forward(
INPUTS_DOCSTRING.format("batch_size, sequence_length")
+ TEXT_GENERATION_EXAMPLE.format(
Expand All @@ -354,6 +380,7 @@ def prepare_inputs(
attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
position_ids: Optional[torch.LongTensor] = None,
beam_idx: Optional[torch.tensor] = None,
**kwargs,
) -> Dict:
if self.use_cache and past_key_values is not None:
Expand All @@ -362,7 +389,6 @@ def prepare_inputs(
batch_size = input_ids.shape[0]
if self.config.model_type == "bloom":
batch_size *= self.config.num_attention_heads

inputs = {}
past_len = 0
if not self.stateful:
Expand Down Expand Up @@ -402,15 +428,6 @@ def prepare_inputs(
else:
shape[1] = 0
inputs[input_name] = Tensor(model_inputs.get_element_type(), shape.get_shape())
else:
# past_key_values are not used explicitly, instead they are handled inside the model
if past_key_values is None:
# This is the first iteration in a sequence, reset all states
if self.request is not None:
self.request.reset_state()
# Set initial value for the next beam_idx input that will be used at the current iteration
# and will be optionally updated by _reorder_cache at the next iterations if beam_search is used
self.next_beam_idx = np.arange(batch_size, dtype=int)

inputs["input_ids"] = np.array(input_ids)
# Add the attention_mask inputs when needed
Expand Down Expand Up @@ -438,7 +455,7 @@ def prepare_inputs(

if "beam_idx" in self.input_names:
inputs["beam_idx"] = (
self.next_beam_idx if self.next_beam_idx is not None else np.arange(batch_size, dtype=int)
beam_idx if beam_idx is not None else np.arange(batch_size, dtype=int)
)

return inputs
Expand All @@ -449,22 +466,28 @@ def forward(
attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
position_ids: Optional[torch.LongTensor] = None,
infer_request: Optional[openvino.runtime.InferRequest] = None,
beam_idx: torch.Tensor = None,
**kwargs,
) -> CausalLMOutputWithPast:
) -> OVCausalLMOutputWithPast:
self.compile()
echarlaix marked this conversation as resolved.
Show resolved Hide resolved

inputs = self.prepare_inputs(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
position_ids=position_ids,
beam_idx=beam_idx,
**kwargs,
)

# Run inference
self.request.start_async(inputs, share_inputs=True)
self.request.wait()
logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device)
if infer_request is None:
self.compile()
infer_request = self.compiled_model.create_infer_request()

infer_request.start_async(inputs, share_inputs=True)
infer_request.wait()
logits = torch.from_numpy(infer_request.get_tensor("logits").data).to(self.device)
if self.stateful:
# Need a marker to differentiate the first generate iteration from the others in
# the first condition at the function beginning above.
Expand All @@ -474,7 +497,7 @@ def forward(
if not self.stateful:
if self.use_cache:
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names)
past_key_values = tuple(infer_request.get_tensor(key).data for key in self.key_value_output_names)
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS:
# Tuple of tuple of length `n_layers`, with each tuple of length equal to 2 (k/v of self-attention)
past_key_values = tuple(
Expand All @@ -483,14 +506,31 @@ def forward(
else:
past_key_values = None

return CausalLMOutputWithPast(logits=logits, past_key_values=past_key_values)
return OVCausalLMOutputWithPast(logits=logits, past_key_values=past_key_values, infer_request=infer_request, beam_idx=beam_idx)

def _update_model_kwargs_for_generation(
self, outputs: OVCausalLMOutputWithPast,
model_kwargs: dict[str],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
) -> dict[str]:
model_kwargs = super()._update_model_kwargs_for_generation(
outputs=outputs,
model_kwargs=model_kwargs,
is_encoder_decoder=is_encoder_decoder,
standardize_cache_format=standardize_cache_format,
)
if "infer_request" in outputs: model_kwargs["infer_request"] = outputs["infer_request"]
if "beam_idx" in outputs: model_kwargs["beam_idx"] = outputs["beam_idx"]
return model_kwargs

# Adapted from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.prepare_inputs_for_generation
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
attention_mask = kwargs.get("attention_mask", None)
use_cache = kwargs.get("use_cache", None)

infer_request = kwargs.get("infer_request", None)
beam_idx = kwargs.get("beam_idx", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
Expand All @@ -503,6 +543,8 @@ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwarg
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"infer_request": infer_request,
"beam_idx": beam_idx,
"position_ids": position_ids,
"attention_mask": attention_mask,
}
Expand All @@ -519,7 +561,8 @@ def _reorder_cache(
if self.stateful:
# TODO: Apply it differently based on model type
# TODO: At least for bloom we need to replicate values for each attention head
self.next_beam_idx = np.array(beam_idx) # save beam_idx to be used as an input in the next iteration
# save beam_idx and infer_request to be used as an input in the next iteration

return past_key_values
else:
return tuple(
Expand Down Expand Up @@ -661,8 +704,7 @@ def _reorder_cache(
batch_size = beam_idx.shape[0]
indices = np.array(range(batch_size * self.config.num_attention_heads))
indices = indices.reshape([batch_size, self.config.num_attention_heads])
self.next_beam_idx = np.take(indices, beam_idx, 0).flatten()
return past_key_values
return ((np.take(indices, beam_idx, 0).flatten()), past_key_values[1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't it be :

Suggested change
return ((np.take(indices, beam_idx, 0).flatten()), past_key_values[1])
return past_key_values

else:
standardized_past = self._convert_to_standard_cache(past_key_values, batch_size=len(beam_idx))
reordered_past = tuple(
Expand Down Expand Up @@ -752,3 +794,5 @@ def _reorder_cache(
return past_key_values
else:
return tuple(np.take(layer_past, beam_idx, 0) for layer_past in past_key_values)


Loading