Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

qwen2.5vl ft #36

Open
SFTJBD opened this issue Feb 9, 2025 · 17 comments
Open

qwen2.5vl ft #36

SFTJBD opened this issue Feb 9, 2025 · 17 comments

Comments

@SFTJBD
Copy link

SFTJBD commented Feb 9, 2025

Thanks for the update supporting Qwen2.5-VL! The video data preparation is as you described, but I don't know why it's throwing an error.

in /training/data.py
pixel_values = torch.cat(batch_pixel_values, dim=0)
[rank0]: TypeError: expected Tensor as element 0 in argument 0, but got list

@2U1
Copy link
Owner

2U1 commented Feb 9, 2025

Does the latest code still has the problem? I've tested last week and it was okay then.

@SFTJBD
Copy link
Author

SFTJBD commented Feb 9, 2025

Does the latest code still has the problem? I've tested last week and it was okay then.

I will try it now. Maybe the code is not the latest.

@2U1
Copy link
Owner

2U1 commented Feb 9, 2025

Let me know if it still has the issue.
Thank you.

@Godheritage
Copy link

Let me know if it still has the issue. Thank you.

I met same problem, and I cloned codes on 2/10/2025.
The output is :
[163] [rank2]: File "/2022233313/Qwen2.5_VL/Qwen2-VL-Finetune/src/training/data.py", line 311, in call
[164] [rank2]: data_dict["second_per_grid_ts"] = torch.stack(batch_second_per_grid_ts)
[165] [rank2]: TypeError: expected Tensor as element 0 in argument 0, but got list

my data json looks like:
[
{
"id": "xxx",
"video": "xxx.mp4",
"conversations": [
{
"from": "human",
"value": "

@Godheritage
Copy link

Let me know if it still has the issue. Thank you.

I met same problem, and I cloned codes on 2/10/2025. The output is : [163] [rank2]: File "/2022233313/Qwen2.5_VL/Qwen2-VL-Finetune/src/training/data.py", line 311, in call [164] [rank2]: data_dict["second_per_grid_ts"] = torch.stack(batch_second_per_grid_ts) [165] [rank2]: TypeError: expected Tensor as element 0 in argument 0, but got list

my data json looks like: [ { "id": "xxx", "video": "xxx.mp4", "conversations": [ { "from": "human", "value": "

I changed data.py code line 288 to "batch_second_per_grid_ts.extend([torch.tensor(ts) for ts in example["second_per_grid_ts"]])"
It seems to work, but I met a new problem:

File "/opt/conda/envs/qwen2/lib/python3.10/site-packages/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py", line 1627, in get_rope_index
[375] [rank2]: time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second
[376] [rank2]: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:2 and CPU!

This problem may caused by qwen2_5_vl code in transformers, could you create a monkey function fixing it? Thanks a lot!

@2U1
Copy link
Owner

2U1 commented Feb 10, 2025

@Godheritage Sorry, I've noticed I tested video only on Qwen2-VL not 2.5. I'll take a look.
Sorry for the inconvinicence.

@2U1
Copy link
Owner

2U1 commented Feb 11, 2025

@Godheritage I've fixed the code. It works fine now.
It doesn't need monkey patching, my collator had to be fixed a bit for flattening the video batch_second_per_grid_thw.
Thanks for reporting!

@Godheritage
Copy link

@2U1 Thanks for your quick response and generous work, I have succeeded in finetuning Qwen2.5 VL with my videos now.

@SFTJBD
Copy link
Author

SFTJBD commented Feb 16, 2025

Let me know if it still has the issue. Thank you.

Thank you for the update. The video data training is running fine. However, a new issue has arisen: I'm trying to train using
both video and image data from the same source, but the program keeps freezing with no error messages. (The code currently handles both video and image inputs, so the problem likely lies in the Dataloader.) What do you think are some good approaches to address this?

@2U1
Copy link
Owner

2U1 commented Feb 16, 2025

@SFTJBD You need to use zero2 to finetune video+image. It's a bit tricky to use it with zero3 so I'm working on it.
I've wrote it in the README.

@2U1
Copy link
Owner

2U1 commented Feb 18, 2025

@SFTJBD I've updated the code to support zero3 with all kind of mixed-modality data.

@SFTJBD
Copy link
Author

SFTJBD commented Feb 19, 2025

@SFTJBD I've updated the code to support zero3 with all kind of mixed-modality data.

Thank you for the update. I am trying to use both video and image data within the same 'value' field in my data. Does the current code support this functionality? (Similar to: "conversations": [ { "from": "human", "value": "<video>\n<image>\n My question." }, ])

@2U1
Copy link
Owner

2U1 commented Feb 19, 2025

@SFTJBD I wasn't cosidering the case for that. The code should be fixed for using in the case.

@SFTJBD
Copy link
Author

SFTJBD commented Feb 20, 2025

@SFTJBD I wasn't cosidering the case for that. The code should be fixed for using in the case.

OK, I got it. So, putting video and image data within multi-turn conversations will also encounter the same problem, right? (I've made simple modifications to the processor's input to ensure it can handle both modalities simultaneously, but it's still not running properly during training. Could this be a problem caused by DeepSpeed?)

@2U1
Copy link
Owner

2U1 commented Feb 20, 2025

@SFTJBD Yes, it's likely a DeepSpeed issue. To resolve it, you'll need to directly modify the forward function (in the monkey patching file). When using DeepSpeed, all GPUs must have the same CUDA graph, meaning every GPU needs to pass through the visual processing the same number of times. However, if video and image data are mixed in one batch, some GPUs might go through the visual pathway twice, while in batches with only a single image (or in text-only cases due to a dummy), they only pass through once. Ensuring that the visual processing happens twice in every batch should allow it to work correctly.

@2U1
Copy link
Owner

2U1 commented Feb 20, 2025

@SFTJBD I'll make a quick fix with it soon. Thanks for letting me know.

@2U1
Copy link
Owner

2U1 commented Feb 20, 2025

@SFTJBD I can't test it right now but

def qwen2_5_mixed_modality_forward(
    self,
    input_ids: torch.LongTensor = None,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_values: Optional[List[torch.FloatTensor]] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    labels: Optional[torch.LongTensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
    pixel_values: Optional[torch.Tensor] = None,
    pixel_values_videos: Optional[torch.FloatTensor] = None,
    image_grid_thw: Optional[torch.LongTensor] = None,
    video_grid_thw: Optional[torch.LongTensor] = None,
    rope_deltas: Optional[torch.LongTensor] = None,
    cache_position: Optional[torch.LongTensor] = None,
    second_per_grid_ts: Optional[torch.Tensor] = None,
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:

    output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
    output_hidden_states = (
        output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
    )
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    dummy_pixel_img = torch.zeros(14308, 1176).to(self.visual.device)
    dummy_grid_img = torch.tensor([[1, 98, 146]], device=self.visual.device)
    dummy_pixel_vid = torch.zeros(14308, 1176).to(self.visual.device)
    dummy_grid_vid = torch.tensor([[1, 98, 146]], device=self.visual.device)

    if inputs_embeds is None:
        inputs_embeds = self.model.embed_tokens(input_ids)
        if pixel_values is not None:
            pixel_values = pixel_values.type(self.visual.dtype)
            image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
            n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
            if image_embeds.shape[0] != n_image_tokens:
                raise ValueError(
                    f"Image features({image_embeds.shape[0]}) and tokens({n_image_tokens}) mismatch."
                )
            mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
            image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
            inputs_embeds = inputs_embeds.masked_scatter(mask, image_embeds)
        else:
            _ = self.visual(dummy_pixel_img, grid_thw=dummy_grid_img)

        if pixel_values_videos is not None:
            pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
            video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
            n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
            if video_embeds.shape[0] != n_video_tokens:
                raise ValueError(
                    f"Video features({video_embeds.shape[0]}) and tokens({n_video_tokens}) mismatch."
                )
            mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
            video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
            inputs_embeds = inputs_embeds.masked_scatter(mask, video_embeds)
        else:
            _ = self.visual(dummy_pixel_vid, grid_thw=dummy_grid_vid)

        if attention_mask is not None:
            attention_mask = attention_mask.to(inputs_embeds.device)

    if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
        # calculate RoPE index once per generation in the pre-fill stage only
        if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
            position_ids, rope_deltas = self.get_rope_index(
                input_ids, image_grid_thw, video_grid_thw, second_per_grid_ts, attention_mask,
            )
            self.rope_deltas = rope_deltas
        else:
            batch_size, seq_length, _ = inputs_embeds.shape
            delta = (
                (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
                if cache_position is not None
                else 0
            )
            position_ids = torch.arange(seq_length, device=inputs_embeds.device)
            position_ids = position_ids.view(1, -1).expand(batch_size, -1)
            if cache_position is not None:
                delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
            position_ids = position_ids.add(delta)
            position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)

    outputs = self.model(
        input_ids=None,
        position_ids=position_ids,
        attention_mask=attention_mask,
        past_key_values=past_key_values,
        inputs_embeds=inputs_embeds,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
        cache_position=cache_position,
    )

    hidden_states = outputs[0]

    loss = None
    logits = None

    if self.training and (labels is not None):
        shift_hidden_states = hidden_states[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()

        shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
        shift_labels = shift_labels.view(-1)

        lce = LigerFusedLinearCrossEntropyLoss()
        loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
    else:
        logits = self.lm_head(hidden_states)
        if labels is not None:
            logits = logits.float()
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            loss_fct = CrossEntropyLoss()
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

    if not return_dict:
        output = (logits,) + outputs[1:]
        return (loss,) + output if loss is not None else output

    return Qwen2_5_VLCausalLMOutputWithPast(
        loss=loss,
        logits=logits,
        past_key_values=outputs.past_key_values,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
        rope_deltas=self.rope_deltas,
    )

Maybe this could reslove the problem for you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants