Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mllama training via FSDP device and dtype misassignment #35880

Open
2 of 4 tasks
blbadger opened this issue Jan 24, 2025 · 1 comment
Open
2 of 4 tasks

Mllama training via FSDP device and dtype misassignment #35880

blbadger opened this issue Jan 24, 2025 · 1 comment
Labels

Comments

@blbadger
Copy link
Contributor

System Info

  • transformers version: 4.48.1
  • Platform: Linux-5.15.0-1073-azure-x86_64-with-glibc2.35
  • Python version: 3.11.0rc1
  • Huggingface_hub version: 0.27.1
  • Safetensors version: 0.5.2
  • Accelerate version: 1.3.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.5.1+cu124 (True)
  • Tensorflow version (GPU?): 2.15.0 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: Yes
  • Using GPU in script?: Yes
  • GPU type: NVIDIA A100-SXM4-40GB

Who can help?

@amyeroberts, @qubvel

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Shell command used to launch FSDP:

accelerate launch --config_file "configs/a100_config.yaml" train.py \
--seed 100 \
--model_name_or_path "/path/to/model" \
--dataset_path "/path/to/dataset" \
--add_special_tokens False \
--append_concat_token False \
--max_seq_len 1024 \
--num_train_epochs 15 \
--logging_steps 50 \
--log_level "info" \
--logging_strategy "steps" \
--evaluation_strategy "epoch" \
--save_strategy "epoch" \
--bf16 True\
--fp16 False \
--packing False \
--learning_rate 7e-5 \
--lr_scheduler_type "linear" \
--weight_decay 0.0 \
--warmup_ratio 0.0 \
--max_grad_norm 1.0 \
--output_dir "/path/to/output" \
--per_device_train_batch_size 1 \
--gradient_checkpointing True \
--use_reentrant True \
--dataset_text_field "content" \
--use_flash_attn False \
--use_peft_lora False \
--report_to "none"

train.py uses a standard SFTTrainer() trainer object called via trainer.train() after loading the processor and model, nothing fancy there.

 def main(model_args, data_args, training_args):
    ... # dataset, model, processor initialization

    sft_training_args = SFTConfig(
        output_dir=training_args.output_dir,
        gradient_checkpointing=training_args.gradient_checkpointing,
        bf16=training_args.bf16,
        remove_unused_columns=False,
        report_to=training_args.report_to,
        num_train_epochs=training_args.num_train_epochs,
        logging_steps=training_args.logging_steps,
        evaluation_strategy = training_args.evaluation_strategy,
        save_strategy = training_args.save_strategy,
        max_seq_length=data_args.max_seq_length
    )

    trainer = SFTTrainer(
        model=model,
        peft_config=peft_config,
        args=sft_training_args,
        data_collator=collate_fn,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        processing_class=processor.tokenizer,
    )
    trainer.train()
)

The FSDP config a100_config.yaml is as follows:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: true
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: false
machine_rank: 0
main_training_function: main
mixed_precision: 'bf16'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

note that we can specify mixed_precision: 'no' for the config and get the same error as well.

Expected behavior

FSDP (with offloaded parameters) training halts during the first forward pass. The resulting error,

 [rank4]:   File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 549, in _conv_forward
[rank4]:     return F.conv2d(
[rank4]:            ^^^^^^^^^
[rank4]: RuntimeError: Input type (torch.FloatTensor) and weight type (CUDABFloat16Type) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

is caused by modeling_mllama.py line 1541,

patch_embeds = self.patch_embedding(pixel_values.to(self.dtype).to(self.device))

AFAIK we cannot generally use .to() assignments referencing the model's device or datatype at the start of the forward pass when training via FSDP with offloaded params, as the parameters may not be located in the device or with the datatype that is used during the forward pass.

The simplest solution here is to simply omit both assignments: if we change line 1541 to

patch_embeds = self.patch_embedding(pixel_values)

the problem is eliminated. One would need to change the documentation for Llama 3.2 Vision, as inference involves sending processed inputs to the model's device but not dtype. For example, the official inference code snippet here

inputs = processor(image, prompt, return_tensors="pt").to(model.device)

needs to become

inputs = processor(image, prompt, return_tensors="pt").to(model.dtype).to(model.device)

If this is an acceptable approach I would be happy to put in the PR, just let me know!

@blbadger blbadger added the bug label Jan 24, 2025
@qubvel
Copy link
Member

qubvel commented Jan 27, 2025

Hi @blbadger, thanks for opening an issue and proposing a solution! As far as I know, other models in Transformers have a similar pattern. Although the fix might be viable for FSDP, it could be breaking for existing scenarios. I might not have enough experience with that, so let's get some opinions from @SunMarc or @muellerzr.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants