Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Iterative generation using Input embeds and static cache #35890

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

yaswanth19
Copy link

@yaswanth19 yaswanth19 commented Jan 25, 2025

What does this PR do?

Fixes #34678
Logic: If cache is present along with inputs_embeds then use inputs_embeds to generate first token for every prompt rather than only for the first token of the cache

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@zucchini-nlp

@yaswanth19
Copy link
Author

yaswanth19 commented Jan 25, 2025

Code which I am using to check the feature branch

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, StaticCache

model_id = "microsoft/phi-2"
model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

model.generation_config.max_new_tokens = 30

prompt_cache = StaticCache(config=model.config, batch_size=1, max_cache_len=1000)

INITIAL_PROMPT = "You are a helpful assistant. "
inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt")

inputs_embeds = model.get_input_embeddings()(inputs_initial_prompt.input_ids)
outputs = model.generate(inputs_embeds=inputs_embeds, past_key_values=prompt_cache)

response = tokenizer.batch_decode(outputs)[0]
print(response)

prompts = ["Help me to write a blogpost about travelling.", "Write a short note on AI"]
responses = []
for prompt in prompts:
    new_inputs = tokenizer(prompt, return_tensors="pt")
    new_input_ids = torch.cat([outputs, new_inputs.input_ids], dim=1)

    inputs_embeds = torch.cat([inputs_embeds,model.get_input_embeddings()(new_input_ids)],dim=1) # Necessary to align with cache

    outputs = model.generate(inputs_embeds=inputs_embeds, past_key_values=prompt_cache)
    response = tokenizer.batch_decode(outputs)[0]
    print(response)
    responses.append(response)

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR! We'd also need to add a test for continue_generate_from_inputs_embeds in https://github.com/huggingface/transformers/blob/main/tests/generation/test_utils.py

LMK if you need any help with adding/running tests :)

Comment on lines +386 to +388
if inputs_embeds is not None and input_ids.shape[1] == 0:
inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
elif (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a small comment, same way as there are 3 exceptions above?

@zucchini-nlp zucchini-nlp requested a review from gante January 27, 2025 08:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Bug when using StaticCache in Qwen2.5 Inference
2 participants