You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
note that we can specify mixed_precision: 'no' for the config and get the same error as well.
Expected behavior
FSDP (with offloaded parameters) training halts during the first forward pass. The resulting error,
[rank4]: File "/local_disk0/.ephemeral_nfs/cluster_libraries/python/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 549, in _conv_forward
[rank4]: return F.conv2d(
[rank4]: ^^^^^^^^^
[rank4]: RuntimeError: Input type (torch.FloatTensor) and weight type (CUDABFloat16Type) should be the same or input should be a MKLDNN tensor and weight is a dense tensor
AFAIK we cannot generally use .to() assignments referencing the model's device or datatype at the start of the forward pass when training via FSDP with offloaded params, as the parameters may not be located in the device or with the datatype that is used during the forward pass.
The simplest solution here is to simply omit both assignments: if we change line 1541 to
patch_embeds=self.patch_embedding(pixel_values)
the problem is eliminated. One would need to change the documentation for Llama 3.2 Vision, as inference involves sending processed inputs to the model's device but not dtype. For example, the official inference code snippet here
Hi @blbadger, thanks for opening an issue and proposing a solution! As far as I know, other models in Transformers have a similar pattern. Although the fix might be viable for FSDP, it could be breaking for existing scenarios. I might not have enough experience with that, so let's get some opinions from @SunMarc or @muellerzr.
System Info
transformers
version: 4.48.1Who can help?
@amyeroberts, @qubvel
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Shell command used to launch FSDP:
train.py
uses a standardSFTTrainer()
trainer object called viatrainer.train()
after loading the processor and model, nothing fancy there.The FSDP config
a100_config.yaml
is as follows:note that we can specify
mixed_precision: 'no'
for the config and get the same error as well.Expected behavior
FSDP (with offloaded parameters) training halts during the first forward pass. The resulting error,
is caused by
modeling_mllama.py
line 1541,AFAIK we cannot generally use
.to()
assignments referencing the model's device or datatype at the start of the forward pass when training via FSDP with offloaded params, as the parameters may not be located in the device or with the datatype that is used during the forward pass.The simplest solution here is to simply omit both assignments: if we change line 1541 to
the problem is eliminated. One would need to change the documentation for Llama 3.2 Vision, as inference involves sending processed inputs to the model's device but not dtype. For example, the official inference code snippet here
needs to become
If this is an acceptable approach I would be happy to put in the PR, just let me know!
The text was updated successfully, but these errors were encountered: