From 32c9eff2fff8ee91a60c9410c69042dc4c1cc5c8 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 6 Jan 2025 23:22:25 +0800 Subject: [PATCH 001/114] [Bugfix][V1] Fix molmo text-only inputs (#11676) Signed-off-by: Jee Jee Li --- .../vision_language/test_models.py | 10 ++ .../vision_language/vlm_utils/model_utils.py | 99 ++++++++++++++++++- vllm/model_executor/models/molmo.py | 56 ++++------- 3 files changed, 123 insertions(+), 42 deletions(-) diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index dc0b683c1f1cb..146685738a1d0 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -341,6 +341,16 @@ ), hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, ), + "molmo": VLMTestInfo( + models=["allenai/Molmo-7B-D-0924"], + test_type=(VLMTestType.IMAGE), + prompt_formatter=lambda img_prompt:"User: " + img_prompt + " Assistant:", # noqa: E501 + max_model_len=4096, + max_num_seqs=2, + image_size_factors=[(),(1.0, 1.0, 1.0)], + patch_hf_runner=model_utils.mlomo_patch_hf_runner, + postprocess_inputs=model_utils.molmo_post_processor, + ), # Tests for phi3v currently live in another file because of a bug in # transformers. Once this issue is fixed, we can enable them here instead. # https://github.com/huggingface/transformers/issues/34307 diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py index 3eca8fb9dcb1a..6c7a753af787e 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py @@ -5,17 +5,20 @@ import re import types from pathlib import PosixPath -from typing import Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from PIL.Image import Image -from transformers import AutoConfig, AutoTokenizer, BatchEncoding +from transformers import (AutoConfig, AutoTokenizer, BatchEncoding, + GenerationConfig) from vllm.sequence import SampleLogprobs from vllm.transformers_utils.tokenizer import patch_padding_side from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE -from .....conftest import HfRunner, ImageAsset, _ImageAssets +from .....conftest import (HfRunner, ImageAsset, PromptAudioInput, + PromptImageInput, PromptVideoInput, _ImageAssets) +from ....utils import TokensTextLogprobs from .types import RunnerOutput @@ -222,6 +225,11 @@ def wrap_inputs_post_processor(hf_inputs: BatchEncoding, dtype: str): return {"model_inputs": hf_inputs} +def molmo_post_processor(hf_inputs: BatchEncoding, dtype: str): + hf_inputs = cast_dtype_post_processor("images")(hf_inputs, dtype) + return {k: v.unsqueeze(0) for k, v in hf_inputs.items()} + + ####### Prompt path encoders for models that need models on disk def qwen_prompt_path_encoder( tmp_path: PosixPath, prompt: str, assets: Union[List[ImageAsset], @@ -451,3 +459,88 @@ def _generate(self, *args, **kwargs): hf_model.model.generate = types.MethodType(_generate, hf_model.model) return hf_model + + +def _generate_greedy_logprobs_limit( + self, + prompts: List[str], + max_tokens: int, + num_logprobs: int, + images: Optional[PromptImageInput] = None, + audios: Optional[PromptAudioInput] = None, + videos: Optional[PromptVideoInput] = None, + **kwargs: Any, +) -> List[TokensTextLogprobs]: + all_inputs = self.get_inputs(prompts, + images=images, + videos=videos, + audios=audios) + + # Process in batches for inference. + if len(all_inputs): + input_ids_lst = [] + images_lst = [] + images_input_idx_lst = [] + imges_masks_lst = [] + for inputs in all_inputs: + input_ids_lst.append(inputs["input_ids"]) + images_lst.append(inputs["images"]) + images_input_idx_lst.append(inputs["image_input_idx"]) + imges_masks_lst.append(inputs["image_masks"]) + batch_inputs = {} + batch_inputs['input_ids'] = torch.cat(input_ids_lst, dim=0) + batch_inputs['images'] = torch.cat(images_lst, dim=0) + batch_inputs['image_input_idx'] = torch.cat(images_input_idx_lst, + dim=0) + batch_inputs['image_masks'] = torch.cat(imges_masks_lst, dim=0) + + outputs = self.model.generate_from_batch( + batch=self.wrap_device(batch_inputs, + device=self.model.device.type), + generation_config=GenerationConfig( + max_new_tokens=max_tokens, + stop_strings="<|endoftext|>", + do_sample=False, + ), + tokenizer=self.tokenizer, + output_hidden_states=True, + return_dict_in_generate=True, + ) + + all_logprobs: List[List[Dict[int, float]]] = [] + all_output_ids: List[List[int]] = [] + all_output_strs: List[str] = [] + + for index in range(len(all_inputs)): + ( + seq_logprobs_lst, + output_len, + ) = self._hidden_states_to_logprobs(outputs.hidden_states, + num_logprobs) + all_logprobs.append(seq_logprobs_lst) + seq_ids = outputs.sequences[index] + output_ids = seq_ids[-output_len:] + all_output_ids.append(output_ids.tolist()) + all_output_strs.append(self.tokenizer.decode(output_ids)) + outputs = zip(all_output_ids, all_output_strs, all_logprobs) + return [(output_ids, output_str, output_logprobs) + for output_ids, output_str, output_logprobs in outputs] + + +####### Molmo-specific HuggingFace runner patchers +def mlomo_patch_hf_runner(hf_model: HfRunner) -> HfRunner: + """Patches and returns an instance of the HfRunner to use for Molmo.""" + hf_processor = hf_model.processor + + def _processor(*args, **kwargs): + return hf_processor.process(*args, **kwargs) + + hf_model.processor = _processor + + setattr( # noqa: B010 + hf_model, + "generate_greedy_logprobs_limit", + types.MethodType(_generate_greedy_logprobs_limit, hf_model), + ) + + return hf_model diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index cc25be9f5b6a9..0e8287bb56b6b 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -1081,45 +1081,25 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): else: out = processor.process(None, image, tokens=inputs["prompt_token_ids"]) - image_processor = processor.image_processor - max_total_crops = 1 + image_processor.max_crops - if image is not None: - images, image_input_idx, image_masks = pad_images( - max_total_crops, - out["images"], - out["image_input_idx"], - out.get("image_masks"), - ) - else: - base_image_input_size = image_processor.base_image_input_size - image_patch_size = image_processor.image_patch_size - image_num_patch = ( - base_image_input_size[0] // image_patch_size, - base_image_input_size[1] // image_patch_size, - ) - n_pixels = image_patch_size * image_patch_size * 3 - n_patches = image_num_patch[0] * image_num_patch[1] - - image_length_w = image_processor.image_token_length_w - image_length_h = image_processor.image_token_length_h - tokens_per_image = image_length_w * image_length_h - images = torch.full( - (max_total_crops, n_patches, n_pixels), - -1, - dtype=torch.float32, - ) - image_input_idx = torch.full( - (max_total_crops, tokens_per_image), - -1, - dtype=torch.int32, + # If there is no image, return directly. + if image is None: + new_prompt_token_ids = out["input_ids"].tolist() + prompt = inputs.get("prompt") + if prompt is None: + prompt = tokenizer.decode(new_prompt_token_ids) + return token_inputs( + prompt_token_ids=new_prompt_token_ids, + prompt=prompt, ) - if image_processor.image_padding_mask: - image_masks = torch.full( - (max_total_crops, n_patches), - -1, - dtype=torch.float32, - ) + image_processor = processor.image_processor + max_total_crops = 1 + image_processor.max_crops + images, image_input_idx, image_masks = pad_images( + max_total_crops, + out["images"], + out["image_input_idx"], + out.get("image_masks"), + ) image_data = dict( images=images, image_input_idx=image_input_idx, @@ -1143,11 +1123,9 @@ def input_processor_for_molmo(ctx: InputContext, inputs: DecoderOnlyInputs): offset = i size += 1 image_data["image_start_end"] = (offset, offset + size) - prompt = inputs.get("prompt") if prompt is None: prompt = tokenizer.decode(new_prompt_token_ids) - return token_inputs( prompt_token_ids=new_prompt_token_ids, prompt=prompt, From e20c92bb618384ce8d0013e0c9ad273d0c23d65b Mon Sep 17 00:00:00 2001 From: Chen Zhang Date: Tue, 7 Jan 2025 00:11:28 +0800 Subject: [PATCH 002/114] [Kernel] Move attn_type to Attention.__init__() (#11690) Signed-off-by: Chen Zhang --- tests/kernels/test_encoder_decoder_attn.py | 100 ++++++++++---------- tests/kernels/utils.py | 12 ++- vllm/attention/backends/abstract.py | 2 +- vllm/attention/backends/blocksparse_attn.py | 14 +-- vllm/attention/backends/flash_attn.py | 4 +- vllm/attention/backends/flashinfer.py | 15 ++- vllm/attention/backends/hpu_attn.py | 13 +-- vllm/attention/backends/ipex_attn.py | 12 +-- vllm/attention/backends/pallas.py | 13 +-- vllm/attention/backends/rocm_flash_attn.py | 14 +-- vllm/attention/backends/torch_sdpa.py | 4 +- vllm/attention/backends/xformers.py | 6 +- vllm/attention/layer.py | 37 ++------ vllm/model_executor/models/bart.py | 44 +++------ vllm/model_executor/models/bert.py | 10 +- vllm/model_executor/models/mllama.py | 11 +-- vllm/model_executor/models/qwen2.py | 35 ++++--- vllm/v1/attention/backends/flash_attn.py | 14 +-- 18 files changed, 159 insertions(+), 201 deletions(-) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index d943b048b7934..614674375786e 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -13,8 +13,7 @@ import torch from tests.kernels.utils import * -from vllm.attention import (Attention, AttentionBackend, AttentionMetadata, - AttentionType) +from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP from vllm.attention.selector import (_Backend, _cached_get_attn_backend, global_force_attn_backend_context_manager) @@ -64,6 +63,7 @@ class TestPoint(NamedTuple): max_dec_seq_len: int max_enc_seq_len: int num_blocks: int + attn_type: AttentionType class TestResources(NamedTuple): @@ -96,7 +96,6 @@ class TestResources(NamedTuple): ''' scale: float - attn_backend: AttentionBackend attn: Attention kv_cache: torch.Tensor @@ -129,16 +128,17 @@ class that Attention will automatically select when it is constructed. ''' scale = float(1.0 / (test_pt.head_size**0.5)) - attn_backend = make_backend(test_pt.backend_name) attn = Attention( test_pt.num_heads, test_pt.head_size, scale=scale, + prefix=f"{test_pt.attn_type}", + attn_type=test_pt.attn_type, ) if test_pt.num_blocks is None or test_pt.num_heads is None: # Caller does not require a KV cache return TestResources( - scale, attn_backend, attn, + scale, attn, torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE)) # Construct KV cache @@ -148,7 +148,7 @@ class that Attention will automatically select when it is constructed. test_pt.block_size, device=CUDA_DEVICE, backend=test_pt.backend_name) - return TestResources(scale, attn_backend, attn, kv_cache) + return TestResources(scale, attn, kv_cache) def _encoder_attn_setup( @@ -193,6 +193,7 @@ def _encoder_attn_setup( _, max_q_seq_len, _, + _, ) = test_pt scale = test_rsrcs.scale @@ -301,6 +302,7 @@ def _decoder_attn_setup( max_q_seq_len, _, _, + _, ) = test_pt scale = test_rsrcs.scale @@ -488,6 +490,7 @@ def _enc_dec_cross_attn_setup_reuses_query( max_decoder_seq_len, max_encoder_seq_len, _, + _, ) = test_pt scale = test_rsrcs.scale @@ -622,7 +625,6 @@ def _run_encoder_attention_test( & attn_metadata ''' assert attn_metadata.num_decode_tokens == 0 - attn_type = AttentionType.ENCODER packed_qkv = encoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None with set_forward_context(attn_metadata, vllm_config): @@ -635,14 +637,11 @@ def _run_encoder_attention_test( # is shaped as [num_tokens, hidden_size] and we can skip the reshape. reshaped_query = packed_qkv.query.view( -1, test_pt.num_heads * test_pt.head_size) - return attn.forward(reshaped_query, - packed_qkv.key, - packed_qkv.value, - torch.tensor([], - dtype=torch.float32, - device=packed_qkv.query.device), - attn_metadata, - attn_type=attn_type) + return attn.forward( + reshaped_query, packed_qkv.key, packed_qkv.value, + torch.tensor([], + dtype=torch.float32, + device=packed_qkv.query.device), attn_metadata) def _run_decoder_self_attention_test( @@ -675,7 +674,6 @@ def _run_decoder_self_attention_test( * Attention.forward() applied to packed_{query,key,value}, kv_cache & attn_metadata ''' - attn_type = AttentionType.DECODER attn = test_rsrcs.attn kv_cache = test_rsrcs.kv_cache packed_qkv = decoder_test_params.packed_qkvo.packed_qkv @@ -690,12 +688,8 @@ def _run_decoder_self_attention_test( # is shaped as [num_tokens, hidden_size] and we can skip the reshape. reshaped_query = packed_qkv.query.view( -1, test_pt.num_heads * test_pt.head_size) - return attn.forward(reshaped_query, - packed_qkv.key, - packed_qkv.value, - kv_cache, - attn_metadata, - attn_type=attn_type) + return attn.forward(reshaped_query, packed_qkv.key, packed_qkv.value, + kv_cache, attn_metadata) def _run_encoder_decoder_cross_attention_test( @@ -742,7 +736,6 @@ def _run_encoder_decoder_cross_attention_test( ''' assert decoder_test_params.packed_qkvo.packed_qkv is not None - attn_type = AttentionType.ENCODER_DECODER attn = test_rsrcs.attn kv_cache = test_rsrcs.kv_cache if cross_test_params is None: @@ -762,12 +755,8 @@ def _run_encoder_decoder_cross_attention_test( # is shaped as [num_tokens, hidden_size] and we can skip the reshape. reshaped_query = decoder_test_params.packed_qkvo.packed_qkv.query.view( -1, test_pt.num_heads * test_pt.head_size) - return attn.forward(reshaped_query, - key, - value, - kv_cache, - attn_metadata, - attn_type=attn_type) + return attn.forward(reshaped_query, key, value, kv_cache, + attn_metadata) @pytest.fixture(autouse=True) @@ -839,7 +828,7 @@ def test_encoder_only( # is not part of this test test_pt = TestPoint(num_heads, head_size, attn_backend.name, batch_size, block_size, max_dec_seq_len, - max_enc_seq_len, 4096) + max_enc_seq_len, 4096, AttentionType.ENCODER) # Attention scale factor, attention backend instance, attention wrapper # instance, KV cache init @@ -855,7 +844,7 @@ def test_encoder_only( # Shared prefill metadata structure prephase_attn_metadata: AttentionMetadata = make_test_metadata( - test_rsrcs.attn_backend, + attn_backend, True, None, decoder_test_params=None, @@ -961,20 +950,29 @@ def test_e2e_enc_dec_attn( # Note: KV cache size of 4096 is arbitrary & chosen intentionally # to be more than necessary, since exceeding the kv cache size # is not part of this test - test_pt = TestPoint(num_heads, head_size, attn_backend.name, - batch_size, block_size, max_dec_seq_len, - max_enc_seq_len, 4096) + enc_test_pt = TestPoint(num_heads, head_size, attn_backend.name, + batch_size, block_size, max_dec_seq_len, + max_enc_seq_len, 4096, AttentionType.ENCODER) + enc_dec_test_pt = TestPoint(num_heads, head_size, attn_backend.name, + batch_size, block_size, max_dec_seq_len, + max_enc_seq_len, 4096, + AttentionType.ENCODER_DECODER) + dec_test_pt = TestPoint(num_heads, head_size, attn_backend.name, + batch_size, block_size, max_dec_seq_len, + max_enc_seq_len, 4096, AttentionType.DECODER) # Attention scale factor, attention backend instance, attention wrapper # instance, KV cache init vllm_config = VllmConfig() with set_current_vllm_config(vllm_config): - test_rsrcs = _make_test_resources(test_pt) + enc_test_rsrcs = _make_test_resources(enc_test_pt) + enc_dec_test_rsrcs = _make_test_resources(enc_dec_test_pt) + dec_test_rsrcs = _make_test_resources(dec_test_pt) # Construct encoder attention test params (only used # during prefill) - enc_test_params = _encoder_attn_setup(test_pt, test_rsrcs) + enc_test_params = _encoder_attn_setup(enc_test_pt, enc_test_rsrcs) # Construct Decoder self-attention prefill-phase & decode-phase # test params, including query/key/value tensors, decoder self-attention @@ -987,7 +985,7 @@ def test_e2e_enc_dec_attn( prephase_dec_test_params, decphase_dec_test_params, cross_block_base_addr, - ) = _decoder_attn_setup(test_pt, test_rsrcs) + ) = _decoder_attn_setup(dec_test_pt, dec_test_rsrcs) # Construct encoder/decoder cross-attention prefill-phase # & decode-phase test params, including key/value tensors, @@ -1000,14 +998,14 @@ def test_e2e_enc_dec_attn( dec_qkv, enc_test_params, prephase_dec_test_params, - test_pt, - test_rsrcs, + enc_dec_test_pt, + enc_dec_test_rsrcs, block_base_addr=cross_block_base_addr) # Shared prefill metadata structure assert prephase_dec_test_params.packed_qkvo.packed_qkv is not None prephase_attn_metadata: AttentionMetadata = make_test_metadata( - test_rsrcs.attn_backend, + attn_backend, True, prephase_dec_test_params.packed_qkvo.packed_qkv.q_seq_lens, decoder_test_params=prephase_dec_test_params, @@ -1017,10 +1015,10 @@ def test_e2e_enc_dec_attn( # PREFILL: encoder attention - enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn, + enc_pckd_act_out = _run_encoder_attention_test(enc_test_rsrcs.attn, enc_test_params, prephase_attn_metadata, - test_pt=test_pt, + test_pt=enc_test_pt, vllm_config=vllm_config) # - Is encoder attention result correct? @@ -1030,10 +1028,10 @@ def test_e2e_enc_dec_attn( # PREFILL: decoder self-attention test prephase_dec_pckd_act_out = _run_decoder_self_attention_test( - test_rsrcs, + dec_test_rsrcs, prephase_dec_test_params, prephase_attn_metadata, - test_pt=test_pt, + test_pt=dec_test_pt, vllm_config=vllm_config) # - Is prefill decoder self-attention correct? @@ -1044,11 +1042,11 @@ def test_e2e_enc_dec_attn( # PREFILL: encoder/decoder cross-attention test prephase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( - test_rsrcs, + enc_dec_test_rsrcs, prephase_dec_test_params, prephase_cross_test_params, prephase_attn_metadata, - test_pt=test_pt, + test_pt=enc_dec_test_pt, vllm_config=vllm_config) # - Is prefill encoder/decoder cross-attention correct? @@ -1059,7 +1057,7 @@ def test_e2e_enc_dec_attn( # DECODE: build decode-phase attention metadata decphase_attn_metadata: AttentionMetadata = make_test_metadata( - test_rsrcs.attn_backend, + attn_backend, False, dec_qkv.q_seq_lens, decoder_test_params=decphase_dec_test_params, @@ -1070,10 +1068,10 @@ def test_e2e_enc_dec_attn( # DECODE: decoder self-attention test decphase_dec_pckd_act_out = _run_decoder_self_attention_test( - test_rsrcs, + dec_test_rsrcs, decphase_dec_test_params, decphase_attn_metadata, - test_pt=test_pt, + test_pt=dec_test_pt, vllm_config=vllm_config) # - Is decode-phase decoder self-attention correct? @@ -1084,11 +1082,11 @@ def test_e2e_enc_dec_attn( # DECODE: encoder/decoder cross-attention test decphase_cross_pckd_act_out = _run_encoder_decoder_cross_attention_test( - test_rsrcs, + enc_dec_test_rsrcs, decphase_dec_test_params, None, decphase_attn_metadata, - test_pt=test_pt, + test_pt=enc_dec_test_pt, vllm_config=vllm_config) # - Is decode-phase encoder/decoder cross-attention correct? diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index e7865fb2500ef..848eea7f54cab 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -13,6 +13,7 @@ from vllm.attention import AttentionBackend, AttentionMetadata, AttentionType from vllm.model_executor.layers.activation import SiluAndMul +from vllm.platforms.interface import _Backend from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) @@ -790,7 +791,7 @@ def make_block_tables_slot_mapping( def make_test_metadata( - attn_backend: AttentionBackend, + attn_backend: _Backend, is_prompt: bool, seq_lens: Optional[List[int]], decoder_test_params: Optional[PhaseTestParameters], @@ -815,7 +816,7 @@ def make_test_metadata( Arguments: - * attn_backend: Backend for sourcing attention kernels + * attn_backend_name: Backend for sourcing attention kernels * is_prompt: prefill if True, o/w decode * seq_lens: list of token counts for each sequence * decoder_test_params: decoder self-attention test params; @@ -882,6 +883,8 @@ def make_test_metadata( # (kv_mmap) cross_kv_mmap = cross_test_params.kv_mmap + attn_backend_obj = make_backend(attn_backend.name) + if is_prompt: # Prefill-phase scenario @@ -902,8 +905,7 @@ def make_test_metadata( context_lens, encoder_seq_lens, device=device) - - return attn_backend.make_metadata( + return attn_backend_obj.make_metadata( num_prefills=num_prefills, slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping), multi_modal_placeholder_index_maps=None, @@ -952,7 +954,7 @@ def make_test_metadata( encoder_seq_lens, device=device) - return attn_backend.make_metadata( + return attn_backend_obj.make_metadata( num_prefills=num_prefills, slot_mapping=kv_mmap.slot_mapping, multi_modal_placeholder_index_maps=None, diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index aed04361e5fb4..f5dcaea79af93 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -233,6 +233,7 @@ def __init__( kv_cache_dtype: str = "auto", blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, ) -> None: raise NotImplementedError @@ -246,7 +247,6 @@ def forward( attn_metadata: T, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 99cb84346d84e..7089d59392c36 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -300,6 +300,7 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, ) -> None: assert blocksparse_params is not None assert alibi_slopes is None, ValueError( @@ -350,6 +351,12 @@ def __init__( active_head_range=self.blocksparse_params.active_head_range, ) + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "BlocksparseFlashAttentionImpl") + def forward( self, query: torch.Tensor, @@ -359,7 +366,6 @@ def forward( attn_metadata: BlocksparseFlashAttentionMetadata, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -375,12 +381,6 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "BlocksparseFlashAttentionImpl") - num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index c69e12ad78c44..23ea244f07dfe 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -600,6 +600,7 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, ) -> None: if blocksparse_params is not None: raise ValueError( @@ -627,6 +628,7 @@ def __init__( raise ValueError( f"Head size {head_size} is not supported by FlashAttention. " f"Supported head sizes are: {support_head_sizes}.") + self.attn_type = attn_type def forward( self, @@ -637,7 +639,6 @@ def forward( attn_metadata: FlashAttentionMetadata, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -659,6 +660,7 @@ def forward( assert output is not None, "Output tensor must be provided." + attn_type = self.attn_type if (attn_type == AttentionType.ENCODER and (not attn_metadata.is_all_encoder_attn_metadata_set)): raise AttributeError("Encoder attention requires setting " diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index e367468d05d26..a11462b2068a5 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -748,6 +748,7 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -764,6 +765,12 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashInferImpl") + def forward( self, query: torch.Tensor, @@ -773,18 +780,10 @@ def forward( attn_metadata: FlashInferMetadata, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: # TODO: directly write to output tensor - - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashInferImpl") - num_heads: int = self.num_heads head_size: int = self.head_size num_kv_heads: int = self.num_kv_heads diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index f90d15d4207e7..94a461e0c8c29 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -102,6 +102,7 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, max_seq_len: int = 4096, + attn_type: str = AttentionType.DECODER, ) -> None: super(AttentionImpl, self).__init__() self.kv_cache_dtype = kv_cache_dtype @@ -143,6 +144,12 @@ def __init__( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {suppored_head_sizes}.") + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "HPUAttentionImpl") + def forward( self, query: torch.Tensor, @@ -152,7 +159,6 @@ def forward( attn_metadata: HPUAttentionMetadata, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -166,11 +172,6 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "HPUAttentionImpl") batch_size, seq_len, hidden_size = query.shape _, seq_len_kv, _ = key.shape diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 21949874bea47..da1d307daa517 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -115,6 +115,7 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, ) -> None: if blocksparse_params is not None: raise ValueError( @@ -146,6 +147,11 @@ def __init__( raise NotImplementedError( "IPEX backend does not support FP8 KV cache. " "Please use xFormers backend instead.") + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "IpexAttnBackendImpl") def split_kv_cache( self, @@ -172,7 +178,6 @@ def forward( attn_metadata: IpexAttnMetadata, # type: ignore k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with IPEX varlen_attention and PagedAttention. @@ -189,11 +194,6 @@ def forward( shape = [num_tokens, num_heads * head_size] """ assert k_scale == 1.0 and v_scale == 1.0 - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "IpexAttnBackendImpl") num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 9809aed0e66f9..2ac492dd8ae54 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -100,6 +100,7 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -141,6 +142,12 @@ def __init__( # megacore mode will be None. self.megacore_mode = "batch" + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "PallasAttentionBackendImpl") + def forward( self, query: torch.Tensor, @@ -150,7 +157,6 @@ def forward( attn_metadata: PallasMetadata, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with Pallas attention. @@ -168,11 +174,6 @@ def forward( shape = [batch_size, seq_len, num_heads * head_size] """ assert k_scale == 1.0 and v_scale == 1.0 - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "PallasAttentionBackendImpl") batch_size, seq_len, hidden_size = query.shape query = query.view(batch_size, seq_len, self.num_heads, self.head_size) key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index d43c15b661ef7..a91a5af5c3d58 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -338,6 +338,7 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, ) -> None: if blocksparse_params is not None: raise ValueError( @@ -397,6 +398,12 @@ def __init__( self.attn_func = _sdpa_attention logger.debug("Using naive attention in ROCmBackend") + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "ROCmFlashAttentionImpl") + def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" tokens, n_kv_heads, head_dim = x.shape @@ -414,7 +421,6 @@ def forward( attn_metadata: ROCmFlashAttentionMetadata, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -432,12 +438,6 @@ def forward( """ # Reminder: Please update docs/source/features/compatibility_matrix.md # If the feature combo become valid - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "ROCmFlashAttentionImpl") - num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 0cff6f5952aba..c14f7754596dd 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -390,6 +390,7 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, ) -> None: if blocksparse_params is not None: raise ValueError( @@ -421,6 +422,7 @@ def __init__( raise NotImplementedError( "Torch SDPA backend does not support FP8 KV cache. " "Please use xFormers backend instead.") + self.attn_type = attn_type def forward( self, @@ -431,7 +433,6 @@ def forward( attn_metadata: TorchSDPAMetadata, # type: ignore k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -448,6 +449,7 @@ def forward( shape = [num_tokens, num_heads * head_size] """ assert k_scale == 1.0 and v_scale == 1.0 + attn_type = self.attn_type if (attn_type == AttentionType.ENCODER and (not attn_metadata.is_all_encoder_attn_metadata_set)): raise AttributeError("Encoder attention requires setting " diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 3e59b3603d2c6..694c7cc1bc36a 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -379,6 +379,7 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + attn_type: str = AttentionType.DECODER, ) -> None: if blocksparse_params is not None: raise ValueError( @@ -405,6 +406,8 @@ def __init__( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {suppored_head_sizes}.") + self.attn_type = attn_type + def forward( self, query: torch.Tensor, @@ -414,7 +417,6 @@ def forward( attn_metadata: "XFormersMetadata", k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: str = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -468,7 +470,7 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - + attn_type = self.attn_type # Check that appropriate attention metadata attributes are # selected for the desired attention type if (attn_type == AttentionType.ENCODER diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 69b6d1e4648df..f1b3598e60b54 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -41,6 +41,7 @@ def __init__( logits_soft_cap: Optional[float] = None, per_layer_sliding_window: Optional[int] = None, prefix: str = "", + attn_type: str = AttentionType.DECODER, ) -> None: super().__init__() if per_layer_sliding_window is not None: @@ -96,7 +97,7 @@ def __init__( impl_cls = attn_backend.get_impl_cls() self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, - blocksparse_params, logits_soft_cap) + blocksparse_params, logits_soft_cap, attn_type) self.num_heads = num_heads self.head_size = head_size self.num_kv_heads = num_kv_heads @@ -119,6 +120,7 @@ def __init__( raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self self.layer_name = prefix + self.attn_type = attn_type def forward( self, @@ -127,18 +129,12 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, - attn_type: str = AttentionType.DECODER, ) -> torch.Tensor: if self.use_direct_call: - return self.impl.forward(query, - key, - value, - kv_cache, - attn_metadata, - self._k_scale, - self._v_scale, - attn_type=attn_type) + return self.impl.forward(query, key, value, kv_cache, + attn_metadata, self._k_scale, + self._v_scale) elif self.use_output: output = torch.empty_like(query) hidden_size = query.size(-1) @@ -152,13 +148,11 @@ def forward( if value is not None: value = value.view(-1, self.num_kv_heads, self.head_size) torch.ops.vllm.unified_attention_with_output( - query, key, value, output, kv_cache, attn_type, - self.layer_name) + query, key, value, output, kv_cache, self.layer_name) return output.view(-1, hidden_size) else: return torch.ops.vllm.unified_attention(query, key, value, - kv_cache, attn_type, - self.layer_name) + kv_cache, self.layer_name) def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore @@ -237,20 +231,13 @@ def unified_attention( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_type: str, layer_name: str, ) -> torch.Tensor: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.dynamic_forward_context self = forward_context.static_forward_context[layer_name] - return self.impl.forward(query, - key, - value, - kv_cache, - attn_metadata, - self._k_scale, - self._v_scale, - attn_type=attn_type) + return self.impl.forward(query, key, value, kv_cache, attn_metadata, + self._k_scale, self._v_scale) def unified_attention_fake( @@ -258,7 +245,6 @@ def unified_attention_fake( key: torch.Tensor, value: torch.Tensor, kv_cache: torch.Tensor, - attn_type: str, layer_name: str, ) -> torch.Tensor: return torch.empty_like(query).contiguous() @@ -279,7 +265,6 @@ def unified_attention_with_output( value: torch.Tensor, output: torch.Tensor, kv_cache: torch.Tensor, - attn_type: str, layer_name: str, ) -> None: forward_context: ForwardContext = get_forward_context() @@ -292,7 +277,6 @@ def unified_attention_with_output( attn_metadata, self._k_scale, self._v_scale, - attn_type=attn_type, output=output) @@ -302,7 +286,6 @@ def unified_attention_with_output_fake( value: torch.Tensor, output: torch.Tensor, kv_cache: torch.Tensor, - attn_type: str, layer_name: str, ) -> None: return diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 3776490cb3465..57eb5adc82d5b 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -71,12 +71,8 @@ def __init__(self, num_embeddings: int, embedding_dim: int): def forward( self, positions: torch.Tensor, - attn_type: AttentionType, ) -> torch.Tensor: """`input_ids' shape is expected to be [bsz x seqlen].""" - - assert attn_type != AttentionType.ENCODER_DECODER - return super().forward(positions + self.offset) @@ -180,7 +176,8 @@ def __init__( num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.attn") + prefix=f"{prefix}.attn", + attn_type=AttentionType.ENCODER) def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata) -> torch.Tensor: @@ -189,12 +186,7 @@ def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - attn_output = self.attn(q, - k, - v, - kv_cache, - attn_metadata, - attn_type=AttentionType.ENCODER) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.out_proj(attn_output) return output @@ -264,7 +256,8 @@ def __init__( num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.attn") + prefix=f"{prefix}.attn", + attn_type=AttentionType.DECODER) def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata) -> torch.Tensor: @@ -273,12 +266,7 @@ def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - attn_output = self.attn(q, - k, - v, - kv_cache, - attn_metadata, - attn_type=AttentionType.DECODER) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.out_proj(attn_output) return output @@ -348,7 +336,8 @@ def __init__( num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.attn") + prefix=f"{prefix}.attn", + attn_type=AttentionType.ENCODER_DECODER) def forward( self, @@ -372,12 +361,7 @@ def forward( _, k, v = qkv_enc.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - attn_output = self.attn(q, - k, - v, - kv_cache, - attn_metadata, - attn_type=AttentionType.ENCODER_DECODER) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.out_proj(attn_output) return output @@ -644,10 +628,7 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, # retrieve input_ids and inputs_embeds inputs_embeds = self.embed_tokens(input_ids) - embed_pos = self.embed_positions( - positions, - AttentionType.ENCODER, - ) + embed_pos = self.embed_positions(positions) embed_pos = embed_pos.to(inputs_embeds.device) hidden_states = inputs_embeds + embed_pos @@ -734,10 +715,7 @@ def forward(self, decoder_input_ids: torch.Tensor, inputs_embeds = self.embed_tokens(decoder_input_ids) # embed positions - embed_pos = self.embed_positions( - decoder_positions, - AttentionType.DECODER, - ) + embed_pos = self.embed_positions(decoder_positions) embed_pos = embed_pos.to(inputs_embeds.device) hidden_states = inputs_embeds + embed_pos diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index c1d47b1bc9bcd..4be136543de15 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -238,7 +238,8 @@ def __init__( num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.attn") + prefix=f"{prefix}.attn", + attn_type=AttentionType.ENCODER_ONLY) def forward( self, @@ -248,12 +249,7 @@ def forward( ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - output = self.attn(q, - k, - v, - kv_cache, - attn_metadata, - attn_type=AttentionType.ENCODER_ONLY) + output = self.attn(q, k, v, kv_cache, attn_metadata) return output diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 6536f9807730c..c5046e06edecb 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -770,6 +770,7 @@ def __init__( self.scaling, self.num_local_key_value_heads, prefix=f"{prefix}.attn", + attn_type=AttentionType.ENCODER_DECODER, ) def forward( @@ -805,13 +806,9 @@ def forward( kv_range_for_decode, attn_metadata) else: - output = self.attn(q.view(-1, - self.num_local_heads * self.head_dim), - k, - v, - kv_cache, - attn_metadata, - attn_type=AttentionType.ENCODER_DECODER) + output = self.attn( + q.view(-1, self.num_local_heads * self.head_dim), k, v, + kv_cache, attn_metadata) out, _ = self.o_proj(output) return out diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 88f4ea4352726..01745b5fd53e1 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -107,7 +107,8 @@ def __init__(self, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, rope_scaling: Optional[Tuple] = None, - prefix: str = "") -> None: + prefix: str = "", + attn_type: str = AttentionType.DECODER) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -160,7 +161,8 @@ def __init__(self, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.attn") + prefix=f"{prefix}.attn", + attn_type=attn_type) def forward( self, @@ -168,17 +170,11 @@ def forward( hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, - attn_type: str = AttentionType.DECODER, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, - k, - v, - kv_cache, - attn_metadata, - attn_type=attn_type) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) output, _ = self.o_proj(attn_output) return output @@ -197,6 +193,16 @@ def __init__( # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 1000000) rope_scaling = getattr(config, "rope_scaling", None) + + # By default, Qwen2 uses causal attention as it is a decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + # (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct) + if getattr(config, "is_causal", True): + attn_type = AttentionType.DECODER + else: + attn_type = AttentionType.ENCODER_ONLY + self.self_attn = Qwen2Attention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -207,6 +213,7 @@ def __init__( quant_config=quant_config, rope_scaling=rope_scaling, prefix=f"{prefix}.self_attn", + attn_type=attn_type, ) self.mlp = Qwen2MLP( hidden_size=self.hidden_size, @@ -220,15 +227,6 @@ def __init__( self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - # By default, Qwen2 uses causal attention as it is a decoder-only model. - # You can override the HF config with `is_causal=False` to enable - # bidirectional attention, which is used in some embedding models - # (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct) - if getattr(config, "is_causal", True): - self._attn_type = AttentionType.DECODER - else: - self._attn_type = AttentionType.ENCODER_ONLY - def forward( self, positions: torch.Tensor, @@ -249,7 +247,6 @@ def forward( hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, - attn_type=self._attn_type, ) # Fully Connected diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 65002f1ad70c7..b02bc9ffde538 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -89,6 +89,7 @@ def __init__( kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + attn_type: AttentionType = AttentionType.DECODER, ) -> None: if blocksparse_params is not None: raise ValueError( @@ -119,6 +120,12 @@ def __init__( f"Head size {head_size} is not supported by FlashAttention. " f"Supported head sizes are: {support_head_sizes}.") + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashAttentionImpl") + def forward( self, query: torch.Tensor, @@ -128,7 +135,6 @@ def forward( attn_metadata: FlashAttentionMetadata, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -142,12 +148,6 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " - "are not implemented for " - "FlashAttentionImpl") - # NOTE(woosuk): FlashAttention does not support FP8 KV cache. assert k_scale == 1.0 and v_scale == 1.0, ( "key/v_scale is not supported in FlashAttention.") From 91b361ae898c944f823534121613f9d3dc19d7d1 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Mon, 6 Jan 2025 11:58:16 -0800 Subject: [PATCH 003/114] [V1] Extend beyond image modality and support mixed-modality inference with Llava-OneVision (#11685) Signed-off-by: Roger Wang Signed-off-by: DarkLight1337 Co-authored-by: DarkLight1337 --- docs/source/models/supported_models.md | 2 +- tests/multimodal/test_utils.py | 209 +++++++++++++++++- tests/v1/core/test_kv_cache_utils.py | 18 +- tests/v1/core/test_prefix_caching.py | 17 +- vllm/model_executor/models/interfaces.py | 6 +- vllm/model_executor/models/llava_onevision.py | 65 +++--- vllm/model_executor/models/molmo.py | 3 - vllm/multimodal/__init__.py | 3 + vllm/multimodal/hasher.py | 100 +++++++++ vllm/multimodal/inputs.py | 9 +- vllm/multimodal/processing.py | 92 +++----- vllm/multimodal/utils.py | 86 ++++++- vllm/v1/engine/__init__.py | 18 +- vllm/v1/engine/mm_input_mapper.py | 67 ------ vllm/v1/engine/processor.py | 101 ++++++--- vllm/v1/request.py | 48 ++-- vllm/v1/worker/gpu_model_runner.py | 74 ++++--- 17 files changed, 636 insertions(+), 282 deletions(-) create mode 100644 vllm/multimodal/hasher.py diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index 5a2778026192a..94a8849f7edcd 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -647,7 +647,7 @@ See [this page](#generative-models) for more information on how to use generativ - `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. - - ✅︎ - - + - ✅︎ * - `MiniCPMV` - MiniCPM-V - T + IE+ diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index 6029f2e514772..198344e5bd88c 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -2,16 +2,22 @@ import mimetypes import os from tempfile import NamedTemporaryFile, TemporaryDirectory -from typing import Dict, Tuple +from typing import TYPE_CHECKING, Dict, NamedTuple, Optional, Tuple import numpy as np import pytest from PIL import Image, ImageChops from transformers import AutoConfig, AutoTokenizer +from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.utils import (MediaConnector, + merge_and_sort_multimodal_metadata, repeat_and_pad_placeholder_tokens) +if TYPE_CHECKING: + from vllm.multimodal.hasher import MultiModalHashDict + from vllm.multimodal.inputs import MultiModalPlaceholderDict + # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) TEST_IMAGE_URLS = [ "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", @@ -191,3 +197,204 @@ def test_repeat_and_pad_placeholder_tokens(model): assert new_prompt == expected_prompt assert new_token_ids == expected_token_ids assert ranges == expected_ranges + + +# Used for the next two tests related to `merge_and_sort_multimodal_metadata`. +class TestCase(NamedTuple): + mm_positions: "MultiModalPlaceholderDict" + mm_hashes: Optional["MultiModalHashDict"] + expected_modalities: list[str] + expected_ranges: list[PlaceholderRange] + expected_hashes: Optional[list[str]] + + +def test_merge_and_sort_multimodal_metadata(): + + test_cases = [ + # Single modality should return result as is but flattened + TestCase( + mm_positions={ + "image": [ + PlaceholderRange(offset=0, length=2), + PlaceholderRange(offset=3, length=2), + ] + }, + mm_hashes={"image": ["hash1", "hash2"]}, + expected_modalities=["image"], + expected_ranges=[ + PlaceholderRange(offset=0, length=2), + PlaceholderRange(offset=3, length=2), + ], + expected_hashes=["hash1", "hash2"], + ), + + # Single modality without hashes return None for mm hash. + TestCase( + mm_positions={ + "image": [ + PlaceholderRange(offset=0, length=2), + PlaceholderRange(offset=2, length=2), + ] + }, + mm_hashes=None, + expected_modalities=["image"], + expected_ranges=[ + PlaceholderRange(offset=0, length=2), + PlaceholderRange(offset=2, length=2), + ], + expected_hashes=None, + ), + + # Multiple modalities with hashes should return sorted modalities + # and flattened ranges and hashes. + TestCase( + mm_positions={ + "image": [ + PlaceholderRange(offset=7, length=4), + PlaceholderRange(offset=11, length=5), + ], + "audio": [ + PlaceholderRange(offset=0, length=2), + PlaceholderRange(offset=2, length=3), + ] + }, + mm_hashes={ + "image": ["image_hash1", "image_hash2"], + "audio": ["audio_hash1", "audio_hash2"], + }, + expected_modalities=["audio", "image"], + expected_ranges=[ + PlaceholderRange(offset=0, length=2), + PlaceholderRange(offset=2, length=3), + PlaceholderRange(offset=7, length=4), + PlaceholderRange(offset=11, length=5), + ], + expected_hashes=[ + "audio_hash1", "audio_hash2", "image_hash1", "image_hash2" + ], + ), + + # Multiple modalities without hashes should return sorted modalities + # and flattened ranges and None. + TestCase( + mm_positions={ + "image": [ + PlaceholderRange(offset=7, length=4), + PlaceholderRange(offset=11, length=5), + ], + "audio": [ + PlaceholderRange(offset=0, length=2), + PlaceholderRange(offset=2, length=3), + ] + }, + mm_hashes=None, + expected_modalities=["audio", "image"], + expected_ranges=[ + PlaceholderRange(offset=0, length=2), + PlaceholderRange(offset=2, length=3), + PlaceholderRange(offset=7, length=4), + PlaceholderRange(offset=11, length=5), + ], + expected_hashes=None, + ), + + # Three modalities + TestCase( + mm_positions={ + "image": [ + PlaceholderRange(offset=15, length=7), + PlaceholderRange(offset=22, length=8), + ], + "audio": [ + PlaceholderRange(offset=0, length=2), + ], + "video": [ + PlaceholderRange(offset=3, length=4), + PlaceholderRange(offset=7, length=5), + PlaceholderRange(offset=12, length=6), + ] + }, + mm_hashes={ + "image": ["image_hash1", "image_hash2"], + "audio": ["audio_hash1"], + "video": ["video_hash1", "video_hash2", "video_hash3"] + }, + expected_modalities=["audio", "video", "image"], + expected_ranges=[ + PlaceholderRange(offset=0, length=2), + PlaceholderRange(offset=3, length=4), + PlaceholderRange(offset=7, length=5), + PlaceholderRange(offset=12, length=6), + PlaceholderRange(offset=15, length=7), + PlaceholderRange(offset=22, length=8), + ], + expected_hashes=[ + "audio_hash1", "video_hash1", "video_hash2", "video_hash3", + "image_hash1", "image_hash2" + ], + ), + ] + + for (mm_positions, mm_hashes, expected_modalities, expected_ranges, + expected_hashes) in test_cases: + modalities, ranges, hashes = merge_and_sort_multimodal_metadata( + mm_positions, mm_hashes) + + assert modalities == expected_modalities + assert ranges == expected_ranges + assert hashes == expected_hashes + + +def test_merge_and_sort_multimodal_metadata_with_interleaving(): + + test_cases = [ + + #