From 9aa3acc16ede7d2973933770894a7d95d0969912 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 7 Nov 2024 05:23:18 +0000 Subject: [PATCH] Init Signed-off-by: Jee Jee Li --- vllm/model_executor/models/mllama.py | 43 +++++++++++++++++++++++++--- vllm/worker/enc_dec_model_runner.py | 31 ++++++++++++++++++-- vllm/worker/utils.py | 4 --- 3 files changed, 67 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 5fa8d19b97fe8..3c369fc7f4541 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -33,7 +33,7 @@ import vllm.distributed.parallel_state as ps from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.attention.ops.paged_attn import PagedAttention -from vllm.config import CacheConfig, MultiModalConfig +from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs, InputContext, TokenInputs, token_inputs) @@ -48,13 +48,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import SequenceData from vllm.utils import is_list_of from .clip import CLIPMLP -from .interfaces import SupportsMultiModal +from .interfaces import SupportsLoRA, SupportsMultiModal from .llama import LlamaDecoderLayer, LlamaMLP logger = init_logger(__name__) @@ -1083,7 +1084,29 @@ def forward( @INPUT_REGISTRY.register_dummy_data(dummy_decoder_data_for_mllama) @INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama) @INPUT_REGISTRY.register_input_processor(input_processor_for_mllama) -class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): +class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"] + } + + # LoRA specific attributes + supported_lora_modules = [ + # vision encoder + "fc1", + "fc2", + # language model + "qkv_proj", # same name with vision encoder + "gate_up_proj", + "o_proj", # same name with vision encoder + "down_proj", + # projector + "multi_modal_projector", + ] + embedding_modules = {} + embedding_padding_modules = [] + # BitandBytes specific attributes default_bitsandbytes_target_modules = [ ".gate_proj.", @@ -1112,8 +1135,11 @@ def __init__(self, config: config_mllama.MllamaConfig, multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None): super().__init__() + # LRUCacheWorkerLoRAManager instantiation requires self.config. + self.config = config self.vocab_size = config.text_config.vocab_size self.hidden_size = config.text_config.hidden_size self.max_num_tiles = config.vision_config.max_num_tiles @@ -1442,6 +1468,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): default_weight_loader) weight_loader(param, loaded_weight) + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="multi_modal_projector", + tower_model="vision_model") + def skip_attention_mask(sparse_mask: List[List[int]]) -> bool: for mask in sparse_mask: diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 90a43196084ea..c006bcd971251 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -15,6 +15,7 @@ from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.utils import get_architecture_class_name @@ -34,7 +35,7 @@ from vllm.worker.utils import assert_enc_dec_mr_supported_scenario logger = init_logger(__name__) - +LORA_WARMUP_RANK = 8 # The Mllama model has PagedAttention specific logic because of which it # can only be run with the XFORMERS backend # TODO Make Mllama model work with Flash Attention backend. @@ -180,7 +181,11 @@ def execute_model( if num_steps > 1: raise ValueError("num_steps > 1 is not supported in " "EncoderDecoderModelRunner") - + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) if (model_input.attn_metadata is not None and model_input.attn_metadata.prefill_metadata is None and model_input.attn_metadata.decode_metadata.use_cuda_graph): @@ -288,7 +293,25 @@ def profile_run(self) -> None: sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_seqs = self.scheduler_config.max_num_seqs - + dummy_lora_requests: List[LoRARequest] = [] + dummy_lora_requests_per_seq: List[LoRARequest] = [] + if self.lora_config: + assert self.lora_manager is not None + with self.lora_manager.dummy_lora_cache(): + for idx in range(self.lora_config.max_loras): + lora_id = idx + 1 + dummy_lora_request = LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests.append(dummy_lora_request) + dummy_lora_requests_per_seq = [ + dummy_lora_requests[idx % len(dummy_lora_requests)] + for idx in range(max_num_seqs) + ] # Profile memory usage with max_num_sequences sequences and the total # number of tokens equal to max_num_batched_tokens. seqs: List[SequenceGroupMetadata] = [] @@ -337,6 +360,8 @@ def profile_run(self) -> None: block_tables=None, encoder_seq_data=encoder_dummy_data.seq_data, cross_block_table=None, + lora_request=dummy_lora_requests_per_seq[group_id] + if dummy_lora_requests_per_seq else None, multi_modal_data=decoder_dummy_data.multi_modal_data or encoder_dummy_data.multi_modal_data, multi_modal_placeholders=decoder_dummy_data. diff --git a/vllm/worker/utils.py b/vllm/worker/utils.py index f43635464ef00..884f812ee46ae 100644 --- a/vllm/worker/utils.py +++ b/vllm/worker/utils.py @@ -34,10 +34,6 @@ def assert_enc_dec_mr_supported_scenario( STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP'] ) - if enc_dec_mr.lora_config is not None: - raise NotImplementedError( - STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_LORA']) - if enc_dec_mr.parallel_config.pipeline_parallel_size > 1: raise NotImplementedError( STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP'])