Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Iterative generation using Input embeds and static cache #35890

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,13 @@ def prepare_inputs_for_generation(
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)
# Excpetion 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
# generate the first token for each sequence. Later use the generated Input ids for continuation.
if past_key_values is not None:
model_inputs["past_key_values"] = past_key_values
if (
if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
elif (
inputs_embeds is not None # Exception 1
or (is_torchdynamo_compiling() or cache_position[-1] >= input_ids.shape[1]) # Exception 3
):
Expand All @@ -393,9 +397,9 @@ def prepare_inputs_for_generation(

# 3. Prepare base model inputs
input_ids_key = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step for every prompt.
if not self.config.is_encoder_decoder:
if inputs_embeds is not None and cache_position[0] == 0:
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
model_inputs[input_ids_key] = None
model_inputs["inputs_embeds"] = inputs_embeds
else:
Expand Down
13 changes: 9 additions & 4 deletions src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,15 +892,20 @@ def prepare_inputs_for_generation(

# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here.
# Exception 3: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
# generate the first token for each sequence. Later use the generated Input ids for continuation.
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
if inputs_embeds is not None:
if input_ids.shape[1] == 0: # Exception 3
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
else:
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the
Expand Down
8 changes: 5 additions & 3 deletions src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1652,8 +1652,10 @@ def prepare_inputs_for_generation(
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
if inputs_embeds is not None:
if input_ids.shape[1] == 0:
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
input_ids = input_ids[:, -cache_position.shape[0] :] # Exception 1
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]

Expand All @@ -1665,7 +1667,7 @@ def prepare_inputs_for_generation(
position_ids = position_ids[:, -input_ids.shape[1] :]

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
Expand Down
9 changes: 6 additions & 3 deletions src/transformers/models/idefics/modeling_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1674,10 +1674,13 @@ def prepare_inputs_for_generation(
else:
model_inputs["pixel_values"] = pixel_values

# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# If we have cache: let's slice `input_ids` or `input embeds` through `cache_position`, to keep only the unprocessed tokens
if past_key_values is not None:
if inputs_embeds is not None:
input_ids = input_ids[:, -cache_position.shape[0] :]
if input_ids.shape[1] == 0:
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
else:
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]:
input_ids = input_ids[:, cache_position]
if image_attention_mask is not None:
Expand All @@ -1694,7 +1697,7 @@ def prepare_inputs_for_generation(
position_ids = position_ids.clone(memory_format=torch.contiguous_format)

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
model_inputs.update({"inputs_embeds": inputs_embeds, "input_ids": None})
else:
# The clone here is for the same reason as for `position_ids`.
Expand Down
13 changes: 9 additions & 4 deletions src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1864,10 +1864,15 @@ def prepare_inputs_for_generation(

# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here.
# Excpetion 3: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
# generate the first token for each sequence. Later use the generated Input ids for continuation.
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
if inputs_embeds is not None:
if input_ids.shape[1] == 0: # Exception 3
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
else:
input_ids = input_ids[:, -cache_position.shape[0] :] # Exception 1
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]

Expand All @@ -1876,7 +1881,7 @@ def prepare_inputs_for_generation(
pixel_values_videos = None

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids, "inputs_embeds": None}
Expand Down
13 changes: 9 additions & 4 deletions src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,10 +763,15 @@ def prepare_inputs_for_generation(

# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here.
# Excpetion 3: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
# generate the first token for each sequence. Later use the generated Input ids for continuation.
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
if inputs_embeds is not None:
if input_ids.shape[1] == 0: # Exception 3
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
else:
input_ids = input_ids[:, -cache_position.shape[0] :] # Exception 1
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]

Expand All @@ -775,7 +780,7 @@ def prepare_inputs_for_generation(
pixel_values_videos = None

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids, "inputs_embeds": None}
Expand Down
13 changes: 9 additions & 4 deletions src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1727,10 +1727,15 @@ def prepare_inputs_for_generation(

# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here.
# Excpetion 3: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
# generate the first token for each sequence. Later use the generated Input ids for continuation.
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
if inputs_embeds is not None:
if input_ids.shape[1] == 0: # Exception 3
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
else:
input_ids = input_ids[:, -cache_position.shape[0] :] # Exception 1
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]

Expand All @@ -1739,7 +1744,7 @@ def prepare_inputs_for_generation(
pixel_values_videos = None

# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
model_inputs = {"input_ids": input_ids, "inputs_embeds": None}
Expand Down
77 changes: 77 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1857,6 +1857,83 @@ def test_generate_continue_from_past_key_values(self):
)
)

@pytest.mark.generate
def test_generate_continue_from_inputs_embeds(self):
"""Tests that we can continue generation from `inputs_embeds` and past key values returned from a previous `generate` call."""
for model_class in self.all_generative_model_classes:
if any(model_name in model_class.__name__.lower() for model_name in ["imagegpt"]):
self.skipTest(reason="Won't fix: old model with unique inputs/caches/other")
if any(model_name in model_class.__name__.lower() for model_name in ["umt5"]):
self.skipTest(reason="TODO: needs modeling or test input preparation fixes for compatibility")

config, inputs_dict = self.prepare_config_and_inputs_for_generate()

if "token_type_ids" in inputs_dict:
del inputs_dict["token_type_ids"]

if config.is_encoder_decoder:
self.skipTest(reason="This model is encoder-decoder")
if not hasattr(config, "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")

model = model_class(config).to(torch_device).eval()

if "inputs_embeds" not in inspect.signature(model.prepare_inputs_for_generation).parameters.keys():
self.skipTest(reason="This model does not support `inputs_embeds` in generation")

# If "past_key_values" is not returned, skip the test (e.g. RWKV uses a different cache name and format)
outputs = model(**inputs_dict)
if "past_key_values" not in outputs:
self.skipTest(reason="This model doesn't return `past_key_values`")

pixel_values_is_mutually_exclusive = any(
model_name in model_class.__name__.lower()
for model_name in ["llava", "idefics2", "idefics3", "mllama", "paligemma", "emu3"]
)
if pixel_values_is_mutually_exclusive:
inputs_dict.pop("pixel_values", None)
inputs_dict.pop("pixel_values_videos", None)
inputs_dict.pop("pixel_values_images", None)

input_ids = inputs_dict.pop("input_ids")

model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1
model.generation_config.forced_eos_token_id = None
model.config.is_decoder = True
model.generation_config.use_cache = True

generation_kwargs = {
"return_dict_in_generate": True,
"do_sample": False,
}

# Traditional way of generating text, with `return_dict_in_generate` to return the past key values.
input_embeds = model.get_input_embeddings()(input_ids)
outputs = model.generate(inputs_embeds=input_embeds, max_new_tokens=4, **generation_kwargs)

# Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens)
initial_output = model.generate(inputs_embeds=input_embeds, max_new_tokens=3, **generation_kwargs)
continued_embeds = torch.cat([input_embeds, model.get_input_embeddings()(initial_output.sequences)], dim=1)
cached_output = model.generate(
inputs_embeds=continued_embeds,
max_new_tokens=1,
past_key_values=initial_output.past_key_values,
**generation_kwargs,
)

# Combine the (3 + 1) generated tokens and verify it matches with full generation.
combined_output_sequences = torch.concat([initial_output.sequences, cached_output.sequences], axis=1)
self.assertListEqual(outputs.sequences.tolist(), combined_output_sequences.tolist())
# The two sets of past kv should be equal to each other
for layer_idx in range(len(cached_output.past_key_values)):
for kv_idx in range(len(cached_output.past_key_values[layer_idx])):
self.assertTrue(
torch.allclose(
outputs.past_key_values[layer_idx][kv_idx],
cached_output.past_key_values[layer_idx][kv_idx],
)
)

@parameterized.expand([("offloaded",)]) # ("offloaded_static",) TODO: @raushan fixme in some models (eg T5)
@require_torch_gpu
@pytest.mark.generate
Expand Down
4 changes: 4 additions & 0 deletions tests/models/clvp/test_modeling_clvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,10 @@ def test_training_gradient_checkpointing(self):
loss = model(**inputs).loss
loss.backward()

@unittest.skip(reason="Clvp `prepare_inputs_for_generation` function doesn't have cache position.")
def test_generate_continue_from_inputs_embeds(self):
pass


class ClvpModelForConditionalGenerationTester:
def __init__(self, parent, is_training=False):
Expand Down
4 changes: 4 additions & 0 deletions tests/models/cohere2/test_modeling_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ def test_generate_with_static_cache(self):
def test_generate_from_inputs_embeds_with_static_cache(self):
pass

@unittest.skip("Cohere2 has HybridCache and doesn't support progressive generation using input embeds.")
def test_generate_continue_from_inputs_embeds(self):
pass

# overwrite because HybridCache has fixed length for key/values
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
Expand Down
4 changes: 4 additions & 0 deletions tests/models/fuyu/test_modeling_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,10 @@ def test_disk_offload_safetensors(self):
def test_model_parallelism(self):
super().test_model_parallelism()

@unittest.skip(reason="Fuyu `prepare_inputs_for_generation` function doesn't have cache position.")
def test_generate_continue_from_inputs_embeds():
pass


@slow
@require_torch_accelerator
Expand Down
4 changes: 4 additions & 0 deletions tests/models/gemma2/test_modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ def test_generate_with_static_cache(self):
def test_generate_from_inputs_embeds_with_static_cache(self):
pass

@unittest.skip("Gemma2 has HybridCache and doesn't support StaticCache. Though it could, it shouldn't support.")
def test_generate_continue_from_inputs_embeds(self):
pass

# overwrite because HybridCache has fixed length for key/values
def _check_attentions_for_generate(
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
Expand Down
4 changes: 4 additions & 0 deletions tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,10 @@ def test_disk_offload(self):
def test_past_key_values_format(self):
pass

@unittest.skip(reason="BigCodeGPT has a non-standard KV cache format and breaks this test.")
def test_generate_continue_from_inputs_embeds(self):
pass

def test_gpt_bigcode_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_gpt_bigcode_model(*config_and_inputs)
Expand Down
8 changes: 8 additions & 0 deletions tests/models/moshi/test_modeling_moshi.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,10 @@ def test_disk_offload_bin(self):
def test_disk_offload_safetensors(self):
pass

@unittest.skip(reason="Test becomes too complex with Moshi requiring multiple modalities input.")
def test_generate_continue_from_inputs_embeds(self):
pass

@is_flaky(max_attempts=5, description="flaky on some models.")
def test_save_load(self):
super().test_save_load()
Expand Down Expand Up @@ -919,6 +923,10 @@ def test_disk_offload_bin(self):
def test_disk_offload_safetensors(self):
pass

@unittest.skip(reason="Test becomes too complex with Moshi requiring multiple modalities")
def test_generate_continue_from_inputs_embeds(self):
pass

@is_flaky(max_attempts=5, description="flaky on some models.")
def test_save_load(self):
super().test_save_load()
Expand Down
4 changes: 4 additions & 0 deletions tests/models/zamba2/test_modeling_zamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,10 @@ def test_past_key_values_format(self):
"""
pass

@unittest.skip(reason="Zamba2 has hybrid cache.")
def test_generate_continue_from_inputs_embeds(self):
pass

@unittest.skip(reason="A large mamba2 would be necessary (and costly) for that")
def test_multi_gpu_data_parallel_forward(self):
pass
Expand Down