From bf3b79efb82676219a3275764d8fcf4c70097ce5 Mon Sep 17 00:00:00 2001
From: Roger Wang <136131678+ywang96@users.noreply.github.com>
Date: Wed, 5 Feb 2025 13:31:38 -0800
Subject: [PATCH] [VLM] Qwen2.5-VL
---
docs/source/models/supported_models.md | 11 +
examples/offline_inference/vision_language.py | 31 +
.../vision_language_multi_image.py | 58 +
.../vision_language/test_models.py | 22 +
.../multimodal/processing/test_common.py | 1 +
tests/models/registry.py | 2 +
vllm/entrypoints/chat_utils.py | 4 +-
.../model_executor/layers/rotary_embedding.py | 58 +-
vllm/model_executor/models/qwen2_5_vl.py | 1133 +++++++++++++++++
vllm/model_executor/models/qwen2_vl.py | 16 +-
vllm/model_executor/models/registry.py | 1 +
vllm/v1/worker/gpu_model_runner.py | 12 +-
vllm/worker/cpu_model_runner.py | 9 +-
vllm/worker/model_runner.py | 9 +-
14 files changed, 1315 insertions(+), 52 deletions(-)
create mode 100644 vllm/model_executor/models/qwen2_5_vl.py
diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md
index d8e2842929515..3e8b2f89642c4 100644
--- a/docs/source/models/supported_models.md
+++ b/docs/source/models/supported_models.md
@@ -846,6 +846,13 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎
* ✅︎
* ✅︎
+- * `Qwen2_5_VLForConditionalGeneration`
+ * Qwen2.5-VL
+ * T + IE+ + VE+
+ * `Qwen/Qwen2.5-VL-3B-Instruct`, `Qwen/Qwen2.5-VL-72B-Instruct`, etc.
+ *
+ * ✅︎
+ * ✅︎
- * `UltravoxModel`
* Ultravox
* T + AE+
@@ -880,6 +887,10 @@ The chat template for Pixtral-HF is incorrect (see [discussion](https://huggingf
A corrected version is available at .
:::
+:::{note}
+To use Qwen2.5-VL series models, you have to install Huggingface `transformers` library from source via `pip install git+https://github.com/huggingface/transformers`.
+:::
+
### Pooling Models
See [this page](pooling-models) for more information on how to use pooling models.
diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py
index 65940b6ada883..436c36570599a 100644
--- a/examples/offline_inference/vision_language.py
+++ b/examples/offline_inference/vision_language.py
@@ -531,6 +531,36 @@ def run_qwen2_vl(question: str, modality: str):
return llm, prompt, stop_token_ids
+# Qwen2.5-VL
+def run_qwen2_5_vl(question: str, modality: str):
+
+ model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
+
+ llm = LLM(
+ model=model_name,
+ max_model_len=4096,
+ max_num_seqs=5,
+ mm_processor_kwargs={
+ "min_pixels": 28 * 28,
+ "max_pixels": 1280 * 28 * 28,
+ "fps": 1,
+ },
+ disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
+ )
+
+ if modality == "image":
+ placeholder = "<|image_pad|>"
+ elif modality == "video":
+ placeholder = "<|video_pad|>"
+
+ prompt = ("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
+ f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
+ f"{question}<|im_end|>\n"
+ "<|im_start|>assistant\n")
+ stop_token_ids = None
+ return llm, prompt, stop_token_ids
+
+
model_example_map = {
"aria": run_aria,
"blip-2": run_blip2,
@@ -557,6 +587,7 @@ def run_qwen2_vl(question: str, modality: str):
"pixtral_hf": run_pixtral_hf,
"qwen_vl": run_qwen_vl,
"qwen2_vl": run_qwen2_vl,
+ "qwen2_5_vl": run_qwen2_5_vl,
}
diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py
index 601ac96e16eac..8d2172a606f8d 100644
--- a/examples/offline_inference/vision_language_multi_image.py
+++ b/examples/offline_inference/vision_language_multi_image.py
@@ -392,6 +392,63 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
)
+def load_qwen2_5_vl(question, image_urls: List[str]) -> ModelRequestData:
+ try:
+ from qwen_vl_utils import process_vision_info
+ except ModuleNotFoundError:
+ print('WARNING: `qwen-vl-utils` not installed, input images will not '
+ 'be automatically resized. You can enable this functionality by '
+ '`pip install qwen-vl-utils`.')
+ process_vision_info = None
+
+ model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
+
+ llm = LLM(
+ model=model_name,
+ max_model_len=32768 if process_vision_info is None else 4096,
+ max_num_seqs=5,
+ limit_mm_per_prompt={"image": len(image_urls)},
+ )
+
+ placeholders = [{"type": "image", "image": url} for url in image_urls]
+ messages = [{
+ "role": "system",
+ "content": "You are a helpful assistant."
+ }, {
+ "role":
+ "user",
+ "content": [
+ *placeholders,
+ {
+ "type": "text",
+ "text": question
+ },
+ ],
+ }]
+
+ processor = AutoProcessor.from_pretrained(model_name)
+
+ prompt = processor.apply_chat_template(messages,
+ tokenize=False,
+ add_generation_prompt=True)
+
+ stop_token_ids = None
+
+ if process_vision_info is None:
+ image_data = [fetch_image(url) for url in image_urls]
+ else:
+ image_data, _ = process_vision_info(messages,
+ return_video_sample_fps=False)
+
+ return ModelRequestData(
+ llm=llm,
+ prompt=prompt,
+ stop_token_ids=stop_token_ids,
+ image_data=image_data,
+ chat_template=None,
+ )
+
+
model_example_map = {
"aria": load_aria,
"deepseek_vl_v2": load_deepseek_vl2,
@@ -404,6 +461,7 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData:
"pixtral_hf": load_pixtral_hf,
"qwen_vl_chat": load_qwen_vl_chat,
"qwen2_vl": load_qwen2_vl,
+ "qwen2_5_vl": load_qwen2_5_vl,
}
diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py
index 85bc4ac13182d..95505dcf5c29f 100644
--- a/tests/models/decoder_only/vision_language/test_models.py
+++ b/tests/models/decoder_only/vision_language/test_models.py
@@ -121,6 +121,8 @@
else ("half", "float")),
marks=[pytest.mark.core_model],
),
+ # TODO(ywang96): Move Qwen2-VL out of core models in favor of Qwen2.5-VL
+ # once we upgraded to transformers>=4.49.0.
"qwen2_vl": VLMTestInfo(
models=["Qwen/Qwen2-VL-2B-Instruct"],
test_type=(
@@ -138,6 +140,26 @@
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
+ "qwen2_5_vl": VLMTestInfo(
+ models=["Qwen/Qwen2.5-VL-3B-Instruct"],
+ test_type=(
+ VLMTestType.IMAGE,
+ VLMTestType.MULTI_IMAGE,
+ VLMTestType.VIDEO
+ ),
+ prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
+ img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501
+ video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", # noqa: E501
+ max_model_len=4096,
+ max_num_seqs=2,
+ auto_cls=AutoModelForVision2Seq,
+ vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
+ image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
+ marks=[pytest.mark.skipif(
+ TRANSFORMERS_VERSION < "4.49.0",
+ reason="HF model requires transformers>=4.49.0",
+ ), pytest.mark.core_model, pytest.mark.cpu_model],
+ ),
#### Extended model tests
"aria": VLMTestInfo(
models=["rhymes-ai/Aria"],
diff --git a/tests/models/multimodal/processing/test_common.py b/tests/models/multimodal/processing/test_common.py
index 5cd749cbd7795..77cf3442df905 100644
--- a/tests/models/multimodal/processing/test_common.py
+++ b/tests/models/multimodal/processing/test_common.py
@@ -161,6 +161,7 @@ def _test_processing_correctness(
"nvidia/NVLM-D-72B",
"Qwen/Qwen-VL-Chat",
"Qwen/Qwen2-VL-2B-Instruct",
+ "Qwen/Qwen2.5-VL-3B-Instruct",
"Qwen/Qwen2-Audio-7B-Instruct",
"fixie-ai/ultravox-v0_3",
])
diff --git a/tests/models/registry.py b/tests/models/registry.py
index 285fbe4848090..20787fe008aa8 100644
--- a/tests/models/registry.py
+++ b/tests/models/registry.py
@@ -264,6 +264,8 @@ def check_available_online(
trust_remote_code=True),
"Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
+ "Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
+ min_transformers_version="4.49"), # noqa: E501
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_3",
trust_remote_code=True),
# [Encoder-decoder]
diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py
index 3a6e75b1d8e58..f04902ae1c767 100644
--- a/vllm/entrypoints/chat_utils.py
+++ b/vllm/entrypoints/chat_utils.py
@@ -410,7 +410,7 @@ def _placeholder_str(self, modality: ModalityStr,
return ""
if model_type == "mllama":
return "<|image|>"
- if model_type == "qwen2_vl":
+ if model_type in ("qwen2_vl", "qwen2_5_vl"):
return "<|vision_start|><|image_pad|><|vision_end|>"
if model_type == "molmo":
return ""
@@ -430,7 +430,7 @@ def _placeholder_str(self, modality: ModalityStr,
return "()"
raise TypeError(f"Unknown model type: {model_type}")
elif modality == "video":
- if model_type == "qwen2_vl":
+ if model_type in ("qwen2_vl", "qwen2_5_vl"):
return "<|vision_start|><|video_pad|><|vision_end|>"
if model_type in ("minicpmo", "minicpmv"):
return "()"
diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py
index 814c3b7d9cd83..b3b9b0e876057 100644
--- a/vllm/model_executor/layers/rotary_embedding.py
+++ b/vllm/model_executor/layers/rotary_embedding.py
@@ -27,6 +27,7 @@
import torch
import torch.nn as nn
+from transformers import PretrainedConfig
from vllm.model_executor.custom_op import CustomOp
@@ -772,8 +773,12 @@ def __init__(
dtype: torch.dtype,
mrope_section: Optional[List[int]] = None,
) -> None:
- super().__init__(head_size, rotary_dim, max_position_embeddings, base,
- is_neox_style, dtype)
+ # In Qwen2.5-VL, the maximum index value is related to the duration of
+ # the input video. We enlarge max_position_embeddings to 4 times to get
+ # a larger the cos and sin cache.
+ self.cache_max_position_num = max_position_embeddings * 4
+ super().__init__(head_size, rotary_dim, self.cache_max_position_num,
+ base, is_neox_style, dtype)
self.mrope_section = mrope_section
if self.mrope_section:
@@ -831,13 +836,10 @@ def forward(
@staticmethod
def get_input_positions(
input_tokens: List[int],
+ hf_config: PretrainedConfig,
image_grid_thw: Union[List[List[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor],
- image_token_id: int,
- video_token_id: int,
- vision_start_token_id: int,
- vision_end_token_id: int,
- spatial_merge_size: int,
+ second_per_grid_ts: Optional[List[float]] = None,
context_len: int = 0,
seq_len: Optional[int] = None,
) -> Tuple[List[List[int]], int]:
@@ -845,16 +847,13 @@ def get_input_positions(
llm_positions, mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
- input_tokens,
- image_grid_thw,
- video_grid_thw,
- image_token_id,
- video_token_id,
- vision_start_token_id,
- vision_end_token_id,
- spatial_merge_size,
- context_len,
- seq_len,
+ input_tokens=input_tokens,
+ hf_config=hf_config,
+ image_grid_thw=image_grid_thw,
+ video_grid_thw=video_grid_thw,
+ second_per_grid_ts=second_per_grid_ts,
+ context_len=context_len,
+ seq_len=seq_len,
)
return llm_positions.tolist(), mrope_position_delta
@@ -862,18 +861,22 @@ def get_input_positions(
@staticmethod
def get_input_positions_tensor(
input_tokens: List[int],
+ hf_config: PretrainedConfig,
image_grid_thw: Union[List[List[int]], torch.Tensor],
video_grid_thw: Union[List[List[int]], torch.Tensor],
- image_token_id: int,
- video_token_id: int,
- vision_start_token_id: int,
- vision_end_token_id: int,
- spatial_merge_size: int,
+ second_per_grid_ts: Optional[List[float]] = None,
context_len: int = 0,
seq_len: Optional[int] = None,
) -> Tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value."""
+ image_token_id = hf_config.image_token_id
+ video_token_id = hf_config.video_token_id
+ vision_start_token_id = hf_config.vision_start_token_id
+ spatial_merge_size = hf_config.vision_config.spatial_merge_size
+ tokens_per_second = getattr(hf_config.vision_config,
+ "tokens_per_second", 1.0)
+
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()
if isinstance(video_grid_thw, torch.Tensor):
@@ -892,6 +895,7 @@ def get_input_positions_tensor(
image_index, video_index = 0, 0
for _ in range(image_nums + video_nums):
+ video_second_per_grid_t = 0.0
if image_token_id in input_tokens and remain_images > 0:
ed_image = input_tokens.index(image_token_id, st)
else:
@@ -915,9 +919,13 @@ def get_input_positions_tensor(
video_grid_thw[video_index][1],
video_grid_thw[video_index][2],
)
+ video_second_per_grid_t = 1.0
+ if second_per_grid_ts is not None:
+ video_second_per_grid_t = second_per_grid_ts[video_index]
video_index += 1
remain_videos -= 1
ed = ed_video
+
llm_grid_t, llm_grid_h, llm_grid_w = \
t, h // spatial_merge_size, w // spatial_merge_size
text_len = ed - st
@@ -927,8 +935,10 @@ def get_input_positions_tensor(
llm_pos_ids_list.append(
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
- t_index = torch.arange(llm_grid_t).view(-1, 1).expand(
- -1, llm_grid_h * llm_grid_w).flatten()
+ t_index = (torch.arange(llm_grid_t).view(-1, 1).expand(
+ -1, llm_grid_h * llm_grid_w) * video_second_per_grid_t *
+ tokens_per_second).long().flatten()
+
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(
llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(
diff --git a/vllm/model_executor/models/qwen2_5_vl.py b/vllm/model_executor/models/qwen2_5_vl.py
new file mode 100644
index 0000000000000..e93cf46b900b6
--- /dev/null
+++ b/vllm/model_executor/models/qwen2_5_vl.py
@@ -0,0 +1,1133 @@
+# SPDX-License-Identifier: Apache-2.0
+
+# Adapted from
+# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
+# Copyright 2025 The vLLM team.
+# Copyright 2025 The Qwen Team.
+# Copyright 2025 The HuggingFace Inc. team.
+# All rights reserved.
+#
+# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
+# and OPT implementations in this library. It has been modified from its
+# original forms to accommodate minor architectural differences compared
+# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
+from functools import cached_property, partial
+from typing import (Callable, Iterable, List, Literal, Mapping, Optional, Set,
+ Tuple, TypedDict, Union)
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from transformers import BatchFeature
+from transformers.models.qwen2_5_vl import (Qwen2_5_VLImageProcessor,
+ Qwen2_5_VLProcessor)
+from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
+ Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)
+
+from vllm.attention import AttentionMetadata
+from vllm.config import VllmConfig
+from vllm.distributed import parallel_state
+from vllm.distributed import utils as dist_utils
+from vllm.logger import init_logger
+from vllm.model_executor import SamplingMetadata
+from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY
+from vllm.model_executor.layers.linear import (ColumnParallelLinear,
+ RowParallelLinear)
+from vllm.model_executor.layers.quantization import QuantizationConfig
+from vllm.model_executor.layers.quantization.gptq import GPTQConfig
+from vllm.model_executor.layers.quantization.gptq_marlin import (
+ GPTQMarlinConfig)
+from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
+from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.models.module_mapping import MultiModelKeys
+from vllm.multimodal import MULTIMODAL_REGISTRY
+from vllm.multimodal.inputs import MultiModalFieldConfig
+from vllm.platforms import _Backend
+from vllm.sequence import IntermediateTensors
+from vllm.transformers_utils.config import uses_mrope
+
+from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
+from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder
+from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo,
+ apply_rotary_pos_emb_vision)
+from .utils import (AutoWeightsLoader, WeightsMapper,
+ init_vllm_registered_model, maybe_prefix,
+ merge_multimodal_embeddings)
+from .vision import get_vit_attn_backend
+
+logger = init_logger(__name__)
+
+# === Vision Inputs === #
+
+
+class Qwen2_5_VLImagePixelInputs(TypedDict):
+ type: Literal["pixel_values"]
+ pixel_values: torch.Tensor
+ """Shape:
+ `(num_patches, num_channels * patch_size * patch_size)`
+ """
+
+ image_grid_thw: torch.Tensor
+ """Shape: `(num_images, 3)`
+ This should be in `(grid_t, grid_h, grid_w)` format.
+ """
+
+
+class Qwen2_5_VLImageEmbeddingInputs(TypedDict):
+ type: Literal["image_embeds"]
+ image_embeds: torch.Tensor
+ """Supported types:
+ - List[`torch.Tensor`]: A list of tensors holding all images' features.
+ Each tensor holds an image's features.
+ - `torch.Tensor`: A tensor holding all images' features
+ (concatenation of all images' feature tensors).
+
+ Tensor shape: `(num_image_features, hidden_size)`
+ - `num_image_features` varies based on
+ the number and resolution of the images.
+ - `hidden_size` must match the hidden size of language model backbone.
+ """
+
+ image_grid_thw: torch.Tensor
+ """Shape: `(num_images, 3)`
+ This should be in `(grid_t, grid_h, grid_w)` format.
+ """
+
+
+Qwen2_5_VLImageInputs = Union[Qwen2_5_VLImagePixelInputs,
+ Qwen2_5_VLImageEmbeddingInputs]
+
+
+class Qwen2_5_VLVideoPixelInputs(TypedDict):
+ type: Literal["pixel_values_videos"]
+ pixel_values_videos: torch.Tensor
+ """Shape:
+ `(num_patches,
+ num_channels * temporal_patch_size * patch_size * patch_size)`
+ """
+
+ video_grid_thw: torch.Tensor
+ """Shape: `(num_videos, 3)`
+
+ This should be in `(grid_t, grid_h, grid_w)` format.
+ """
+
+ second_per_grid_ts: torch.Tensor
+ """
+ The video time interval (in seconds) for each grid along the temporal
+ dimension in the 3D position IDs. Returned when `videos` is not `None`.
+ """
+
+
+class Qwen2_5_VLVideoEmbeddingInputs(TypedDict):
+ type: Literal["video_embeds"]
+ video_embeds: torch.Tensor
+ """Supported types:
+ - List[`torch.Tensor`]: A list of tensors holding all videos' features.
+ Each tensor holds an video's features.
+ - `torch.Tensor`: A tensor holding all videos' features
+ (concatenation of all videos' feature tensors).
+
+ Tensor shape: `(num_image_features, hidden_size)`
+ - `num_image_features` varies based on
+ the number and resolution of the videos.
+ - `hidden_size` must match the hidden size of language model backbone.
+ """
+
+ video_grid_thw: torch.Tensor
+ """Shape: `(num_videos, 3)`
+ This should be in `(grid_t, grid_h, grid_w)` format.
+ """
+
+
+Qwen2_5_VLVideoInputs = Union[Qwen2_5_VLVideoPixelInputs,
+ Qwen2_5_VLVideoEmbeddingInputs]
+
+# === Vision Encoder === #
+
+
+class Qwen2_5_VisionMLP(nn.Module):
+
+ def __init__(self,
+ in_features: int,
+ hidden_features: int,
+ bias: bool = False,
+ act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = ""):
+ super().__init__()
+ self.gate_proj = ColumnParallelLinear(in_features,
+ hidden_features,
+ bias=bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.gate_proj")
+ self.up_proj = ColumnParallelLinear(in_features,
+ hidden_features,
+ bias=bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.up_proj")
+ self.down_proj = RowParallelLinear(hidden_features,
+ in_features,
+ bias=bias,
+ quant_config=quant_config,
+ prefix=f"{prefix}.down_proj")
+ self.act_fn = act_fn
+
+ def forward(self, x: torch.Tensor):
+ x_gate, _ = self.gate_proj(x)
+ x_gate = self.act_fn(x_gate)
+ x_up, _ = self.up_proj(x)
+ x_down, _ = self.down_proj(x_gate * x_up)
+ return x_down
+
+
+class Qwen2_5_VisionAttention(nn.Module):
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ projection_size: int,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ # Per attention head and per partition values.
+ world_size = parallel_state.get_tensor_model_parallel_world_size()
+ self.hidden_size_per_attention_head = dist_utils.divide(
+ projection_size, num_heads)
+ self.num_attention_heads_per_partition = dist_utils.divide(
+ num_heads, world_size)
+
+ self.qkv = ColumnParallelLinear(input_size=embed_dim,
+ output_size=3 * projection_size,
+ quant_config=quant_config,
+ prefix=f"{prefix}.qkv")
+ self.proj = RowParallelLinear(input_size=projection_size,
+ output_size=embed_dim,
+ quant_config=quant_config,
+ prefix=f"{prefix}.proj")
+
+ # Detect attention implementation.
+ self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
+ if self.attn_backend not in {
+ _Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS
+ }:
+ raise RuntimeError(
+ f"Qwen2.5-VL does not support {self.attn_backend} backend now."
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ cu_seqlens: torch.Tensor,
+ rotary_pos_emb: torch.Tensor,
+ ) -> torch.Tensor:
+ # [s, b, c] --> [s, b, head * 3 * head_dim]
+ x, _ = self.qkv(x)
+
+ # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
+ new_x_shape = x.size()[:-1] + (
+ self.num_attention_heads_per_partition,
+ 3 * self.hidden_size_per_attention_head,
+ )
+ x = x.view(*new_x_shape)
+
+ # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
+ q, k, v = dist_utils.split_tensor_along_last_dim(x, 3)
+ batch_size = q.shape[1]
+
+ q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
+ for x in (q, k, v))
+ if rotary_pos_emb is not None:
+ q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
+ k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
+
+ if self.attn_backend == _Backend.FLASH_ATTN:
+ # from vllm_flash_attn.flash_attn_interface import (
+ # flash_attn_varlen_func)
+ from flash_attn import flash_attn_varlen_func
+
+ q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
+
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
+ output = flash_attn_varlen_func(q,
+ k,
+ v,
+ cu_seqlens_q=cu_seqlens,
+ cu_seqlens_k=cu_seqlens,
+ max_seqlen_q=max_seqlen,
+ max_seqlen_k=max_seqlen,
+ dropout_p=0,
+ causal=False)
+
+ context_layer = rearrange(output,
+ "(b s) ... -> b s ...",
+ b=batch_size)
+ elif self.attn_backend == _Backend.TORCH_SDPA:
+ seq_length = q.size(1)
+ q, k, v = (rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
+ attention_mask = torch.zeros([1, seq_length, seq_length],
+ device=q.device,
+ dtype=torch.bool)
+ for i in range(1, len(cu_seqlens)):
+ attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
+ cu_seqlens[i - 1]:cu_seqlens[i]] = True
+ output = F.scaled_dot_product_attention(q,
+ k,
+ v,
+ attention_mask,
+ dropout_p=0.0)
+ context_layer = rearrange(output, "b h s d -> b s h d ")
+ elif self.attn_backend == _Backend.XFORMERS:
+ from xformers import ops as xops
+ from xformers.ops.fmha.attn_bias import BlockDiagonalMask
+
+ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
+ attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
+ kv_seqlen=None)
+
+ context_layer = xops.memory_efficient_attention_forward(
+ q, k, v, attn_bias=attn_bias, p=0, scale=None)
+ context_layer = rearrange(context_layer,
+ "b s h d -> s b (h d)").contiguous()
+
+ output, _ = self.proj(context_layer)
+ return output
+
+
+class Qwen2RMSNorm(nn.Module):
+
+ def __init__(self, hidden_size, eps=1e-6):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance +
+ self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class Qwen2_5_VisionBlock(nn.Module):
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_hidden_dim: int,
+ act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
+ norm_layer: Optional[Callable[[int], nn.Module]] = None,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ if norm_layer is None:
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+ self.norm1 = norm_layer(dim)
+ self.norm2 = norm_layer(dim)
+ self.attn = Qwen2_5_VisionAttention(embed_dim=dim,
+ num_heads=num_heads,
+ projection_size=dim,
+ quant_config=quant_config,
+ prefix=f"{prefix}.attn")
+ self.mlp = Qwen2_5_VisionMLP(dim,
+ mlp_hidden_dim,
+ act_fn=act_fn,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp")
+
+ def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
+ rotary_pos_emb: torch.Tensor) -> torch.Tensor:
+ x = x + self.attn(self.norm1(x),
+ cu_seqlens=cu_seqlens,
+ rotary_pos_emb=rotary_pos_emb)
+ x = x + self.mlp(self.norm2(x))
+ return x
+
+
+class Qwen2_5_VisionPatchEmbed(nn.Module):
+
+ def __init__(
+ self,
+ patch_size: int = 14,
+ temporal_patch_size: int = 2,
+ in_channels: int = 3,
+ hidden_size: int = 1152,
+ ) -> None:
+ super().__init__()
+ self.patch_size = patch_size
+ self.temporal_patch_size = temporal_patch_size
+ self.hidden_size = hidden_size
+
+ kernel_size = (temporal_patch_size, patch_size, patch_size)
+ self.proj = nn.Conv3d(in_channels,
+ hidden_size,
+ kernel_size=kernel_size,
+ stride=kernel_size,
+ bias=False)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ L, C = x.shape
+ x = x.view(L, -1, self.temporal_patch_size, self.patch_size,
+ self.patch_size)
+ x = self.proj(x).view(L, self.hidden_size)
+ return x
+
+
+class Qwen2_5_VisionPatchMerger(nn.Module):
+
+ def __init__(
+ self,
+ d_model: int,
+ context_dim: int,
+ norm_layer: Optional[Callable[[int], nn.Module]] = None,
+ spatial_merge_size: int = 2,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+ self.hidden_size = context_dim * (spatial_merge_size**2)
+ if norm_layer is None:
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+ self.ln_q = norm_layer(context_dim)
+ self.mlp = nn.ModuleList([
+ ColumnParallelLinear(self.hidden_size,
+ self.hidden_size,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp.0"),
+ nn.GELU(),
+ RowParallelLinear(self.hidden_size,
+ d_model,
+ bias=True,
+ quant_config=quant_config,
+ prefix=f"{prefix}.mlp.2"),
+ ])
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.ln_q(x)
+ x = x.view(-1, self.hidden_size)
+
+ mlp_fc1, mlp_act, mlp_fc2 = self.mlp
+ x_parallel, _ = mlp_fc1(x)
+ x_parallel = mlp_act(x_parallel)
+ out, _ = mlp_fc2(x_parallel)
+ return out
+
+
+class Qwen2_5_VisionRotaryEmbedding(nn.Module):
+
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
+ super().__init__()
+ self.dim = dim
+ self.theta = theta
+ inv_freq = 1.0 / (theta
+ **(torch.arange(0, dim, 2, dtype=torch.float) / dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self._seq_len_cached = 0
+ self._freqs_cached = None
+
+ def update_freqs_cache(self, seqlen: int) -> None:
+ if seqlen > self._seq_len_cached:
+ seqlen *= 2
+ self._seq_len_cached = seqlen
+ self.inv_freq = 1.0 / (self.theta**(torch.arange(
+ 0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device)
+ / self.dim))
+ seq = torch.arange(seqlen,
+ device=self.inv_freq.device,
+ dtype=self.inv_freq.dtype)
+ freqs = torch.outer(seq, self.inv_freq)
+ self._freqs_cached = freqs
+
+ def forward(self, seqlen: int) -> torch.Tensor:
+ self.update_freqs_cache(seqlen)
+ return self._freqs_cached[:seqlen]
+
+
+class Qwen2_5_VisionTransformer(nn.Module):
+
+ def __init__(
+ self,
+ vision_config: Qwen2_5_VLVisionConfig,
+ norm_eps: float = 1e-6,
+ quant_config: Optional[QuantizationConfig] = None,
+ prefix: str = "",
+ ) -> None:
+ super().__init__()
+
+ patch_size = vision_config.patch_size
+ temporal_patch_size = vision_config.temporal_patch_size
+ in_channels = vision_config.in_channels
+ depth = vision_config.depth
+ self.hidden_size = vision_config.hidden_size
+ self.num_heads = vision_config.num_heads
+
+ # args for get_window_index
+ self.window_size = vision_config.window_size
+ self.patch_size = vision_config.patch_size
+ self.spatial_merge_size = vision_config.spatial_merge_size
+ self.fullatt_block_indexes = vision_config.fullatt_block_indexes
+ self.spatial_merge_unit = self.spatial_merge_size**2
+
+ self.patch_embed = Qwen2_5_VisionPatchEmbed(
+ patch_size=patch_size,
+ temporal_patch_size=temporal_patch_size,
+ in_channels=in_channels,
+ hidden_size=self.hidden_size,
+ )
+
+ # NOTE: We use torch native RMSNorm here for precision purposes.
+ norm_layer = partial(Qwen2RMSNorm, eps=norm_eps)
+ head_dim = self.hidden_size // self.num_heads
+ self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
+
+ self.blocks = nn.ModuleList([
+ Qwen2_5_VisionBlock(
+ dim=self.hidden_size,
+ num_heads=self.num_heads,
+ mlp_hidden_dim=vision_config.intermediate_size,
+ act_fn=_ACTIVATION_REGISTRY[vision_config.hidden_act],
+ norm_layer=norm_layer,
+ quant_config=quant_config,
+ prefix=f"{prefix}.blocks.{layer_idx}")
+ for layer_idx in range(depth)
+ ])
+ self.merger = Qwen2_5_VisionPatchMerger(
+ d_model=vision_config.out_hidden_size,
+ context_dim=self.hidden_size,
+ norm_layer=norm_layer,
+ spatial_merge_size=self.spatial_merge_size,
+ quant_config=quant_config,
+ prefix=f"{prefix}.merger",
+ )
+
+ @property
+ def dtype(self) -> torch.dtype:
+ return self.patch_embed.proj.weight.dtype
+
+ @property
+ def device(self) -> torch.device:
+ return self.patch_embed.proj.weight.device
+
+ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
+ pos_ids = []
+ for t, h, w in grid_thw:
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
+ hpos_ids = hpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ ).permute(0, 2, 1, 3).flatten()
+ wpos_ids = wpos_ids.reshape(
+ h // self.spatial_merge_size,
+ self.spatial_merge_size,
+ w // self.spatial_merge_size,
+ self.spatial_merge_size,
+ ).permute(0, 2, 1, 3).flatten()
+ pos_ids.append(
+ torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
+ pos_ids = torch.cat(pos_ids, dim=0)
+ max_grid_size = grid_thw[:, 1:].max()
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
+ return rotary_pos_emb
+
+ def get_window_index(self, grid_thw):
+ window_index: list = []
+ cu_window_seqlens: list = [0]
+ window_index_id = 0
+ vit_merger_window_size = (self.window_size //
+ self.spatial_merge_size // self.patch_size)
+
+ for grid_t, grid_h, grid_w in grid_thw:
+ llm_grid_h = grid_h // self.spatial_merge_size
+ llm_grid_w = grid_w // self.spatial_merge_size
+ index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
+ grid_t, llm_grid_h, llm_grid_w)
+ pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
+ pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
+ num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
+ num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
+ index_padded = F.pad(index, (0, pad_w, 0, pad_h), 'constant', -100)
+ index_padded = index_padded.reshape(grid_t, num_windows_h,
+ vit_merger_window_size,
+ num_windows_w,
+ vit_merger_window_size)
+ index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
+ grid_t, num_windows_h * num_windows_w, vit_merger_window_size,
+ vit_merger_window_size)
+ seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
+ index_padded = index_padded.reshape(-1)
+ index_new = index_padded[index_padded != -100]
+ window_index.append(index_new + window_index_id)
+ cu_seqlens_tmp = seqlens.cumsum(
+ 0) * self.spatial_merge_unit + cu_window_seqlens[-1]
+ cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
+ window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
+ window_index = torch.cat(window_index, dim=0)
+ return window_index, cu_window_seqlens
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ grid_thw: torch.Tensor,
+ ) -> torch.Tensor:
+ # patchify
+ hidden_states = x.to(device=self.device, dtype=self.dtype)
+ hidden_states = self.patch_embed(hidden_states)
+
+ # compute position embedding
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
+
+ # windows attention
+ window_index, cu_window_seqlens = self.get_window_index(grid_thw)
+ cu_window_seqlens = torch.tensor(
+ cu_window_seqlens,
+ device=hidden_states.device,
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32)
+ cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
+ seq_len, _ = hidden_states.size()
+ hidden_states = hidden_states.reshape(
+ seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
+ hidden_states = hidden_states[window_index, :, :]
+ hidden_states = hidden_states.reshape(seq_len, -1)
+ rotary_pos_emb = rotary_pos_emb.reshape(
+ seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
+ rotary_pos_emb = rotary_pos_emb[window_index, :, :]
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
+ # compute cu_seqlens
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
+ grid_thw[:, 0]).cumsum(
+ dim=0, dtype=torch.int32)
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
+
+ # transformers
+ hidden_states = hidden_states.unsqueeze(1)
+ for layer_num, blk in enumerate(self.blocks):
+ if layer_num in self.fullatt_block_indexes:
+ cu_seqlens_now = cu_seqlens
+ else:
+ cu_seqlens_now = cu_window_seqlens
+ hidden_states = blk(hidden_states,
+ cu_seqlens=cu_seqlens_now,
+ rotary_pos_emb=rotary_pos_emb)
+
+ # adapter
+ hidden_states = self.merger(hidden_states)
+ reverse_indices = torch.argsort(window_index)
+ hidden_states = hidden_states[reverse_indices, :]
+ return hidden_states
+
+ def load_weights(self, weights: Iterable[Tuple[str,
+ torch.Tensor]]) -> Set[str]:
+ stacked_params_mapping = [
+ # (param_name, shard_name, shard_id)
+ ("qkv_proj", "q_proj", "q"),
+ ("qkv_proj", "k_proj", "k"),
+ ("qkv_proj", "v_proj", "v"),
+ ]
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
+ loaded_params: Set[str] = set()
+
+ for name, loaded_weight in weights:
+ for (param_name, weight_name, shard_id) in stacked_params_mapping:
+ if weight_name not in name:
+ continue
+ name = name.replace(weight_name, param_name)
+
+ param = params_dict[name]
+ weight_loader = param.weight_loader
+ weight_loader(param, loaded_weight, shard_id)
+ break
+ else:
+ if name.endswith("qkv.weight"):
+ visual_num_heads = self.num_heads
+ visual_embed_dim = self.hidden_size
+ head_size = visual_embed_dim // visual_num_heads
+ loaded_weight = loaded_weight.view(3, visual_num_heads,
+ head_size,
+ visual_embed_dim)
+ loaded_weight = loaded_weight.transpose(0, 1)
+ loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
+ elif name.endswith("qkv.bias"):
+ visual_num_heads = self.num_heads
+ visual_embed_dim = self.hidden_size
+ head_size = visual_embed_dim // visual_num_heads
+ loaded_weight = loaded_weight.view(3, visual_num_heads,
+ head_size)
+ loaded_weight = loaded_weight.transpose(0, 1)
+ loaded_weight = loaded_weight.reshape(-1)
+
+ param = params_dict[name]
+ weight_loader = getattr(param, "weight_loader",
+ default_weight_loader)
+ weight_loader(param, loaded_weight)
+ loaded_params.add(name)
+ return loaded_params
+
+
+class Qwen2_5_VLProcessingInfo(Qwen2VLProcessingInfo):
+
+ def get_hf_config(self):
+ return self.ctx.get_hf_config(Qwen2_5_VLConfig)
+
+ def get_hf_processor(
+ self,
+ *,
+ min_pixels: Optional[int] = None,
+ max_pixels: Optional[int] = None,
+ fps: Optional[float] = 2.0,
+ ) -> Qwen2_5_VLProcessor:
+ hf_processor = self.ctx.get_hf_processor(Qwen2_5_VLProcessor)
+ image_processor = hf_processor.image_processor # type: ignore
+ assert isinstance(image_processor, Qwen2_5_VLImageProcessor)
+
+ if min_pixels:
+ image_processor.min_pixels = min_pixels
+ if max_pixels:
+ image_processor.max_pixels = max_pixels
+ if max_pixels or min_pixels:
+ image_processor.size = {
+ "min_pixels": image_processor.min_pixels,
+ "max_pixels": image_processor.max_pixels,
+ }
+
+ return hf_processor
+
+ def get_image_processor(
+ self,
+ *,
+ min_pixels: Optional[int] = None,
+ max_pixels: Optional[int] = None,
+ fps: Optional[float] = 2.0,
+ ) -> Qwen2_5_VLImageProcessor:
+ hf_processor = self.get_hf_processor(
+ min_pixels=min_pixels,
+ max_pixels=max_pixels,
+ fps=fps,
+ )
+ image_processor = hf_processor.image_processor # type: ignore
+ assert isinstance(image_processor, Qwen2_5_VLImageProcessor)
+ return image_processor
+
+
+class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor):
+
+ def _get_mm_fields_config(
+ self,
+ hf_inputs: BatchFeature,
+ hf_processor_mm_kwargs: Mapping[str, object],
+ ) -> Mapping[str, MultiModalFieldConfig]:
+ return dict(
+ **super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs),
+ second_per_grid_ts=MultiModalFieldConfig.batched("video"),
+ )
+
+
+@MULTIMODAL_REGISTRY.register_processor(
+ Qwen2_5_VLMultiModalProcessor,
+ info=Qwen2_5_VLProcessingInfo,
+ dummy_inputs=Qwen2_5_VLDummyInputsBuilder)
+class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
+ SupportsLoRA, SupportsPP):
+ packed_modules_mapping = {
+ "qkv_proj": [
+ "q_proj",
+ "k_proj",
+ "v_proj",
+ ]
+ }
+
+ # LoRA specific attributes, TODO: double check
+ supported_lora_modules = [
+ "qkv_proj",
+ "o_proj",
+ "gate_up_proj",
+ "down_proj",
+ "gate_proj"
+ "up_proj",
+ # vision tower
+ "qkv",
+ "attn.proj", # Distinguish patch_embed.proj
+ "fc1",
+ "fc2",
+ # projector
+ "mlp.0",
+ "mlp.2"
+ ]
+ embedding_modules = {}
+ embedding_padding_modules = []
+
+ # To ensure correct weight loading and mapping.
+ hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
+ "lm_head.": "language_model.lm_head.",
+ "model.": "language_model.model.",
+ })
+
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+ super().__init__()
+ config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
+ quant_config = vllm_config.quant_config
+ multimodal_config = vllm_config.model_config.multimodal_config
+
+ self.config = config
+ self.multimodal_config = multimodal_config
+
+ self.visual = Qwen2_5_VisionTransformer(
+ config.vision_config,
+ norm_eps=getattr(config, "rms_norm_eps", 1e-6),
+ quant_config=self._maybe_ignore_quant_config(quant_config),
+ prefix=maybe_prefix(prefix, "visual"),
+ )
+
+ self.language_model = init_vllm_registered_model(
+ vllm_config=vllm_config,
+ prefix=maybe_prefix(prefix, "language_model"),
+ architectures=["Qwen2ForCausalLM"],
+ )
+
+ self.make_empty_intermediate_tensors = (
+ self.language_model.make_empty_intermediate_tensors)
+
+ @cached_property
+ def sampler(self):
+ if hasattr(self.language_model, "sampler"):
+ return self.language_model.sampler
+
+ return get_sampler()
+
+ def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
+ # GPTQ configs do not have a list of ignored modules, however AutoGPTQ
+ # seems to avoid vision encoder sections for some models.
+ if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
+ return None
+ return quant_config
+
+ def _validate_and_reshape_mm_tensor(self, mm_input: object,
+ name: str) -> torch.Tensor:
+ if not isinstance(mm_input, (torch.Tensor, list)):
+ raise ValueError(f"Incorrect type of {name}. "
+ f"Got type: {type(mm_input)}")
+ if isinstance(mm_input, torch.Tensor):
+ if mm_input.ndim == 2:
+ return mm_input
+ if mm_input.ndim != 3:
+ raise ValueError(f"{name} should be 2D or batched 3D tensor. "
+ f"Got ndim: {mm_input.ndim} "
+ f"(shape={mm_input.shape})")
+ return torch.concat(list(mm_input))
+ else:
+ return torch.concat(mm_input)
+
+ def _parse_and_validate_image_input(
+ self, **kwargs: object) -> Optional[Qwen2_5_VLImageInputs]:
+ pixel_values = kwargs.pop("pixel_values", None)
+ image_embeds = kwargs.pop("image_embeds", None)
+ image_grid_thw = kwargs.pop("image_grid_thw", None)
+
+ if pixel_values is None and image_embeds is None:
+ return None
+
+ if pixel_values is not None:
+ pixel_values = self._validate_and_reshape_mm_tensor(
+ pixel_values, "image pixel values")
+ image_grid_thw = self._validate_and_reshape_mm_tensor(
+ image_grid_thw, "image grid_thw")
+
+ if not isinstance(pixel_values, (torch.Tensor, list)):
+ raise ValueError("Incorrect type of image pixel values. "
+ f"Got type: {type(pixel_values)}")
+
+ return Qwen2_5_VLImagePixelInputs(type="pixel_values",
+ pixel_values=pixel_values,
+ image_grid_thw=image_grid_thw)
+
+ if image_embeds is not None:
+ image_embeds = self._validate_and_reshape_mm_tensor(
+ image_embeds, "image embeds")
+ image_grid_thw = self._validate_and_reshape_mm_tensor(
+ image_grid_thw, "image grid_thw")
+
+ if not isinstance(image_embeds, torch.Tensor):
+ raise ValueError("Incorrect type of image embeddings. "
+ f"Got type: {type(image_embeds)}")
+ return Qwen2_5_VLImageEmbeddingInputs(
+ type="image_embeds",
+ image_embeds=image_embeds,
+ image_grid_thw=image_grid_thw)
+
+ def _parse_and_validate_video_input(
+ self, **kwargs: object) -> Optional[Qwen2_5_VLVideoInputs]:
+ pixel_values_videos = kwargs.pop("pixel_values_videos", None)
+ video_embeds = kwargs.pop("video_embeds", None)
+ video_grid_thw = kwargs.pop("video_grid_thw", None)
+ second_per_grid_ts = kwargs.pop("second_per_grid_ts", None)
+
+ if pixel_values_videos is None and video_embeds is None:
+ return None
+
+ if pixel_values_videos is not None:
+ pixel_values_videos = self._validate_and_reshape_mm_tensor(
+ pixel_values_videos, "video pixel values")
+ video_grid_thw = self._validate_and_reshape_mm_tensor(
+ video_grid_thw, "video grid_thw")
+
+ return Qwen2_5_VLVideoPixelInputs(
+ type="pixel_values_videos",
+ pixel_values_videos=pixel_values_videos,
+ video_grid_thw=video_grid_thw,
+ second_per_grid_ts=second_per_grid_ts,
+ )
+
+ if video_embeds is not None:
+ video_embeds = self._validate_and_reshape_mm_tensor(
+ video_embeds, "video embeds")
+ video_grid_thw = self._validate_and_reshape_mm_tensor(
+ video_grid_thw, "video grid_thw")
+
+ if not isinstance(video_embeds, torch.Tensor):
+ raise ValueError("Incorrect type of video embeddings. "
+ f"Got type: {type(video_embeds)}")
+ return Qwen2_5_VLVideoEmbeddingInputs(
+ type="video_embeds",
+ video_embeds=video_embeds,
+ video_grid_thw=video_grid_thw)
+
+ def _process_image_input(
+ self,
+ image_input: Qwen2_5_VLImageInputs) -> tuple[torch.Tensor, ...]:
+
+ grid_thw = image_input["image_grid_thw"]
+ assert grid_thw.ndim == 2
+
+ if image_input["type"] == "image_embeds":
+ image_embeds = image_input["image_embeds"].type(self.visual.dtype)
+ else:
+ pixel_values = image_input["pixel_values"].type(self.visual.dtype)
+ image_embeds = self.visual(pixel_values, grid_thw=grid_thw)
+
+ # Split concatenated embeddings for each image item.
+ merge_size = self.visual.spatial_merge_size
+ sizes = grid_thw.prod(-1) // merge_size // merge_size
+
+ return image_embeds.split(sizes.tolist())
+
+ def _process_video_input(
+ self,
+ video_input: Qwen2_5_VLVideoInputs) -> tuple[torch.Tensor, ...]:
+
+ grid_thw = video_input["video_grid_thw"]
+ assert grid_thw.ndim == 2
+
+ if video_input["type"] == "video_embeds":
+ video_embeds = video_input["video_embeds"].type(self.visual.dtype)
+ else:
+ pixel_values_videos = video_input["pixel_values_videos"].type(
+ self.visual.dtype)
+ video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw)
+
+ # Split concatenated embeddings for each video item.
+ merge_size = self.visual.spatial_merge_size
+ sizes = grid_thw.prod(-1) // merge_size // merge_size
+
+ return video_embeds.split(sizes.tolist())
+
+ def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
+ modalities = {}
+
+ # Preserve the order of modalities if there are multiple of them
+ # from the order of kwargs.
+ for input_key in kwargs:
+ if input_key in ("pixel_values",
+ "image_embeds") and "images" not in modalities:
+ modalities["images"] = self._parse_and_validate_image_input(
+ **kwargs)
+ if input_key in ("pixel_values_videos",
+ "video_embeds") and "videos" not in modalities:
+ modalities["videos"] = self._parse_and_validate_video_input(
+ **kwargs)
+ return modalities
+
+ def get_multimodal_embeddings(
+ self, **kwargs) -> Optional[tuple[torch.Tensor, ...]]:
+
+ modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
+ if not modalities:
+ return None
+
+ # The result multimodal_embeddings is tuple of tensors, with each
+ # tensor correspoending to a multimodal data item (image or video).
+ multimodal_embeddings: tuple[torch.Tensor, ...] = ()
+
+ # NOTE: It is important to iterate over the keys in this dictionary
+ # to preserve the order of the modalities.
+ for modality in modalities:
+ if modality == "images":
+ image_input = modalities["images"]
+ vision_embeddings = self._process_image_input(image_input)
+ multimodal_embeddings += vision_embeddings
+ if modality == "videos":
+ video_input = modalities["videos"]
+ video_embeddings = self._process_video_input(video_input)
+ multimodal_embeddings += video_embeddings
+ return multimodal_embeddings
+
+ def get_input_embeddings(
+ self,
+ input_ids: torch.Tensor,
+ multimodal_embeddings: Optional[tuple[torch.Tensor, ...]] = None,
+ ) -> torch.Tensor:
+ inputs_embeds = self.language_model.get_input_embeddings(input_ids)
+ if multimodal_embeddings is not None:
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids, inputs_embeds, multimodal_embeddings,
+ [self.config.image_token_id, self.config.video_token_id])
+ return inputs_embeds
+
+ def get_input_embeddings_v0(
+ self,
+ input_ids: torch.Tensor,
+ image_input: Optional[tuple[torch.Tensor, ...]] = None,
+ video_input: Optional[tuple[torch.Tensor, ...]] = None,
+ ) -> torch.Tensor:
+
+ inputs_embeds = self.get_input_embeddings(input_ids)
+ if image_input is not None:
+ image_embeds = self._process_image_input(image_input)
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids,
+ inputs_embeds,
+ image_embeds,
+ placeholder_token_id=self.config.image_token_id,
+ )
+
+ if video_input is not None:
+ video_embeds = self._process_video_input(video_input)
+ inputs_embeds = merge_multimodal_embeddings(
+ input_ids,
+ inputs_embeds,
+ video_embeds,
+ placeholder_token_id=self.config.video_token_id,
+ )
+ return inputs_embeds
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ kv_caches: List[torch.Tensor],
+ attn_metadata: AttentionMetadata,
+ intermediate_tensors: Optional[IntermediateTensors] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ **kwargs: object,
+ ) -> Union[torch.Tensor, IntermediateTensors]:
+ """Run forward pass for Qwen2.5-VL.
+
+ Args:
+ input_ids: Flattened (concatenated) input_ids corresponding to a
+ batch.
+ positions: Flattened (concatenated) position ids corresponding to a
+ batch.
+ **NOTE**: If mrope is enabled (default setting for Qwen2.5-VL
+ opensource models), the shape will be `(3, seq_len)`,
+ otherwise it will be `(seq_len,).
+ pixel_values: Pixel values to be fed to a model.
+ `None` if no images are passed.
+ image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
+ `None` if no images are passed.
+ pixel_values_videos: Pixel values of videos to be fed to a model.
+ `None` if no videos are passed.
+ video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM.
+ `None` if no videos are passed.
+ second_per_grid_ts: Tensor `(num_videos)` of video time interval (
+ in seconds) for each grid along the temporal dimension in the
+ 3D position IDs. `None` if no videos are passed.
+ """
+
+ if intermediate_tensors is not None:
+ inputs_embeds = None
+
+ # NOTE: In v1, inputs_embeds is always generated at model runner from
+ # `get_multimodal_embeddings` and `get_input_embeddings`, this
+ # condition is only for v0 compatibility.
+ elif inputs_embeds is None:
+ image_input = self._parse_and_validate_image_input(**kwargs)
+ video_input = self._parse_and_validate_video_input(**kwargs)
+
+ if image_input is None and video_input is None:
+ inputs_embeds = None
+ else:
+ if uses_mrope(self.config):
+ assert positions.ndim == 2 and positions.size(0) == 3, (
+ "multimodal section rotary embedding requires "
+ f"(3, seq_len) positions, but got {positions.size()}")
+ inputs_embeds = self.get_input_embeddings_v0(
+ input_ids,
+ image_input=image_input,
+ video_input=video_input)
+ input_ids = None
+
+ hidden_states = self.language_model.model(
+ input_ids=input_ids,
+ positions=positions,
+ kv_caches=kv_caches,
+ attn_metadata=attn_metadata,
+ intermediate_tensors=intermediate_tensors,
+ inputs_embeds=inputs_embeds,
+ )
+ return hidden_states
+
+ def compute_logits(
+ self,
+ hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[torch.Tensor]:
+ return self.language_model.compute_logits(hidden_states,
+ sampling_metadata)
+
+ def sample(
+ self,
+ logits: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[SamplerOutput]:
+ return self.language_model.sample(logits, sampling_metadata)
+
+ def load_weights(self, weights: Iterable[Tuple[str,
+ torch.Tensor]]) -> Set[str]:
+
+ loader = AutoWeightsLoader(self)
+ return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
+
+ def get_mm_mapping(self) -> MultiModelKeys:
+ """
+ Get the module prefix in multimodal models
+ """
+ return MultiModelKeys.from_string_field(
+ language_model="language_model",
+ connector="visual.",
+ tower_model="visual.merger.")
diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py
index 2b2638cf68fc7..34ae7b8c94697 100644
--- a/vllm/model_executor/models/qwen2_vl.py
+++ b/vllm/model_executor/models/qwen2_vl.py
@@ -650,8 +650,8 @@ def load_weights(self, weights: Iterable[Tuple[str,
return loaded_params
-class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
- dict[str, torch.Tensor]]):
+class Qwen2VLEmbeddingItems(ModalityDataItems[dict[str, torch.Tensor],
+ dict[str, torch.Tensor]]):
def __init__(self, data: dict, modality: str) -> None:
super().__init__(data, modality)
@@ -683,26 +683,26 @@ def get_passthrough_data(self) -> Mapping[str, object]:
return self.data
-class Qwen2ImageEmbeddingItems(Qwen2EmbeddingItems):
+class Qwen2VLImageEmbeddingItems(Qwen2VLEmbeddingItems):
def __init__(self, data: dict) -> None:
super().__init__(data, "image")
-class Qwen2VideoEmbeddingItems(Qwen2EmbeddingItems):
+class Qwen2VLVideoEmbeddingItems(Qwen2VLEmbeddingItems):
def __init__(self, data: dict) -> None:
super().__init__(data, "video")
-class Qwen2MultiModalDataParser(MultiModalDataParser):
+class Qwen2VLMultiModalDataParser(MultiModalDataParser):
def _parse_image_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict):
- return Qwen2EmbeddingItems(data, modality="image")
+ return Qwen2VLEmbeddingItems(data, modality="image")
return super()._parse_image_data(data)
@@ -711,7 +711,7 @@ def _parse_video_data(
data: Union[dict[str, torch.Tensor], ModalityData[VideoItem]],
) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict):
- return Qwen2EmbeddingItems(data, modality="video")
+ return Qwen2VLEmbeddingItems(data, modality="video")
return super()._parse_video_data(data)
@@ -948,7 +948,7 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor[Qwen2VLProcessingInfo]
):
def _get_data_parser(self) -> MultiModalDataParser:
- return Qwen2MultiModalDataParser()
+ return Qwen2VLMultiModalDataParser()
def _get_prompt_replacements(
self,
diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py
index 962f95f10fc51..b6708f77d8aff 100644
--- a/vllm/model_executor/models/registry.py
+++ b/vllm/model_executor/models/registry.py
@@ -172,6 +172,7 @@
"PixtralForConditionalGeneration": ("pixtral", "PixtralForConditionalGeneration"), # noqa: E501
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration"), # noqa: E501
+ "Qwen2_5_VLForConditionalGeneration": ("qwen2_5_vl", "Qwen2_5_VLForConditionalGeneration"), # noqa: E501
"Qwen2AudioForConditionalGeneration": ("qwen2_audio", "Qwen2AudioForConditionalGeneration"), # noqa: E501
"UltravoxModel": ("ultravox", "UltravoxModel"),
# [Encoder-decoder]
diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py
index 7841fac1df34b..ec6d04cd49752 100644
--- a/vllm/v1/worker/gpu_model_runner.py
+++ b/vllm/v1/worker/gpu_model_runner.py
@@ -285,6 +285,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
if self.model_config.uses_mrope:
image_grid_thw = []
video_grid_thw = []
+ second_per_grid_ts = []
for mm_input in self.requests[req_id].mm_inputs:
if mm_input.get("image_grid_thw") is not None:
image_grid_thw.extend(
@@ -292,6 +293,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
if mm_input.get("video_grid_thw") is not None:
video_grid_thw.extend(
mm_input["video_grid_thw"].tolist())
+ if mm_input.get("second_per_grid_ts") is not None:
+ second_per_grid_ts.extend(
+ mm_input["second_per_grid_ts"])
hf_config = self.model_config.hf_config
@@ -299,14 +303,10 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
self.requests[req_id].mrope_position_delta = \
MRotaryEmbedding.get_input_positions_tensor(
self.requests[req_id].prompt_token_ids,
+ hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
- image_token_id=hf_config.image_token_id,
- video_token_id=hf_config.video_token_id,
- vision_start_token_id=hf_config.vision_start_token_id,
- vision_end_token_id=hf_config.vision_end_token_id,
- spatial_merge_size=hf_config.vision_config.
- spatial_merge_size,
+ second_per_grid_ts=second_per_grid_ts,
)
req_ids_to_add.append(req_id)
diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py
index 1c3feece95a5a..9400893105d73 100644
--- a/vllm/worker/cpu_model_runner.py
+++ b/vllm/worker/cpu_model_runner.py
@@ -386,20 +386,17 @@ def _compute_multi_modal_input(self,
"mrope embedding type requires multi-modal input mapper "
"returns 'image_grid_thw' or 'video_grid_thw'.")
+ second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None)
hf_config = self.runner.model_config.hf_config
token_ids = seq_data.get_token_ids()
mrope_positions, mrope_position_delta = \
MRotaryEmbedding.get_input_positions(
token_ids,
+ hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
- image_token_id=hf_config.image_token_id,
- video_token_id=hf_config.video_token_id,
- vision_start_token_id=hf_config.vision_start_token_id,
- vision_end_token_id=hf_config.vision_end_token_id,
- spatial_merge_size=hf_config.vision_config.
- spatial_merge_size,
+ second_per_grid_ts=second_per_grid_ts,
context_len=computed_len,
)
seq_data.mrope_position_delta = mrope_position_delta
diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py
index 0bbba55b3b3f8..12baecde6e42c 100644
--- a/vllm/worker/model_runner.py
+++ b/vllm/worker/model_runner.py
@@ -702,6 +702,7 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
"mrope embedding type requires multi-modal input mapper "
"returns 'image_grid_thw' or 'video_grid_thw'.")
+ second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None)
hf_config = self.runner.model_config.hf_config
inter_data.mrope_input_positions = [None] * inter_data.n_seqs
@@ -713,14 +714,10 @@ def _compute_multi_modal_input(self, inter_data: InterDataForSeqGroup,
mrope_input_positions, mrope_position_delta = \
MRotaryEmbedding.get_input_positions(
token_ids,
+ hf_config=hf_config,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
- image_token_id=hf_config.image_token_id,
- video_token_id=hf_config.video_token_id,
- vision_start_token_id=hf_config.vision_start_token_id,
- vision_end_token_id=hf_config.vision_end_token_id,
- spatial_merge_size=hf_config.vision_config.
- spatial_merge_size,
+ second_per_grid_ts=second_per_grid_ts,
context_len=inter_data.context_lens[seq_idx],
seq_len=inter_data.seq_lens[seq_idx],
)