Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

forward() got an unexpected keyword argument 'num_items_in_batch' #35838

Open
2 of 4 tasks
Bachstelze opened this issue Jan 22, 2025 · 13 comments · May be fixed by #35875
Open
2 of 4 tasks

forward() got an unexpected keyword argument 'num_items_in_batch' #35838

Bachstelze opened this issue Jan 22, 2025 · 13 comments · May be fixed by #35875
Assignees
Labels

Comments

@Bachstelze
Copy link

System Info

New versions can't train encoder-decoder models.
Related issue and pull request: #34575
System-Info:

  • transformers version: 4.48.1
  • Platform: Linux-6.8.0-36-generic-x86_64-with-glibc2.39
  • Python version: 3.12.8
  • Huggingface_hub version: 0.24.6
  • Safetensors version: 0.4.5
  • Accelerate version: 1.2.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: Yes
  • GPU type: Tesla V100-PCIE-32GB
Traceback (most recent call last):
  File "/home/hilsenbek/workspace/thesis/syntax_transformer/training/train_cross_attention.py", line 110, in <module>
    trainer.train()
  File "/home/hilsenbek/.conda/envs/harness/lib/python3.12/site-packages/transformers/trainer.py", line 2171, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/hilsenbek/.conda/envs/harness/lib/python3.12/site-packages/transformers/trainer.py", line 2531, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hilsenbek/.conda/envs/harness/lib/python3.12/site-packages/transformers/trainer.py", line 3675, in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hilsenbek/.conda/envs/harness/lib/python3.12/site-packages/transformers/trainer.py", line 3731, in compute_loss
    outputs = model(**inputs)
              ^^^^^^^^^^^^^^^
  File "/home/hilsenbek/.conda/envs/harness/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hilsenbek/.conda/envs/harness/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hilsenbek/.conda/envs/harness/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/hilsenbek/.conda/envs/harness/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hilsenbek/.conda/envs/harness/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hilsenbek/.conda/envs/harness/lib/python3.12/site-packages/accelerate/utils/operations.py", line 823, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hilsenbek/.conda/envs/harness/lib/python3.12/site-packages/accelerate/utils/operations.py", line 811, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hilsenbek/.conda/envs/harness/lib/python3.12/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/hilsenbek/.conda/envs/harness/lib/python3.12/site-packages/transformers/models/encoder_decoder/modeling_encoder_decoder.py", line 603, in forward
    encoder_outputs = self.encoder(
                      ^^^^^^^^^^^^^
  File "/home/hilsenbek/.conda/envs/harness/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/hilsenbek/.conda/envs/harness/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: RobertaModel.forward() got an unexpected keyword argument 'num_items_in_batch'

Who can help?

@ArthurZucker
@gheinrich

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

follow the blog https://huggingface.co/blog/encoder-decoder

Expected behavior

Work as in old transformer versions

@Bachstelze Bachstelze added the bug label Jan 22, 2025
@Rocketknight1
Copy link
Member

This seems related to the trainer changes - cc @muellerzr @SunMarc

@shubhamjain0594
Copy link

Getting same error for the Gemma Model.

@SilverSoldier
Copy link
Contributor

Same for bloom which is marking unexpected arguments as deprecated and throws ValueError: Got unexpected arguments: {'num_items_in_batch': 5120}.

Seems to be these 3 lines causing the problem:

loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs}
outputs = model(**inputs)

@ArthurZucker
Copy link
Collaborator

We'll do a patch as soon as there is a fix!

@SunMarc
Copy link
Member

SunMarc commented Jan 23, 2025

Can you share the traceback for the gemma model error @shubhamjain0594 ?

For the bloom error, this can be easily fixed by setting accepts_loss_kwargs = False in bloom modeling code. This happens because for bloom, we allow to pass kwargs hence the issue.

For the encoder decoder, this is because we allow to pass **kwargs in the forward + kwargs_encoder is not set correctly.

I'll let @muellerzr decide how to fix these. Maybe the easiest fix would be to just set accepts_loss_kwargs = True for models that supports it.

@shubhamjain0594
Copy link

Image

@SunMarc here you go. Does this help?

@SunMarc
Copy link
Member

SunMarc commented Jan 23, 2025

Yeah thanks ! The issue comes from the attention refactor PR where we pass **kwargs in the loss calculation and in the model. cc @muellerzr

@muellerzr muellerzr self-assigned this Jan 23, 2025
@muellerzr
Copy link
Contributor

@shubhamjain0594 can you post a repr of the model you're using and how the Trainer is configured? I can't recreate this with "google/gemma-2-2b-it"

# End-to-end script running the Hugging Face Trainer
# for causal language modeling. Based on the Tasks documentation
# originally from: https://hf.co/docs/transformers/tasks/language_modeling
from accelerate import PartialState
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
)

# Constants
model_name = "google/gemma-2-2b-it"
dataset_name = "wikitext"
dataset_config = "wikitext-2-raw-v1"

# Load dataset
print(f"Downloading dataset ({dataset_name})")
dataset = load_dataset(dataset_name, dataset_config, split="train[:500]")
dataset = dataset.train_test_split(test_size=0.2)

# Tokenize the dataset
tokenizer = AutoTokenizer.from_pretrained(model_name)


def tokenize_function(examples):
    return tokenizer(examples["text"])


print(f"Tokenizing dataset for {model_name}...")
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)

# We still need to concatenate our sequences
# and split them into shorter chunks to ease
# minimal RAM usage
block_size = 128


def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    if total_length >= block_size:
        total_length = (total_length // block_size) * block_size
    # Split by chunks of block_size.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result


# And apply
tokenized_dataset = tokenized_dataset.map(group_texts, batched=True)

# Create an efficient collator which dynamically pads
# End-of-sequence as the padding token and mlm=False will
# use the inputs as labels, shifted to the right by one element
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

print(f"Instantiating model ({model_name})...")
model = AutoModelForCausalLM.from_pretrained(model_name)

# Define the hyperparameters in the TrainingArguments
print("Creating training arguments (weights are stored at `results/causal_language_modeling`)...")
training_args = TrainingArguments(
    output_dir="results/causal_language_modeling",  # Where weights are stored
    learning_rate=2e-5,  # The learning rate during training
    per_device_train_batch_size=1,  # Number of samples per batch during training
    per_device_eval_batch_size=1,  # Number of samples per batch during evaluation
    gradient_accumulation_steps=2,
    num_train_epochs=2,  # How many iterations through the dataloaders should be done
    weight_decay=0.01,  # Regularization penalization
)

# Create the `Trainer`, passing in the model and arguments
# the datasets to train on, how the data should be collated,
# and the method for computing our metrics
print("Creating `Trainer`...")
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    data_collator=data_collator,
)

# Initiate training
print("Training...")
trainer.train()```

@muellerzr
Copy link
Contributor

Or @SilverSoldier or @Bachstelze

@shubhamjain0594
Copy link

shubhamjain0594 commented Jan 23, 2025

@shubhamjain0594 can you post a repr of the model you're using and how the Trainer is configured? I can't recreate this with "google/gemma-2-2b-it"

I am using google/gemma-1.1-2b-it, maybe that is the difference?

@muellerzr
Copy link
Contributor

@shubhamjain0594 can recreate it with that model, thanks! :)

@muellerzr
Copy link
Contributor

Essentially this stems from certain model forwards not accepting kwargs, which is an issue on our end.

Said problem models:

FAILED tests/models/bamba/test_modeling_bamba.py::BambaModelTest::test_training_gradient_accumulation - TypeError: BambaModel.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/bert_generation/test_modeling_bert_generation.py::BertGenerationEncoderTest::test_training_gradient_accumulation - TypeError: BertGenerationDecoder.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/big_bird/test_modeling_big_bird.py::BigBirdModelTest::test_training_gradient_accumulation - TypeError: BigBirdForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/biogpt/test_modeling_biogpt.py::BioGptModelTest::test_training_gradient_accumulation - TypeError: BioGptForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/bloom/test_modeling_bloom.py::BloomModelTest::test_training_gradient_accumulation - ValueError: Got unexpected arguments: {'num_items_in_batch': 14}
FAILED tests/models/codegen/test_modeling_codegen.py::CodeGenModelTest::test_training_gradient_accumulation - TypeError: CodeGenForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/ctrl/test_modeling_ctrl.py::CTRLModelTest::test_training_gradient_accumulation - TypeError: CTRLLMHeadModel.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/data2vec/test_modeling_data2vec_text.py::Data2VecTextModelTest::test_training_gradient_accumulation - TypeError: Data2VecTextForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/dbrx/test_modeling_dbrx.py::DbrxModelTest::test_training_gradient_accumulation - TypeError: DbrxForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/electra/test_modeling_electra.py::ElectraModelTest::test_training_gradient_accumulation - TypeError: ElectraForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/ernie/test_modeling_ernie.py::ErnieModelTest::test_training_gradient_accumulation - TypeError: ErnieForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/falcon/test_modeling_falcon.py::FalconModelTest::test_training_gradient_accumulation - TypeError: FalconForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/fuyu/test_modeling_fuyu.py::FuyuModelTest::test_training_gradient_accumulation - TypeError: FuyuForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/gemma/test_modeling_gemma.py::GemmaModelTest::test_training_gradient_accumulation - TypeError: GemmaModel.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/gemma2/test_modeling_gemma2.py::GemmaModelTest::test_training_gradient_accumulation - TypeError: GemmaModel.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/git/test_modeling_git.py::GitModelTest::test_training_gradient_accumulation - TypeError: GitForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/gpt2/test_modeling_gpt2.py::GPT2ModelTest::test_training_gradient_accumulation - TypeError: GPT2LMHeadModel.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py::GPTBigCodeModelTest::test_training_gradient_accumulation - TypeError: GPTBigCodeForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/gpt_bigcode/test_modeling_gpt_bigcode.py::GPTBigCodeMHAModelTest::test_training_gradient_accumulation - TypeError: GPTBigCodeForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/gpt_neo/test_modeling_gpt_neo.py::GPTNeoModelTest::test_training_gradient_accumulation - TypeError: GPTNeoForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/gpt_neox/test_modeling_gpt_neox.py::GPTNeoXModelTest::test_training_gradient_accumulation - TypeError: GPTNeoXForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/gpt_neox_japanese/test_modeling_gpt_neox_japanese.py::GPTNeoXModelJapaneseTest::test_training_gradient_accumulation - TypeError: GPTNeoXJapaneseForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/gptj/test_modeling_gptj.py::GPTJModelTest::test_training_gradient_accumulation - TypeError: GPTJForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/granitemoe/test_modeling_granitemoe.py::GraniteMoeModelTest::test_training_gradient_accumulation - TypeError: GraniteMoeForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/helium/test_modeling_helium.py::GemmaModelTest::test_training_gradient_accumulation - TypeError: GemmaModel.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/jetmoe/test_modeling_jetmoe.py::JetMoeModelTest::test_training_gradient_accumulation - TypeError: JetMoeForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/megatron_bert/test_modeling_megatron_bert.py::MegatronBertModelTest::test_training_gradient_accumulation - TypeError: MegatronBertForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/mllama/test_modeling_mllama.py::MllamaForCausalLMModelTest::test_training_gradient_accumulation - ValueError: Unrecognized configuration class <class 'transformers.models.mllama.configuration_mllama.MllamaTextConfig'> for this kind of AutoModel: AutoModelForCausalLM.
FAILED tests/models/moshi/test_modeling_moshi.py::MoshiDecoderTest::test_training_gradient_accumulation - TypeError: MoshiForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/mpt/test_modeling_mpt.py::MptModelTest::test_training_gradient_accumulation - TypeError: MptForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/musicgen/test_modeling_musicgen.py::MusicgenDecoderTest::test_training_gradient_accumulation - ValueError: Unrecognized configuration class <class 'transformers.models.musicgen.configuration_musicgen.MusicgenDecoderConfig'> for this kind of AutoModel: AutoModelForCausalLM.
FAILED tests/models/musicgen_melody/test_modeling_musicgen_melody.py::MusicgenMelodyDecoderTest::test_training_gradient_accumulation - ValueError: Unrecognized configuration class <class 'transformers.models.musicgen_melody.configuration_musicgen_melody.MusicgenMelodyDecoderConfig'> for this kind of AutoModel: ...
FAILED tests/models/nemotron/test_modeling_nemotron.py::GemmaModelTest::test_training_gradient_accumulation - TypeError: GemmaModel.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/openai/test_modeling_openai.py::OpenAIGPTModelTest::test_training_gradient_accumulation - TypeError: OpenAIGPTLMHeadModel.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/opt/test_modeling_opt.py::OPTModelTest::test_training_gradient_accumulation - TypeError: OPTForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/persimmon/test_modeling_persimmon.py::PersimmonModelTest::test_training_gradient_accumulation - TypeError: PersimmonForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py::RecurrentGemmaModelTest::test_training_gradient_accumulation - TypeError: RecurrentGemmaForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/reformer/test_modeling_reformer.py::ReformerLocalAttnModelTest::test_training_gradient_accumulation - TypeError: ReformerModelWithLMHead.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/rembert/test_modeling_rembert.py::RemBertModelTest::test_training_gradient_accumulation - TypeError: RemBertForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/roberta/test_modeling_roberta.py::RobertaModelTest::test_training_gradient_accumulation - TypeError: RobertaForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/roberta_prelayernorm/test_modeling_roberta_prelayernorm.py::RobertaPreLayerNormModelTest::test_training_gradient_accumulation - TypeError: RobertaPreLayerNormForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/roc_bert/test_modeling_roc_bert.py::RoCBertModelTest::test_training_gradient_accumulation - TypeError: RoCBertForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/roformer/test_modeling_roformer.py::RoFormerModelTest::test_training_gradient_accumulation - TypeError: RoFormerForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/rwkv/test_modeling_rwkv.py::RwkvModelTest::test_training_gradient_accumulation - TypeError: RwkvForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/stablelm/test_modeling_stablelm.py::StableLmModelTest::test_training_gradient_accumulation - TypeError: StableLmForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/xglm/test_modeling_xglm.py::XGLMModelTest::test_training_gradient_accumulation - TypeError: XGLMForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/xlm/test_modeling_xlm.py::XLMModelTest::test_training_gradient_accumulation - TypeError: XLMWithLMHeadModel.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/xlm_roberta_xl/test_modeling_xlm_roberta_xl.py::XLMRobertaXLModelTest::test_training_gradient_accumulation - TypeError: XLMRobertaXLForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'
FAILED tests/models/xmod/test_modeling_xmod.py::XmodModelTest::test_training_gradient_accumulation - TypeError: XmodForCausalLM.forward() got an unexpected keyword argument 'num_items_in_batch'

@muellerzr muellerzr linked a pull request Jan 24, 2025 that will close this issue
5 tasks
@techkang
Copy link
Contributor

I think the better way to fix this bug is delete:

**kwargs: Unpack[KwargsForCausalLM],

and

Forcefully enabling variable args but not using num_items_in_batch for loss calculation will make training loss gradient_accumulation_steps times larger than before.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

8 participants