From bf3b79efb82676219a3275764d8fcf4c70097ce5 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Wed, 5 Feb 2025 13:31:38 -0800 Subject: [PATCH] [VLM] Qwen2.5-VL --- docs/source/models/supported_models.md | 11 + examples/offline_inference/vision_language.py | 31 + .../vision_language_multi_image.py | 58 + .../vision_language/test_models.py | 22 + .../multimodal/processing/test_common.py | 1 + tests/models/registry.py | 2 + vllm/entrypoints/chat_utils.py | 4 +- .../model_executor/layers/rotary_embedding.py | 58 +- vllm/model_executor/models/qwen2_5_vl.py | 1133 +++++++++++++++++ vllm/model_executor/models/qwen2_vl.py | 16 +- vllm/model_executor/models/registry.py | 1 + vllm/v1/worker/gpu_model_runner.py | 12 +- vllm/worker/cpu_model_runner.py | 9 +- vllm/worker/model_runner.py | 9 +- 14 files changed, 1315 insertions(+), 52 deletions(-) create mode 100644 vllm/model_executor/models/qwen2_5_vl.py diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index d8e2842929515..3e8b2f89642c4 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -846,6 +846,13 @@ See [this page](#generative-models) for more information on how to use generativ * ✅︎ * ✅︎ * ✅︎ +- * `Qwen2_5_VLForConditionalGeneration` + * Qwen2.5-VL + * T + IE+ + VE+ + * `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc. + * + * ✅︎ + * ✅︎ - * `UltravoxModel` * Ultravox * T + AE+ @@ -880,6 +887,10 @@ The chat template for Pixtral-HF is incorrect (see [discussion](https://huggingf A corrected version is available at . ::: +:::{note} +To use Qwen2.5-VL series models, you have to install Huggingface `transformers` library from source via `pip install git+https://github.com/huggingface/transformers`. +::: + ### Pooling Models See [this page](pooling-models) for more information on how to use pooling models. diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index 65940b6ada883..436c36570599a 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -531,6 +531,36 @@ def run_qwen2_vl(question: str, modality: str): return llm, prompt, stop_token_ids +# Qwen2.5-VL +def run_qwen2_5_vl(question: str, modality: str): + + model_name = "Qwen/Qwen2.5-VL-3B-Instruct" + + llm = LLM( + model=model_name, + max_model_len=4096, + max_num_seqs=5, + mm_processor_kwargs={ + "min_pixels": 28 * 28, + "max_pixels": 1280 * 28 * 28, + "fps": 1, + }, + disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache, + ) + + if modality == "image": + placeholder = "<|image_pad|>" + elif modality == "video": + placeholder = "<|video_pad|>" + + prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>" + f"{question}<|im_end|>\n" + "<|im_start|>assistant\n") + stop_token_ids = None + return llm, prompt, stop_token_ids + + model_example_map = { "aria": run_aria, "blip-2": run_blip2, @@ -557,6 +587,7 @@ def run_qwen2_vl(question: str, modality: str): "pixtral_hf": run_pixtral_hf, "qwen_vl": run_qwen_vl, "qwen2_vl": run_qwen2_vl, + "qwen2_5_vl": run_qwen2_5_vl, } diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 601ac96e16eac..8d2172a606f8d 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -392,6 +392,63 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData: ) +def load_qwen2_5_vl(question, image_urls: List[str]) -> ModelRequestData: + try: + from qwen_vl_utils import process_vision_info + except ModuleNotFoundError: + print('WARNING: `qwen-vl-utils` not installed, input images will not ' + 'be automatically resized. You can enable this functionality by ' + '`pip install qwen-vl-utils`.') + process_vision_info = None + + model_name = "Qwen/Qwen2.5-VL-3B-Instruct" + + llm = LLM( + model=model_name, + max_model_len=32768 if process_vision_info is None else 4096, + max_num_seqs=5, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = [{"type": "image", "image": url} for url in image_urls] + messages = [{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": + "user", + "content": [ + *placeholders, + { + "type": "text", + "text": question + }, + ], + }] + + processor = AutoProcessor.from_pretrained(model_name) + + prompt = processor.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + + stop_token_ids = None + + if process_vision_info is None: + image_data = [fetch_image(url) for url in image_urls] + else: + image_data, _ = process_vision_info(messages, + return_video_sample_fps=False) + + return ModelRequestData( + llm=llm, + prompt=prompt, + stop_token_ids=stop_token_ids, + image_data=image_data, + chat_template=None, + ) + + model_example_map = { "aria": load_aria, "deepseek_vl_v2": load_deepseek_vl2, @@ -404,6 +461,7 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData: "pixtral_hf": load_pixtral_hf, "qwen_vl_chat": load_qwen_vl_chat, "qwen2_vl": load_qwen2_vl, + "qwen2_5_vl": load_qwen2_5_vl, } diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 85bc4ac13182d..95505dcf5c29f 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -121,6 +121,8 @@ else ("half", "float")), marks=[pytest.mark.core_model], ), + # TODO(ywang96): Move Qwen2-VL out of core models in favor of Qwen2.5-VL + # once we upgraded to transformers>=4.49.0. "qwen2_vl": VLMTestInfo( models=["Qwen/Qwen2-VL-2B-Instruct"], test_type=( @@ -138,6 +140,26 @@ image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), + "qwen2_5_vl": VLMTestInfo( + models=["Qwen/Qwen2.5-VL-3B-Instruct"], + test_type=( + VLMTestType.IMAGE, + VLMTestType.MULTI_IMAGE, + VLMTestType.VIDEO + ), + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501 + video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", # noqa: E501 + max_model_len=4096, + max_num_seqs=2, + auto_cls=AutoModelForVision2Seq, + vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, + image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], + marks=[pytest.mark.skipif( + TRANSFORMERS_VERSION < "4.49.0", + reason="HF model requires transformers>=4.49.0", + ), pytest.mark.core_model, pytest.mark.cpu_model], + ), #### Extended model tests "aria": VLMTestInfo( models=["rhymes-ai/Aria"], diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py index 5cd749cbd7795..77cf3442df905 100644 --- a/tests/models/multimodal/processing/test_common.py +++ b/tests/models/multimodal/processing/test_common.py @@ -161,6 +161,7 @@ def _test_processing_correctness( "nvidia/NVLM-D-72B", "Qwen/Qwen-VL-Chat", "Qwen/Qwen2-VL-2B-Instruct", + "Qwen/Qwen2.5-VL-3B-Instruct", "Qwen/Qwen2-Audio-7B-Instruct", "fixie-ai/ultravox-v0_3", ]) diff --git a/tests/models/registry.py b/tests/models/registry.py index 285fbe4848090..20787fe008aa8 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -264,6 +264,8 @@ def check_available_online( trust_remote_code=True), "Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501 "Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501 + "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501 + min_transformers_version="4.49"), # noqa: E501 "UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3", trust_remote_code=True), # [Encoder-decoder] diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 3a6e75b1d8e58..f04902ae1c767 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -410,7 +410,7 @@ def _placeholder_str(self, modality: ModalityStr, return "" if model_type == "mllama": return "<|image|>" - if model_type == "qwen2_vl": + if model_type in ("qwen2_vl", "qwen2_5_vl"): return "<|vision_start|><|image_pad|><|vision_end|>" if model_type == "molmo": return "" @@ -430,7 +430,7 @@ def _placeholder_str(self, modality: ModalityStr, return "()" raise TypeError(f"Unknown model type: {model_type}") elif modality == "video": - if model_type == "qwen2_vl": + if model_type in ("qwen2_vl", "qwen2_5_vl"): return "<|vision_start|><|video_pad|><|vision_end|>" if model_type in ("minicpmo", "minicpmv"): return "()" diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 814c3b7d9cd83..b3b9b0e876057 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -27,6 +27,7 @@ import torch import torch.nn as nn +from transformers import PretrainedConfig from vllm.model_executor.custom_op import CustomOp @@ -772,8 +773,12 @@ def __init__( dtype: torch.dtype, mrope_section: Optional[List[int]] = None, ) -> None: - super().__init__(head_size, rotary_dim, max_position_embeddings, base, - is_neox_style, dtype) + # In Qwen2.5-VL, the maximum index value is related to the duration of + # the input video. We enlarge max_position_embeddings to 4 times to get + # a larger the cos and sin cache. + self.cache_max_position_num = max_position_embeddings * 4 + super().__init__(head_size, rotary_dim, self.cache_max_position_num, + base, is_neox_style, dtype) self.mrope_section = mrope_section if self.mrope_section: @@ -831,13 +836,10 @@ def forward( @staticmethod def get_input_positions( input_tokens: List[int], + hf_config: PretrainedConfig, image_grid_thw: Union[List[List[int]], torch.Tensor], video_grid_thw: Union[List[List[int]], torch.Tensor], - image_token_id: int, - video_token_id: int, - vision_start_token_id: int, - vision_end_token_id: int, - spatial_merge_size: int, + second_per_grid_ts: Optional[List[float]] = None, context_len: int = 0, seq_len: Optional[int] = None, ) -> Tuple[List[List[int]], int]: @@ -845,16 +847,13 @@ def get_input_positions( llm_positions, mrope_position_delta = \ MRotaryEmbedding.get_input_positions_tensor( - input_tokens, - image_grid_thw, - video_grid_thw, - image_token_id, - video_token_id, - vision_start_token_id, - vision_end_token_id, - spatial_merge_size, - context_len, - seq_len, + input_tokens=input_tokens, + hf_config=hf_config, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + context_len=context_len, + seq_len=seq_len, ) return llm_positions.tolist(), mrope_position_delta @@ -862,18 +861,22 @@ def get_input_positions( @staticmethod def get_input_positions_tensor( input_tokens: List[int], + hf_config: PretrainedConfig, image_grid_thw: Union[List[List[int]], torch.Tensor], video_grid_thw: Union[List[List[int]], torch.Tensor], - image_token_id: int, - video_token_id: int, - vision_start_token_id: int, - vision_end_token_id: int, - spatial_merge_size: int, + second_per_grid_ts: Optional[List[float]] = None, context_len: int = 0, seq_len: Optional[int] = None, ) -> Tuple[torch.Tensor, int]: """Get mrope input positions and delta value.""" + image_token_id = hf_config.image_token_id + video_token_id = hf_config.video_token_id + vision_start_token_id = hf_config.vision_start_token_id + spatial_merge_size = hf_config.vision_config.spatial_merge_size + tokens_per_second = getattr(hf_config.vision_config, + "tokens_per_second", 1.0) + if isinstance(image_grid_thw, torch.Tensor): image_grid_thw = image_grid_thw.tolist() if isinstance(video_grid_thw, torch.Tensor): @@ -892,6 +895,7 @@ def get_input_positions_tensor( image_index, video_index = 0, 0 for _ in range(image_nums + video_nums): + video_second_per_grid_t = 0.0 if image_token_id in input_tokens and remain_images > 0: ed_image = input_tokens.index(image_token_id, st) else: @@ -915,9 +919,13 @@ def get_input_positions_tensor( video_grid_thw[video_index][1], video_grid_thw[video_index][2], ) + video_second_per_grid_t = 1.0 + if second_per_grid_ts is not None: + video_second_per_grid_t = second_per_grid_ts[video_index] video_index += 1 remain_videos -= 1 ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = \ t, h // spatial_merge_size, w // spatial_merge_size text_len = ed - st @@ -927,8 +935,10 @@ def get_input_positions_tensor( llm_pos_ids_list.append( torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) - t_index = torch.arange(llm_grid_t).view(-1, 1).expand( - -1, llm_grid_h * llm_grid_w).flatten() + t_index = (torch.arange(llm_grid_t).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w) * video_second_per_grid_t * + tokens_per_second).long().flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( llm_grid_t, -1, llm_grid_w).flatten() w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py new file mode 100644 index 0000000000000..e93cf46b900b6 --- /dev/null +++ b/vllm/model_executor/models/qwen2_5_vl.py @@ -0,0 +1,1133 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +# Copyright 2025 The vLLM team. +# Copyright 2025 The Qwen Team. +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights.""" +from functools import cached_property, partial +from typing import (Callable, Iterable, List, Literal, Mapping, Optional, Set, + Tuple, TypedDict, Union) + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from transformers import BatchFeature +from transformers.models.qwen2_5_vl import (Qwen2_5_VLImageProcessor, + Qwen2_5_VLProcessor) +from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( + Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) + +from vllm.attention import AttentionMetadata +from vllm.config import VllmConfig +from vllm.distributed import parallel_state +from vllm.distributed import utils as dist_utils +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import MultiModalFieldConfig +from vllm.platforms import _Backend +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.config import uses_mrope + +from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP +from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder +from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo, + apply_rotary_pos_emb_vision) +from .utils import (AutoWeightsLoader, WeightsMapper, + init_vllm_registered_model, maybe_prefix, + merge_multimodal_embeddings) +from .vision import get_vit_attn_backend + +logger = init_logger(__name__) + +# === Vision Inputs === # + + +class Qwen2_5_VLImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + pixel_values: torch.Tensor + """Shape: + `(num_patches, num_channels * patch_size * patch_size)` + """ + + image_grid_thw: torch.Tensor + """Shape: `(num_images, 3)` + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + +class Qwen2_5_VLImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + image_embeds: torch.Tensor + """Supported types: + - List[`torch.Tensor`]: A list of tensors holding all images' features. + Each tensor holds an image's features. + - `torch.Tensor`: A tensor holding all images' features + (concatenation of all images' feature tensors). + + Tensor shape: `(num_image_features, hidden_size)` + - `num_image_features` varies based on + the number and resolution of the images. + - `hidden_size` must match the hidden size of language model backbone. + """ + + image_grid_thw: torch.Tensor + """Shape: `(num_images, 3)` + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + +Qwen2_5_VLImageInputs = Union[Qwen2_5_VLImagePixelInputs, + Qwen2_5_VLImageEmbeddingInputs] + + +class Qwen2_5_VLVideoPixelInputs(TypedDict): + type: Literal["pixel_values_videos"] + pixel_values_videos: torch.Tensor + """Shape: + `(num_patches, + num_channels * temporal_patch_size * patch_size * patch_size)` + """ + + video_grid_thw: torch.Tensor + """Shape: `(num_videos, 3)` + + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + second_per_grid_ts: torch.Tensor + """ + The video time interval (in seconds) for each grid along the temporal + dimension in the 3D position IDs. Returned when `videos` is not `None`. + """ + + +class Qwen2_5_VLVideoEmbeddingInputs(TypedDict): + type: Literal["video_embeds"] + video_embeds: torch.Tensor + """Supported types: + - List[`torch.Tensor`]: A list of tensors holding all videos' features. + Each tensor holds an video's features. + - `torch.Tensor`: A tensor holding all videos' features + (concatenation of all videos' feature tensors). + + Tensor shape: `(num_image_features, hidden_size)` + - `num_image_features` varies based on + the number and resolution of the videos. + - `hidden_size` must match the hidden size of language model backbone. + """ + + video_grid_thw: torch.Tensor + """Shape: `(num_videos, 3)` + This should be in `(grid_t, grid_h, grid_w)` format. + """ + + +Qwen2_5_VLVideoInputs = Union[Qwen2_5_VLVideoPixelInputs, + Qwen2_5_VLVideoEmbeddingInputs] + +# === Vision Encoder === # + + +class Qwen2_5_VisionMLP(nn.Module): + + def __init__(self, + in_features: int, + hidden_features: int, + bias: bool = False, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.gate_proj = ColumnParallelLinear(in_features, + hidden_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_proj") + self.up_proj = ColumnParallelLinear(in_features, + hidden_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.up_proj") + self.down_proj = RowParallelLinear(hidden_features, + in_features, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") + self.act_fn = act_fn + + def forward(self, x: torch.Tensor): + x_gate, _ = self.gate_proj(x) + x_gate = self.act_fn(x_gate) + x_up, _ = self.up_proj(x) + x_down, _ = self.down_proj(x_gate * x_up) + return x_down + + +class Qwen2_5_VisionAttention(nn.Module): + + def __init__( + self, + embed_dim: int, + num_heads: int, + projection_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + # Per attention head and per partition values. + world_size = parallel_state.get_tensor_model_parallel_world_size() + self.hidden_size_per_attention_head = dist_utils.divide( + projection_size, num_heads) + self.num_attention_heads_per_partition = dist_utils.divide( + num_heads, world_size) + + self.qkv = ColumnParallelLinear(input_size=embed_dim, + output_size=3 * projection_size, + quant_config=quant_config, + prefix=f"{prefix}.qkv") + self.proj = RowParallelLinear(input_size=projection_size, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.proj") + + # Detect attention implementation. + self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True) + if self.attn_backend not in { + _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS + }: + raise RuntimeError( + f"Qwen2.5-VL does not support {self.attn_backend} backend now." + ) + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor, + ) -> torch.Tensor: + # [s, b, c] --> [s, b, head * 3 * head_dim] + x, _ = self.qkv(x) + + # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim] + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + x = x.view(*new_x_shape) + + # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim] + q, k, v = dist_utils.split_tensor_along_last_dim(x, 3) + batch_size = q.shape[1] + + q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() + for x in (q, k, v)) + if rotary_pos_emb is not None: + q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) + k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) + + if self.attn_backend == _Backend.FLASH_ATTN: + # from vllm_flash_attn.flash_attn_interface import ( + # flash_attn_varlen_func) + from flash_attn import flash_attn_varlen_func + + q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + output = flash_attn_varlen_func(q, + k, + v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_k=max_seqlen, + dropout_p=0, + causal=False) + + context_layer = rearrange(output, + "(b s) ... -> b s ...", + b=batch_size) + elif self.attn_backend == _Backend.TORCH_SDPA: + seq_length = q.size(1) + q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v]) + attention_mask = torch.zeros([1, seq_length, seq_length], + device=q.device, + dtype=torch.bool) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i], + cu_seqlens[i - 1]:cu_seqlens[i]] = True + output = F.scaled_dot_product_attention(q, + k, + v, + attention_mask, + dropout_p=0.0) + context_layer = rearrange(output, "b h s d -> b s h d ") + elif self.attn_backend == _Backend.XFORMERS: + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalMask + + seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() + attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens, + kv_seqlen=None) + + context_layer = xops.memory_efficient_attention_forward( + q, k, v, attn_bias=attn_bias, p=0, scale=None) + context_layer = rearrange(context_layer, + "b s h d -> s b (h d)").contiguous() + + output, _ = self.proj(context_layer) + return output + + +class Qwen2RMSNorm(nn.Module): + + def __init__(self, hidden_size, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen2_5_VisionBlock(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int, + mlp_hidden_dim: int, + act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + self.attn = Qwen2_5_VisionAttention(embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn") + self.mlp = Qwen2_5_VisionMLP(dim, + mlp_hidden_dim, + act_fn=act_fn, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + + def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, + rotary_pos_emb: torch.Tensor) -> torch.Tensor: + x = x + self.attn(self.norm1(x), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb) + x = x + self.mlp(self.norm2(x)) + return x + + +class Qwen2_5_VisionPatchEmbed(nn.Module): + + def __init__( + self, + patch_size: int = 14, + temporal_patch_size: int = 2, + in_channels: int = 3, + hidden_size: int = 1152, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.hidden_size = hidden_size + + kernel_size = (temporal_patch_size, patch_size, patch_size) + self.proj = nn.Conv3d(in_channels, + hidden_size, + kernel_size=kernel_size, + stride=kernel_size, + bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + L, C = x.shape + x = x.view(L, -1, self.temporal_patch_size, self.patch_size, + self.patch_size) + x = self.proj(x).view(L, self.hidden_size) + return x + + +class Qwen2_5_VisionPatchMerger(nn.Module): + + def __init__( + self, + d_model: int, + context_dim: int, + norm_layer: Optional[Callable[[int], nn.Module]] = None, + spatial_merge_size: int = 2, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.ln_q = norm_layer(context_dim) + self.mlp = nn.ModuleList([ + ColumnParallelLinear(self.hidden_size, + self.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.0"), + nn.GELU(), + RowParallelLinear(self.hidden_size, + d_model, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.mlp.2"), + ]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.ln_q(x) + x = x.view(-1, self.hidden_size) + + mlp_fc1, mlp_act, mlp_fc2 = self.mlp + x_parallel, _ = mlp_fc1(x) + x_parallel = mlp_act(x_parallel) + out, _ = mlp_fc2(x_parallel) + return out + + +class Qwen2_5_VisionRotaryEmbedding(nn.Module): + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + self.dim = dim + self.theta = theta + inv_freq = 1.0 / (theta + **(torch.arange(0, dim, 2, dtype=torch.float) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self._seq_len_cached = 0 + self._freqs_cached = None + + def update_freqs_cache(self, seqlen: int) -> None: + if seqlen > self._seq_len_cached: + seqlen *= 2 + self._seq_len_cached = seqlen + self.inv_freq = 1.0 / (self.theta**(torch.arange( + 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device) + / self.dim)) + seq = torch.arange(seqlen, + device=self.inv_freq.device, + dtype=self.inv_freq.dtype) + freqs = torch.outer(seq, self.inv_freq) + self._freqs_cached = freqs + + def forward(self, seqlen: int) -> torch.Tensor: + self.update_freqs_cache(seqlen) + return self._freqs_cached[:seqlen] + + +class Qwen2_5_VisionTransformer(nn.Module): + + def __init__( + self, + vision_config: Qwen2_5_VLVisionConfig, + norm_eps: float = 1e-6, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + patch_size = vision_config.patch_size + temporal_patch_size = vision_config.temporal_patch_size + in_channels = vision_config.in_channels + depth = vision_config.depth + self.hidden_size = vision_config.hidden_size + self.num_heads = vision_config.num_heads + + # args for get_window_index + self.window_size = vision_config.window_size + self.patch_size = vision_config.patch_size + self.spatial_merge_size = vision_config.spatial_merge_size + self.fullatt_block_indexes = vision_config.fullatt_block_indexes + self.spatial_merge_unit = self.spatial_merge_size**2 + + self.patch_embed = Qwen2_5_VisionPatchEmbed( + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + in_channels=in_channels, + hidden_size=self.hidden_size, + ) + + # NOTE: We use torch native RMSNorm here for precision purposes. + norm_layer = partial(Qwen2RMSNorm, eps=norm_eps) + head_dim = self.hidden_size // self.num_heads + self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2) + + self.blocks = nn.ModuleList([ + Qwen2_5_VisionBlock( + dim=self.hidden_size, + num_heads=self.num_heads, + mlp_hidden_dim=vision_config.intermediate_size, + act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act], + norm_layer=norm_layer, + quant_config=quant_config, + prefix=f"{prefix}.blocks.{layer_idx}") + for layer_idx in range(depth) + ]) + self.merger = Qwen2_5_VisionPatchMerger( + d_model=vision_config.out_hidden_size, + context_dim=self.hidden_size, + norm_layer=norm_layer, + spatial_merge_size=self.spatial_merge_size, + quant_config=quant_config, + prefix=f"{prefix}.merger", + ) + + @property + def dtype(self) -> torch.dtype: + return self.patch_embed.proj.weight.dtype + + @property + def device(self) -> torch.device: + return self.patch_embed.proj.weight.device + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + pos_ids = [] + for t, h, w in grid_thw: + hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + hpos_ids = hpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + wpos_ids = wpos_ids.reshape( + h // self.spatial_merge_size, + self.spatial_merge_size, + w // self.spatial_merge_size, + self.spatial_merge_size, + ).permute(0, 2, 1, 3).flatten() + pos_ids.append( + torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) + pos_ids = torch.cat(pos_ids, dim=0) + max_grid_size = grid_thw[:, 1:].max() + rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size) + rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) + return rotary_pos_emb + + def get_window_index(self, grid_thw): + window_index: list = [] + cu_window_seqlens: list = [0] + window_index_id = 0 + vit_merger_window_size = (self.window_size // + self.spatial_merge_size // self.patch_size) + + for grid_t, grid_h, grid_w in grid_thw: + llm_grid_h = grid_h // self.spatial_merge_size + llm_grid_w = grid_w // self.spatial_merge_size + index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape( + grid_t, llm_grid_h, llm_grid_w) + pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size + pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size + num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size + num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size + index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100) + index_padded = index_padded.reshape(grid_t, num_windows_h, + vit_merger_window_size, + num_windows_w, + vit_merger_window_size) + index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape( + grid_t, num_windows_h * num_windows_w, vit_merger_window_size, + vit_merger_window_size) + seqlens = (index_padded != -100).sum([2, 3]).reshape(-1) + index_padded = index_padded.reshape(-1) + index_new = index_padded[index_padded != -100] + window_index.append(index_new + window_index_id) + cu_seqlens_tmp = seqlens.cumsum( + 0) * self.spatial_merge_unit + cu_window_seqlens[-1] + cu_window_seqlens.extend(cu_seqlens_tmp.tolist()) + window_index_id += (grid_t * llm_grid_h * llm_grid_w).item() + window_index = torch.cat(window_index, dim=0) + return window_index, cu_window_seqlens + + def forward( + self, + x: torch.Tensor, + grid_thw: torch.Tensor, + ) -> torch.Tensor: + # patchify + hidden_states = x.to(device=self.device, dtype=self.dtype) + hidden_states = self.patch_embed(hidden_states) + + # compute position embedding + rotary_pos_emb = self.rot_pos_emb(grid_thw) + + # windows attention + window_index, cu_window_seqlens = self.get_window_index(grid_thw) + cu_window_seqlens = torch.tensor( + cu_window_seqlens, + device=hidden_states.device, + dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32) + cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + hidden_states = hidden_states[window_index, :, :] + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape( + seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) + rotary_pos_emb = rotary_pos_emb[window_index, :, :] + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + # compute cu_seqlens + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0) + + # transformers + hidden_states = hidden_states.unsqueeze(1) + for layer_num, blk in enumerate(self.blocks): + if layer_num in self.fullatt_block_indexes: + cu_seqlens_now = cu_seqlens + else: + cu_seqlens_now = cu_window_seqlens + hidden_states = blk(hidden_states, + cu_seqlens=cu_seqlens_now, + rotary_pos_emb=rotary_pos_emb) + + # adapter + hidden_states = self.merger(hidden_states) + reverse_indices = torch.argsort(window_index) + hidden_states = hidden_states[reverse_indices, :] + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if name.endswith("qkv.weight"): + visual_num_heads = self.num_heads + visual_embed_dim = self.hidden_size + head_size = visual_embed_dim // visual_num_heads + loaded_weight = loaded_weight.view(3, visual_num_heads, + head_size, + visual_embed_dim) + loaded_weight = loaded_weight.transpose(0, 1) + loaded_weight = loaded_weight.reshape(-1, visual_embed_dim) + elif name.endswith("qkv.bias"): + visual_num_heads = self.num_heads + visual_embed_dim = self.hidden_size + head_size = visual_embed_dim // visual_num_heads + loaded_weight = loaded_weight.view(3, visual_num_heads, + head_size) + loaded_weight = loaded_weight.transpose(0, 1) + loaded_weight = loaded_weight.reshape(-1) + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class Qwen2_5_VLProcessingInfo(Qwen2VLProcessingInfo): + + def get_hf_config(self): + return self.ctx.get_hf_config(Qwen2_5_VLConfig) + + def get_hf_processor( + self, + *, + min_pixels: Optional[int] = None, + max_pixels: Optional[int] = None, + fps: Optional[float] = 2.0, + ) -> Qwen2_5_VLProcessor: + hf_processor = self.ctx.get_hf_processor(Qwen2_5_VLProcessor) + image_processor = hf_processor.image_processor # type: ignore + assert isinstance(image_processor, Qwen2_5_VLImageProcessor) + + if min_pixels: + image_processor.min_pixels = min_pixels + if max_pixels: + image_processor.max_pixels = max_pixels + if max_pixels or min_pixels: + image_processor.size = { + "min_pixels": image_processor.min_pixels, + "max_pixels": image_processor.max_pixels, + } + + return hf_processor + + def get_image_processor( + self, + *, + min_pixels: Optional[int] = None, + max_pixels: Optional[int] = None, + fps: Optional[float] = 2.0, + ) -> Qwen2_5_VLImageProcessor: + hf_processor = self.get_hf_processor( + min_pixels=min_pixels, + max_pixels=max_pixels, + fps=fps, + ) + image_processor = hf_processor.image_processor # type: ignore + assert isinstance(image_processor, Qwen2_5_VLImageProcessor) + return image_processor + + +class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor): + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + **super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs), + second_per_grid_ts=MultiModalFieldConfig.batched("video"), + ) + + +@MULTIMODAL_REGISTRY.register_processor( + Qwen2_5_VLMultiModalProcessor, + info=Qwen2_5_VLProcessingInfo, + dummy_inputs=Qwen2_5_VLDummyInputsBuilder) +class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsLoRA, SupportsPP): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ] + } + + # LoRA specific attributes, TODO: double check + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + "gate_proj" + "up_proj", + # vision tower + "qkv", + "attn.proj", # Distinguish patch_embed.proj + "fc1", + "fc2", + # projector + "mlp.0", + "mlp.2" + ] + embedding_modules = {} + embedding_padding_modules = [] + + # To ensure correct weight loading and mapping. + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={ + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + }) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + + self.config = config + self.multimodal_config = multimodal_config + + self.visual = Qwen2_5_VisionTransformer( + config.vision_config, + norm_eps=getattr(config, "rms_norm_eps", 1e-6), + quant_config=self._maybe_ignore_quant_config(quant_config), + prefix=maybe_prefix(prefix, "visual"), + ) + + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Qwen2ForCausalLM"], + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return get_sampler() + + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + # GPTQ configs do not have a list of ignored modules, however AutoGPTQ + # seems to avoid vision encoder sections for some models. + if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): + return None + return quant_config + + def _validate_and_reshape_mm_tensor(self, mm_input: object, + name: str) -> torch.Tensor: + if not isinstance(mm_input, (torch.Tensor, list)): + raise ValueError(f"Incorrect type of {name}. " + f"Got type: {type(mm_input)}") + if isinstance(mm_input, torch.Tensor): + if mm_input.ndim == 2: + return mm_input + if mm_input.ndim != 3: + raise ValueError(f"{name} should be 2D or batched 3D tensor. " + f"Got ndim: {mm_input.ndim} " + f"(shape={mm_input.shape})") + return torch.concat(list(mm_input)) + else: + return torch.concat(mm_input) + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[Qwen2_5_VLImageInputs]: + pixel_values = kwargs.pop("pixel_values", None) + image_embeds = kwargs.pop("image_embeds", None) + image_grid_thw = kwargs.pop("image_grid_thw", None) + + if pixel_values is None and image_embeds is None: + return None + + if pixel_values is not None: + pixel_values = self._validate_and_reshape_mm_tensor( + pixel_values, "image pixel values") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of image pixel values. " + f"Got type: {type(pixel_values)}") + + return Qwen2_5_VLImagePixelInputs(type="pixel_values", + pixel_values=pixel_values, + image_grid_thw=image_grid_thw) + + if image_embeds is not None: + image_embeds = self._validate_and_reshape_mm_tensor( + image_embeds, "image embeds") + image_grid_thw = self._validate_and_reshape_mm_tensor( + image_grid_thw, "image grid_thw") + + if not isinstance(image_embeds, torch.Tensor): + raise ValueError("Incorrect type of image embeddings. " + f"Got type: {type(image_embeds)}") + return Qwen2_5_VLImageEmbeddingInputs( + type="image_embeds", + image_embeds=image_embeds, + image_grid_thw=image_grid_thw) + + def _parse_and_validate_video_input( + self, **kwargs: object) -> Optional[Qwen2_5_VLVideoInputs]: + pixel_values_videos = kwargs.pop("pixel_values_videos", None) + video_embeds = kwargs.pop("video_embeds", None) + video_grid_thw = kwargs.pop("video_grid_thw", None) + second_per_grid_ts = kwargs.pop("second_per_grid_ts", None) + + if pixel_values_videos is None and video_embeds is None: + return None + + if pixel_values_videos is not None: + pixel_values_videos = self._validate_and_reshape_mm_tensor( + pixel_values_videos, "video pixel values") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + return Qwen2_5_VLVideoPixelInputs( + type="pixel_values_videos", + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + ) + + if video_embeds is not None: + video_embeds = self._validate_and_reshape_mm_tensor( + video_embeds, "video embeds") + video_grid_thw = self._validate_and_reshape_mm_tensor( + video_grid_thw, "video grid_thw") + + if not isinstance(video_embeds, torch.Tensor): + raise ValueError("Incorrect type of video embeddings. " + f"Got type: {type(video_embeds)}") + return Qwen2_5_VLVideoEmbeddingInputs( + type="video_embeds", + video_embeds=video_embeds, + video_grid_thw=video_grid_thw) + + def _process_image_input( + self, + image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]: + + grid_thw = image_input["image_grid_thw"] + assert grid_thw.ndim == 2 + + if image_input["type"] == "image_embeds": + image_embeds = image_input["image_embeds"].type(self.visual.dtype) + else: + pixel_values = image_input["pixel_values"].type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=grid_thw) + + # Split concatenated embeddings for each image item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return image_embeds.split(sizes.tolist()) + + def _process_video_input( + self, + video_input: Qwen2_5_VLVideoInputs) -> tuple[torch.Tensor, ...]: + + grid_thw = video_input["video_grid_thw"] + assert grid_thw.ndim == 2 + + if video_input["type"] == "video_embeds": + video_embeds = video_input["video_embeds"].type(self.visual.dtype) + else: + pixel_values_videos = video_input["pixel_values_videos"].type( + self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw) + + # Split concatenated embeddings for each video item. + merge_size = self.visual.spatial_merge_size + sizes = grid_thw.prod(-1) // merge_size // merge_size + + return video_embeds.split(sizes.tolist()) + + def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict: + modalities = {} + + # Preserve the order of modalities if there are multiple of them + # from the order of kwargs. + for input_key in kwargs: + if input_key in ("pixel_values", + "image_embeds") and "images" not in modalities: + modalities["images"] = self._parse_and_validate_image_input( + **kwargs) + if input_key in ("pixel_values_videos", + "video_embeds") and "videos" not in modalities: + modalities["videos"] = self._parse_and_validate_video_input( + **kwargs) + return modalities + + def get_multimodal_embeddings( + self, **kwargs) -> Optional[tuple[torch.Tensor, ...]]: + + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if not modalities: + return None + + # The result multimodal_embeddings is tuple of tensors, with each + # tensor correspoending to a multimodal data item (image or video). + multimodal_embeddings: tuple[torch.Tensor, ...] = () + + # NOTE: It is important to iterate over the keys in this dictionary + # to preserve the order of the modalities. + for modality in modalities: + if modality == "images": + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + multimodal_embeddings += vision_embeddings + if modality == "videos": + video_input = modalities["videos"] + video_embeddings = self._process_video_input(video_input) + multimodal_embeddings += video_embeddings + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[tuple[torch.Tensor, ...]] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + [self.config.image_token_id, self.config.video_token_id]) + return inputs_embeds + + def get_input_embeddings_v0( + self, + input_ids: torch.Tensor, + image_input: Optional[tuple[torch.Tensor, ...]] = None, + video_input: Optional[tuple[torch.Tensor, ...]] = None, + ) -> torch.Tensor: + + inputs_embeds = self.get_input_embeddings(input_ids) + if image_input is not None: + image_embeds = self._process_image_input(image_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + image_embeds, + placeholder_token_id=self.config.image_token_id, + ) + + if video_input is not None: + video_embeds = self._process_video_input(video_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, + inputs_embeds, + video_embeds, + placeholder_token_id=self.config.video_token_id, + ) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + """Run forward pass for Qwen2.5-VL. + + Args: + input_ids: Flattened (concatenated) input_ids corresponding to a + batch. + positions: Flattened (concatenated) position ids corresponding to a + batch. + **NOTE**: If mrope is enabled (default setting for Qwen2.5-VL + opensource models), the shape will be `(3, seq_len)`, + otherwise it will be `(seq_len,). + pixel_values: Pixel values to be fed to a model. + `None` if no images are passed. + image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. + `None` if no images are passed. + pixel_values_videos: Pixel values of videos to be fed to a model. + `None` if no videos are passed. + video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. + `None` if no videos are passed. + second_per_grid_ts: Tensor `(num_videos)` of video time interval ( + in seconds) for each grid along the temporal dimension in the + 3D position IDs. `None` if no videos are passed. + """ + + if intermediate_tensors is not None: + inputs_embeds = None + + # NOTE: In v1, inputs_embeds is always generated at model runner from + # `get_multimodal_embeddings` and `get_input_embeddings`, this + # condition is only for v0 compatibility. + elif inputs_embeds is None: + image_input = self._parse_and_validate_image_input(**kwargs) + video_input = self._parse_and_validate_video_input(**kwargs) + + if image_input is None and video_input is None: + inputs_embeds = None + else: + if uses_mrope(self.config): + assert positions.ndim == 2 and positions.size(0) == 3, ( + "multimodal section rotary embedding requires " + f"(3, seq_len) positions, but got {positions.size()}") + inputs_embeds = self.get_input_embeddings_v0( + input_ids, + image_input=image_input, + video_input=video_input) + input_ids = None + + hidden_states = self.language_model.model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + + loader = AutoWeightsLoader(self) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="visual.", + tower_model="visual.merger.") diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 2b2638cf68fc7..34ae7b8c94697 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -650,8 +650,8 @@ def load_weights(self, weights: Iterable[Tuple[str, return loaded_params -class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor], - dict[str, torch.Tensor]]): +class Qwen2VLEmbeddingItems(ModalityDataItems[dict[str, torch.Tensor], + dict[str, torch.Tensor]]): def __init__(self, data: dict, modality: str) -> None: super().__init__(data, modality) @@ -683,26 +683,26 @@ def get_passthrough_data(self) -> Mapping[str, object]: return self.data -class Qwen2ImageEmbeddingItems(Qwen2EmbeddingItems): +class Qwen2VLImageEmbeddingItems(Qwen2VLEmbeddingItems): def __init__(self, data: dict) -> None: super().__init__(data, "image") -class Qwen2VideoEmbeddingItems(Qwen2EmbeddingItems): +class Qwen2VLVideoEmbeddingItems(Qwen2VLEmbeddingItems): def __init__(self, data: dict) -> None: super().__init__(data, "video") -class Qwen2MultiModalDataParser(MultiModalDataParser): +class Qwen2VLMultiModalDataParser(MultiModalDataParser): def _parse_image_data( self, data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]], ) -> ModalityDataItems[Any, Any]: if isinstance(data, dict): - return Qwen2EmbeddingItems(data, modality="image") + return Qwen2VLEmbeddingItems(data, modality="image") return super()._parse_image_data(data) @@ -711,7 +711,7 @@ def _parse_video_data( data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]], ) -> ModalityDataItems[Any, Any]: if isinstance(data, dict): - return Qwen2EmbeddingItems(data, modality="video") + return Qwen2VLEmbeddingItems(data, modality="video") return super()._parse_video_data(data) @@ -948,7 +948,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo] ): def _get_data_parser(self) -> MultiModalDataParser: - return Qwen2MultiModalDataParser() + return Qwen2VLMultiModalDataParser() def _get_prompt_replacements( self, diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 962f95f10fc51..b6708f77d8aff 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -172,6 +172,7 @@ "PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501 "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501 + "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), # noqa: E501 "Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501 "UltravoxModel": ("ultravox", "UltravoxModel"), # [Encoder-decoder] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 7841fac1df34b..ec6d04cd49752 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -285,6 +285,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: if self.model_config.uses_mrope: image_grid_thw = [] video_grid_thw = [] + second_per_grid_ts = [] for mm_input in self.requests[req_id].mm_inputs: if mm_input.get("image_grid_thw") is not None: image_grid_thw.extend( @@ -292,6 +293,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: if mm_input.get("video_grid_thw") is not None: video_grid_thw.extend( mm_input["video_grid_thw"].tolist()) + if mm_input.get("second_per_grid_ts") is not None: + second_per_grid_ts.extend( + mm_input["second_per_grid_ts"]) hf_config = self.model_config.hf_config @@ -299,14 +303,10 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool: self.requests[req_id].mrope_position_delta = \ MRotaryEmbedding.get_input_positions_tensor( self.requests[req_id].prompt_token_ids, + hf_config=hf_config, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, - image_token_id=hf_config.image_token_id, - video_token_id=hf_config.video_token_id, - vision_start_token_id=hf_config.vision_start_token_id, - vision_end_token_id=hf_config.vision_end_token_id, - spatial_merge_size=hf_config.vision_config. - spatial_merge_size, + second_per_grid_ts=second_per_grid_ts, ) req_ids_to_add.append(req_id) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 1c3feece95a5a..9400893105d73 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -386,20 +386,17 @@ def _compute_multi_modal_input(self, "mrope embedding type requires multi-modal input mapper " "returns 'image_grid_thw' or 'video_grid_thw'.") + second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None) hf_config = self.runner.model_config.hf_config token_ids = seq_data.get_token_ids() mrope_positions, mrope_position_delta = \ MRotaryEmbedding.get_input_positions( token_ids, + hf_config=hf_config, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, - image_token_id=hf_config.image_token_id, - video_token_id=hf_config.video_token_id, - vision_start_token_id=hf_config.vision_start_token_id, - vision_end_token_id=hf_config.vision_end_token_id, - spatial_merge_size=hf_config.vision_config. - spatial_merge_size, + second_per_grid_ts=second_per_grid_ts, context_len=computed_len, ) seq_data.mrope_position_delta = mrope_position_delta diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 0bbba55b3b3f8..12baecde6e42c 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -702,6 +702,7 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, "mrope embedding type requires multi-modal input mapper " "returns 'image_grid_thw' or 'video_grid_thw'.") + second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None) hf_config = self.runner.model_config.hf_config inter_data.mrope_input_positions = [None] * inter_data.n_seqs @@ -713,14 +714,10 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup, mrope_input_positions, mrope_position_delta = \ MRotaryEmbedding.get_input_positions( token_ids, + hf_config=hf_config, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, - image_token_id=hf_config.image_token_id, - video_token_id=hf_config.video_token_id, - vision_start_token_id=hf_config.vision_start_token_id, - vision_end_token_id=hf_config.vision_end_token_id, - spatial_merge_size=hf_config.vision_config. - spatial_merge_size, + second_per_grid_ts=second_per_grid_ts, context_len=inter_data.context_lens[seq_idx], seq_len=inter_data.seq_lens[seq_idx], )