From 960c605ca0073a5dd1adf856add5ef5a495aaf6e Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 25 Oct 2024 14:12:56 +0800 Subject: [PATCH 01/16] fix phi3 gguf --- vllm/model_executor/layers/linear.py | 48 ++++++++++++++++++---------- vllm/model_executor/models/llama.py | 3 +- 2 files changed, 34 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 94f30412e43b3..1808dfc53072e 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -440,8 +440,11 @@ def weight_loader(self, is_gguf_weight = getattr(param, "is_gguf_weight", False) is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) if is_gguf_weight_type: - param.data[loaded_shard_id].copy_(loaded_weight) - param.shard_weight_type[loaded_shard_id] = loaded_weight.item() + if loaded_shard_id is not None: + param.data[loaded_shard_id].copy_(loaded_weight) + param.shard_weight_type[loaded_shard_id] = loaded_weight.item() + else: + param.weight_type = loaded_weight.item() return if is_gguf_weight: @@ -455,11 +458,16 @@ def weight_loader(self, loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) - param.shard_id.append(loaded_shard_id) - param.shard_id_map[loaded_shard_id] = len(param.data_container) - param.data_container.append(loaded_weight) - if len(param.data_container) == 2: - self.qweight = param.materialize_nested() + if loaded_shard_id is not None: + param.shard_id.append(loaded_shard_id) + param.shard_id_map[loaded_shard_id] = len(param.data_container) + param.data_container.append(loaded_weight) + if len(param.data_container) == 2: + self.qweight = param.materialize_nested() + else: + param.materialize(loaded_weight.shape, + dtype=loaded_weight.dtype) + param.data.copy_(loaded_weight) return param_data = param.data @@ -775,10 +783,13 @@ def weight_loader(self, # initialize GGUF param after we know the quantize type is_gguf_weight = getattr(param, "is_gguf_weight", False) is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) - if is_gguf_weight_type and loaded_shard_id is not None: - idx_map = {"q": 0, "k": 1, "v": 2} - param.data[idx_map[loaded_shard_id]].copy_(loaded_weight) - param.shard_weight_type[loaded_shard_id] = loaded_weight.item() + if is_gguf_weight_type: + if loaded_shard_id is not None: + idx_map = {"q": 0, "k": 1, "v": 2} + param.data[idx_map[loaded_shard_id]].copy_(loaded_weight) + param.shard_weight_type[loaded_shard_id] = loaded_weight.item() + else: + param.weight_type = loaded_weight.item() return if is_gguf_weight: @@ -792,11 +803,16 @@ def weight_loader(self, loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) - param.shard_id.append(loaded_shard_id) - param.shard_id_map[loaded_shard_id] = len(param.data_container) - param.data_container.append(loaded_weight) - if len(param.data_container) == 3: - self.qweight = param.materialize_nested() + if loaded_shard_id is not None: + param.shard_id.append(loaded_shard_id) + param.shard_id_map[loaded_shard_id] = len(param.data_container) + param.data_container.append(loaded_weight) + if len(param.data_container) == 3: + self.qweight = param.materialize_nested() + else: + param.materialize(loaded_weight.shape, + dtype=loaded_weight.dtype) + param.data.copy_(loaded_weight) return param_data = param.data diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index c346e3e808e3f..08964fa4f706e 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -156,7 +156,8 @@ def __init__( ) is_neox_style = True - if quant_config is not None and quant_config.get_name() == "gguf": + is_gguf = quant_config and quant_config.get_name() == "gguf" + if is_gguf and config.model_type == "llama": is_neox_style = False self.rotary_emb = get_rope( From 0551d73504b19738a99892740ef00e43f2a20e31 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 25 Oct 2024 14:26:54 +0800 Subject: [PATCH 02/16] update gguf example --- examples/gguf_inference.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/examples/gguf_inference.py b/examples/gguf_inference.py index 09a5fcc22e553..aa05c4c0bfaa5 100644 --- a/examples/gguf_inference.py +++ b/examples/gguf_inference.py @@ -3,27 +3,20 @@ from vllm import LLM, SamplingParams -def run_gguf_inference(model_path): - PROMPT_TEMPLATE = "<|system|>\n{system_message}</s>\n<|user|>\n{prompt}</s>\n<|assistant|>\n" # noqa: E501 - system_message = "You are a friendly chatbot who always responds in the style of a pirate." # noqa: E501 +def run_gguf_inference(model_path, tokenizer): # Sample prompts. prompts = [ "How many helicopters can a human eat in one sitting?", "What's the future of AI?", ] - prompts = [ - PROMPT_TEMPLATE.format(system_message=system_message, prompt=prompt) - for prompt in prompts - ] + prompts = [[{"role": "user", "content": prompt}] for prompt in prompts] # Create a sampling params object. sampling_params = SamplingParams(temperature=0, max_tokens=128) # Create an LLM. - llm = LLM(model=model_path, - tokenizer="TinyLlama/TinyLlama-1.1B-Chat-v1.0", - gpu_memory_utilization=0.95) + llm = LLM(model=model_path, tokenizer=tokenizer) - outputs = llm.generate(prompts, sampling_params) + outputs = llm.chat(prompts, sampling_params) # Print the outputs. for output in outputs: prompt = output.prompt @@ -32,7 +25,8 @@ def run_gguf_inference(model_path): if __name__ == "__main__": - repo_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF" - filename = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf" + repo_id = "bartowski/Phi-3-medium-4k-instruct-GGUF" + filename = "Phi-3-medium-4k-instruct-IQ2_M.gguf" + tokenizer = "microsoft/Phi-3-medium-4k-instruct" model = hf_hub_download(repo_id, filename=filename) - run_gguf_inference(model) + run_gguf_inference(model, tokenizer) From 1ed74daf1f6890dea1145e89eb463e2411a9ef9d Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 25 Oct 2024 15:21:37 +0800 Subject: [PATCH 03/16] fix stablelm and starcoder2 --- vllm/model_executor/models/stablelm.py | 4 +++- vllm/model_executor/models/starcoder2.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 083a48588d01a..1ef192e136b7f 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -62,7 +62,8 @@ def __init__(self, quant_config=quant_config) self.down_proj = RowParallelLinear(config.intermediate_size, config.hidden_size, - bias=False) + bias=False, + quant_config=quant_config) self.act_fn = SiluAndMul() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -203,6 +204,7 @@ def __init__(self, self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, + quant_config=quant_config, ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 8f0644bca3e2e..123bcd0a42f91 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -209,7 +209,8 @@ def __init__(self, # TODO: consider padding_idx (currently removed) self.embed_tokens = VocabParallelEmbedding(config.vocab_size, - config.hidden_size) + config.hidden_size, + quant_config=quant_config) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: Starcoder2DecoderLayer( From b3f0e43be9967269a10d1cfdd28a38ead4188eb4 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Sat, 26 Oct 2024 12:07:55 +0800 Subject: [PATCH 04/16] fix gpt2 --- vllm/model_executor/models/gpt2.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 3330d84021368..8b42918c7cc0d 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -199,7 +199,7 @@ def __init__( assert not config.scale_attn_by_inverse_layer_idx assert not config.reorder_and_upcast_attn self.embed_dim = config.hidden_size - self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) + self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim, quant_config=quant_config) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, @@ -259,7 +259,8 @@ def __init__( self.lm_head = self.transformer.wte else: self.lm_head = ParallelLMHead(self.config.vocab_size, - self.config.hidden_size) + self.config.hidden_size, + quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() self.make_empty_intermediate_tensors = ( @@ -297,7 +298,7 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in weights: - if "lm_head.weight" in name: + if "lm_head" in name: # GPT-2 ties the weights of the embedding layer and the final # linear layer. continue From 01dc5c8bc2dfbbdfcdc6f51764ca3db5942af5fc Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Mon, 28 Oct 2024 17:27:36 +0800 Subject: [PATCH 05/16] refactor gguf test --- .../models/decoder_only/language/test_gguf.py | 88 ++++++++++++------- vllm/model_executor/models/gpt2.py | 4 +- 2 files changed, 60 insertions(+), 32 deletions(-) diff --git a/tests/models/decoder_only/language/test_gguf.py b/tests/models/decoder_only/language/test_gguf.py index 5dc83942632fd..335b91696f437 100644 --- a/tests/models/decoder_only/language/test_gguf.py +++ b/tests/models/decoder_only/language/test_gguf.py @@ -4,33 +4,70 @@ """ import os +from typing import List, NamedTuple, Type import pytest from huggingface_hub import hf_hub_download -from transformers import AutoTokenizer from tests.quantization.utils import is_quant_method_supported +from ....conftest import VllmRunner from ...utils import check_logprobs_close os.environ["TOKENIZERS_PARALLELISM"] = "true" MAX_MODEL_LEN = 1024 -# FIXME: Move this to confest + +class GGUFTestConfig(NamedTuple): + original_model: str + gguf_repo: str + gguf_filename: str + + @property + def gguf_model(self): + return hf_hub_download(self.gguf_repo, filename=self.gguf_filename) + + +LLAMA_CONFIG = GGUFTestConfig( + original_model="meta-llama/Llama-3.2-1B-Instruct", + gguf_repo="bartowski/Llama-3.2-1B-Instruct-GGUF", + gguf_filename="Llama-3.2-1B-Instruct-IQ4_XS.gguf", +) + +QWEN2_CONFIG = GGUFTestConfig( + original_model="Qwen/Qwen2-1.5B-Instruct", + gguf_repo="Qwen/Qwen2.5-1.5B-Instruct-GGUF", + gguf_filename="qwen2.5-1.5b-instruct-q4_k_m.gguf", +) + +PHI3_CONFIG = GGUFTestConfig( + original_model="microsoft/Phi-3.5-mini-instruct", + gguf_repo="bartowski/Phi-3.5-mini-instruct-GGUF", + gguf_filename="Phi-3.5-mini-instruct-IQ4_XS.gguf", +) + +GPT2_CONFIG = GGUFTestConfig( + original_model="openai-community/gpt2-large", + gguf_repo="QuantFactory/gpt2-large-GGUF", + gguf_filename="gpt2-large.Q4_K_M.gguf", +) + +STABLELM_CONFIG = GGUFTestConfig( + original_model="afrideva/stablelm-3b-4e1t-GGUF", + gguf_repo="afrideva/stablelm-3b-4e1t-GGUF", + gguf_filename="stablelm-3b-4e1t.q4_k_m.gguf", +) + +STARCODER_CONFIG = GGUFTestConfig( + original_model="bigcode/starcoder2-3b", + gguf_repo="QuantFactory/starcoder2-3b-GGUF", + gguf_filename="starcoder2-3b.Q4_K_M.gguf", +) + MODELS = [ - ("meta-llama/Llama-3.2-1B-Instruct", - hf_hub_download("bartowski/Llama-3.2-1B-Instruct-GGUF", - filename="Llama-3.2-1B-Instruct-Q4_K_M.gguf")), - ("meta-llama/Llama-3.2-1B-Instruct", - hf_hub_download("bartowski/Llama-3.2-1B-Instruct-GGUF", - filename="Llama-3.2-1B-Instruct-IQ4_XS.gguf")), - ("Qwen/Qwen2-1.5B-Instruct", - hf_hub_download("Qwen/Qwen2-1.5B-Instruct-GGUF", - filename="qwen2-1_5b-instruct-q4_k_m.gguf")), - ("Qwen/Qwen2-1.5B-Instruct", - hf_hub_download("legraphista/Qwen2-1.5B-Instruct-IMat-GGUF", - filename="Qwen2-1.5B-Instruct.IQ4_XS.gguf")), + LLAMA_CONFIG, QWEN2_CONFIG, PHI3_CONFIG, GPT2_CONFIG, STABLELM_CONFIG, + STARCODER_CONFIG ] @@ -42,10 +79,10 @@ @pytest.mark.parametrize("num_logprobs", [5]) @pytest.mark.parametrize("tp_size", [1, 2]) def test_models( - num_gpus_available, - vllm_runner, - example_prompts, - model, + num_gpus_available: int, + vllm_runner: Type[VllmRunner], + example_prompts: List[str], + model: GGUFTestConfig, dtype: str, max_tokens: int, num_logprobs: int, @@ -54,19 +91,8 @@ def test_models( if num_gpus_available < tp_size: pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") - original_model, gguf_model = model - - tokenizer = AutoTokenizer.from_pretrained(original_model) - messages = [[{ - 'role': 'user', - 'content': prompt - }] for prompt in example_prompts] - example_prompts = tokenizer.apply_chat_template(messages, - tokenize=False, - add_generation_prompt=True) - # Run unquantized model. - with vllm_runner(model_name=original_model, + with vllm_runner(model_name=model.original_model, dtype=dtype, max_model_len=MAX_MODEL_LEN, tensor_parallel_size=tp_size) as original_model: @@ -75,7 +101,7 @@ def test_models( example_prompts[:-1], max_tokens, num_logprobs) # Run gguf model. - with vllm_runner(model_name=gguf_model, + with vllm_runner(model_name=model.gguf_model, dtype=dtype, max_model_len=MAX_MODEL_LEN, tensor_parallel_size=tp_size) as gguf_model: diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 8b42918c7cc0d..398b0714187e1 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -199,7 +199,9 @@ def __init__( assert not config.scale_attn_by_inverse_layer_idx assert not config.reorder_and_upcast_attn self.embed_dim = config.hidden_size - self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim, quant_config=quant_config) + self.wte = VocabParallelEmbedding(config.vocab_size, + self.embed_dim, + quant_config=quant_config) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, From e3e858111ee14959558b6158194a1774de25e39c Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Tue, 29 Oct 2024 01:16:24 +0800 Subject: [PATCH 06/16] fix failed tests --- .../models/decoder_only/language/test_gguf.py | 26 ++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/tests/models/decoder_only/language/test_gguf.py b/tests/models/decoder_only/language/test_gguf.py index 335b91696f437..76aa3b2c74a38 100644 --- a/tests/models/decoder_only/language/test_gguf.py +++ b/tests/models/decoder_only/language/test_gguf.py @@ -8,6 +8,7 @@ import pytest from huggingface_hub import hf_hub_download +from transformers import AutoTokenizer from tests.quantization.utils import is_quant_method_supported @@ -36,7 +37,7 @@ def gguf_model(self): ) QWEN2_CONFIG = GGUFTestConfig( - original_model="Qwen/Qwen2-1.5B-Instruct", + original_model="Qwen/Qwen2.5-1.5B-Instruct", gguf_repo="Qwen/Qwen2.5-1.5B-Instruct-GGUF", gguf_filename="qwen2.5-1.5b-instruct-q4_k_m.gguf", ) @@ -54,7 +55,7 @@ def gguf_model(self): ) STABLELM_CONFIG = GGUFTestConfig( - original_model="afrideva/stablelm-3b-4e1t-GGUF", + original_model="stabilityai/stablelm-3b-4e1t", gguf_repo="afrideva/stablelm-3b-4e1t-GGUF", gguf_filename="stablelm-3b-4e1t.q4_k_m.gguf", ) @@ -62,12 +63,16 @@ def gguf_model(self): STARCODER_CONFIG = GGUFTestConfig( original_model="bigcode/starcoder2-3b", gguf_repo="QuantFactory/starcoder2-3b-GGUF", - gguf_filename="starcoder2-3b.Q4_K_M.gguf", + gguf_filename="starcoder2-3b.Q6_K.gguf", ) MODELS = [ - LLAMA_CONFIG, QWEN2_CONFIG, PHI3_CONFIG, GPT2_CONFIG, STABLELM_CONFIG, - STARCODER_CONFIG + LLAMA_CONFIG, + QWEN2_CONFIG, + PHI3_CONFIG, + GPT2_CONFIG, + STABLELM_CONFIG, + STARCODER_CONFIG, ] @@ -91,17 +96,26 @@ def test_models( if num_gpus_available < tp_size: pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") + tokenizer = AutoTokenizer.from_pretrained(model.original_model) + if tokenizer.chat_template is not None: + messages = [[{ + 'role': 'user', + 'content': prompt + }] for prompt in example_prompts] + example_prompts = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True) + # Run unquantized model. with vllm_runner(model_name=model.original_model, dtype=dtype, max_model_len=MAX_MODEL_LEN, tensor_parallel_size=tp_size) as original_model: - original_outputs = original_model.generate_greedy_logprobs( example_prompts[:-1], max_tokens, num_logprobs) # Run gguf model. with vllm_runner(model_name=model.gguf_model, + tokenizer_name=model.original_model, dtype=dtype, max_model_len=MAX_MODEL_LEN, tensor_parallel_size=tp_size) as gguf_model: From 3eb08eb06ef52f7586c8be8da37babb212409b63 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 31 Oct 2024 15:25:05 +0800 Subject: [PATCH 07/16] fix gguf test --- tests/models/decoder_only/language/test_gguf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/decoder_only/language/test_gguf.py b/tests/models/decoder_only/language/test_gguf.py index 76aa3b2c74a38..81b93ebdf0fc0 100644 --- a/tests/models/decoder_only/language/test_gguf.py +++ b/tests/models/decoder_only/language/test_gguf.py @@ -39,7 +39,7 @@ def gguf_model(self): QWEN2_CONFIG = GGUFTestConfig( original_model="Qwen/Qwen2.5-1.5B-Instruct", gguf_repo="Qwen/Qwen2.5-1.5B-Instruct-GGUF", - gguf_filename="qwen2.5-1.5b-instruct-q4_k_m.gguf", + gguf_filename="qwen2.5-1.5b-instruct-q6_k.gguf", ) PHI3_CONFIG = GGUFTestConfig( @@ -72,7 +72,7 @@ def gguf_model(self): PHI3_CONFIG, GPT2_CONFIG, STABLELM_CONFIG, - STARCODER_CONFIG, + # STARCODER_CONFIG, # broken ] From 06b843d19a089c811f7a6d88166554c952c404ed Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 1 Nov 2024 00:26:35 +0800 Subject: [PATCH 08/16] add transformers flag for gguf test --- tests/models/decoder_only/language/test_gguf.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/models/decoder_only/language/test_gguf.py b/tests/models/decoder_only/language/test_gguf.py index 81b93ebdf0fc0..77a20e83d5275 100644 --- a/tests/models/decoder_only/language/test_gguf.py +++ b/tests/models/decoder_only/language/test_gguf.py @@ -7,6 +7,7 @@ from typing import List, NamedTuple, Type import pytest +import transformers from huggingface_hub import hf_hub_download from transformers import AutoTokenizer @@ -24,12 +25,15 @@ class GGUFTestConfig(NamedTuple): original_model: str gguf_repo: str gguf_filename: str + run_requirement: bool = True @property def gguf_model(self): return hf_hub_download(self.gguf_repo, filename=self.gguf_filename) +TRANSFORMERS_REQUIREMENT = transformers.__version__.startswith("4.46.0") + LLAMA_CONFIG = GGUFTestConfig( original_model="meta-llama/Llama-3.2-1B-Instruct", gguf_repo="bartowski/Llama-3.2-1B-Instruct-GGUF", @@ -46,24 +50,28 @@ def gguf_model(self): original_model="microsoft/Phi-3.5-mini-instruct", gguf_repo="bartowski/Phi-3.5-mini-instruct-GGUF", gguf_filename="Phi-3.5-mini-instruct-IQ4_XS.gguf", + run_requirement=TRANSFORMERS_REQUIREMENT, ) GPT2_CONFIG = GGUFTestConfig( original_model="openai-community/gpt2-large", gguf_repo="QuantFactory/gpt2-large-GGUF", gguf_filename="gpt2-large.Q4_K_M.gguf", + run_requirement=TRANSFORMERS_REQUIREMENT, ) STABLELM_CONFIG = GGUFTestConfig( original_model="stabilityai/stablelm-3b-4e1t", gguf_repo="afrideva/stablelm-3b-4e1t-GGUF", gguf_filename="stablelm-3b-4e1t.q4_k_m.gguf", + run_requirement=TRANSFORMERS_REQUIREMENT, ) STARCODER_CONFIG = GGUFTestConfig( original_model="bigcode/starcoder2-3b", gguf_repo="QuantFactory/starcoder2-3b-GGUF", gguf_filename="starcoder2-3b.Q6_K.gguf", + run_requirement=TRANSFORMERS_REQUIREMENT, ) MODELS = [ @@ -96,6 +104,10 @@ def test_models( if num_gpus_available < tp_size: pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") + if not model.run_requirement: + pytest.skip( + f"Model not supported in transformers=={transformers.__version__}") + tokenizer = AutoTokenizer.from_pretrained(model.original_model) if tokenizer.chat_template is not None: messages = [[{ From 37004618aba000337bb64666d91c8cb79e77fc47 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 1 Nov 2024 01:14:11 +0800 Subject: [PATCH 09/16] make transformers flag more robust --- tests/models/decoder_only/language/test_gguf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/decoder_only/language/test_gguf.py b/tests/models/decoder_only/language/test_gguf.py index 77a20e83d5275..d50c42b3363f1 100644 --- a/tests/models/decoder_only/language/test_gguf.py +++ b/tests/models/decoder_only/language/test_gguf.py @@ -10,6 +10,7 @@ import transformers from huggingface_hub import hf_hub_download from transformers import AutoTokenizer +from packaging.version import parse from tests.quantization.utils import is_quant_method_supported @@ -32,7 +33,7 @@ def gguf_model(self): return hf_hub_download(self.gguf_repo, filename=self.gguf_filename) -TRANSFORMERS_REQUIREMENT = transformers.__version__.startswith("4.46.0") +TRANSFORMERS_REQUIREMENT = parse(transformers.__version__) >= parse("4.46.0") LLAMA_CONFIG = GGUFTestConfig( original_model="meta-llama/Llama-3.2-1B-Instruct", From f76ea3b7ca5510b20f4b1038a26d0557d09116fd Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 1 Nov 2024 01:17:18 +0800 Subject: [PATCH 10/16] code format --- tests/models/decoder_only/language/test_gguf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/decoder_only/language/test_gguf.py b/tests/models/decoder_only/language/test_gguf.py index d50c42b3363f1..8311d60dda0bf 100644 --- a/tests/models/decoder_only/language/test_gguf.py +++ b/tests/models/decoder_only/language/test_gguf.py @@ -9,8 +9,8 @@ import pytest import transformers from huggingface_hub import hf_hub_download -from transformers import AutoTokenizer from packaging.version import parse +from transformers import AutoTokenizer from tests.quantization.utils import is_quant_method_supported From 5a5aa2ce5d178ee464fe6e10a968194a10412338 Mon Sep 17 00:00:00 2001 From: Isotr0py <Isotr0py@outlook.com> Date: Fri, 1 Nov 2024 01:19:02 +0800 Subject: [PATCH 11/16] Update vllm/model_executor/models/gpt2.py Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> --- vllm/model_executor/models/gpt2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 398b0714187e1..59a3f6b65fe26 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -300,7 +300,7 @@ def sample( def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in weights: - if "lm_head" in name: + if name.startswith("lm_head"): # GPT-2 ties the weights of the embedding layer and the final # linear layer. continue From f7c75e689219568e5714a361c9553cfed0fd6e3e Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 1 Nov 2024 15:27:11 +0800 Subject: [PATCH 12/16] stablelm add prefix --- vllm/model_executor/models/stablelm.py | 54 +++++++++++++++++--------- 1 file changed, 35 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 1ef192e136b7f..c7435338b2e69 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -23,7 +23,7 @@ import torch from torch import nn -from transformers import PretrainedConfig +from transformers import StableLmConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig @@ -50,8 +50,9 @@ class StablelmMLP(nn.Module): def __init__(self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None) -> None: + config: StableLmConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -59,11 +60,13 @@ def __init__(self, self.gate_up_proj = MergedColumnParallelLinear( config.hidden_size, [config.intermediate_size] * 2, bias=False, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") self.down_proj = RowParallelLinear(config.intermediate_size, config.hidden_size, bias=False, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.down_proj") self.act_fn = SiluAndMul() def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -76,9 +79,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class StablelmAttention(nn.Module): def __init__(self, - config: PretrainedConfig, + config: StableLmConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -116,11 +120,13 @@ def __init__(self, self.total_num_heads, self.total_num_key_value_heads, self.qkv_bias, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj") self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, self.hidden_size, bias=False, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.o_proj") self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.rotary_ndims, @@ -153,13 +159,17 @@ class StablelmDecoderLayer(nn.Module): def __init__( self, - config: PretrainedConfig, + config: StableLmConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() - self.self_attn = StablelmAttention(config, cache_config, quant_config) - self.mlp = StablelmMLP(config, quant_config) + self.self_attn = StablelmAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.self_attn") + self.mlp = StablelmMLP(config, quant_config, prefix=f"{prefix}.mlp") norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05)) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps) @@ -196,20 +206,21 @@ def forward( class StableLMEpochModel(nn.Module): def __init__(self, - config: PretrainedConfig, + config: StableLmConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - prefix: str = '') -> None: + prefix: str = "") -> None: super().__init__() self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, quant_config=quant_config, + prefix=f"{prefix}.embed_tokens", ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: StablelmDecoderLayer(config, cache_config, - quant_config), + lambda prefix: StablelmDecoderLayer( + config, cache_config, quant_config, prefix=f"{prefix}.layers"), prefix=f"{prefix}.layers", ) norm_eps = getattr(config, "norm_eps", @@ -250,17 +261,22 @@ class StablelmForCausalLM(nn.Module, SupportsPP): def __init__( self, - config: PretrainedConfig, + config: StableLmConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = StableLMEpochModel(config, cache_config, quant_config) + self.model = StableLMEpochModel(config, + cache_config, + quant_config, + prefix=f"{prefix}.model") self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.lm_head") if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) From f4c78cd41eba8e946eaaf0e88470cd1e3a37e085 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Fri, 1 Nov 2024 15:49:39 +0800 Subject: [PATCH 13/16] add prefix for gpt2 and starcoder2 --- vllm/model_executor/models/gpt2.py | 9 ++++-- vllm/model_executor/models/stablelm.py | 2 +- vllm/model_executor/models/starcoder2.py | 38 +++++++++++++++++------- 3 files changed, 34 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 59a3f6b65fe26..f8f80ee86447d 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -201,7 +201,8 @@ def __init__( self.embed_dim = config.hidden_size self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.wte") self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, @@ -249,6 +250,7 @@ def __init__( config: GPT2Config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -256,13 +258,14 @@ def __init__( self.transformer = GPT2Model(config, cache_config, quant_config, - prefix="transformer") + prefix=f"{prefix}.transformer") if self.config.tie_word_embeddings: self.lm_head = self.transformer.wte else: self.lm_head = ParallelLMHead(self.config.vocab_size, self.config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.lm_head") self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() self.make_empty_intermediate_tensors = ( diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index c7435338b2e69..90b9449b79e79 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -220,7 +220,7 @@ def __init__(self, self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: StablelmDecoderLayer( - config, cache_config, quant_config, prefix=f"{prefix}.layers"), + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers", ) norm_eps = getattr(config, "norm_eps", diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 123bcd0a42f91..5f5f818c84f48 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -52,7 +52,8 @@ class Starcoder2Attention(nn.Module): def __init__(self, config: Starcoder2Config, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.config = config @@ -86,12 +87,14 @@ def __init__(self, self.total_num_kv_heads, bias=self.use_bias, quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, self.hidden_size, bias=self.use_bias, quant_config=quant_config, + prefix=f"{prefix}.o_proj", ) self.rotary_emb = get_rope( self.head_dim, @@ -126,19 +129,22 @@ class Starcoder2MLP(nn.Module): def __init__(self, config: Starcoder2Config, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.c_fc = ColumnParallelLinear( config.hidden_size, config.intermediate_size, bias=config.use_bias, quant_config=quant_config, + prefix=f"{prefix}.c_fc", ) self.c_proj = RowParallelLinear( config.intermediate_size, config.hidden_size, bias=config.use_bias, quant_config=quant_config, + prefix=f"{prefix}.c_proj", ) self.act = get_act_fn(config.hidden_act, quant_config, config.intermediate_size) @@ -155,13 +161,17 @@ class Starcoder2DecoderLayer(nn.Module): def __init__(self, config: Starcoder2Config, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.hidden_size = config.hidden_size self.self_attn = Starcoder2Attention(config, cache_config, - quant_config=quant_config) - self.mlp = Starcoder2MLP(config, quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.self_attn") + self.mlp = Starcoder2MLP(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, @@ -208,13 +218,16 @@ def __init__(self, self.vocab_size = config.vocab_size # TODO: consider padding_idx (currently removed) - self.embed_tokens = VocabParallelEmbedding(config.vocab_size, - config.hidden_size, - quant_config=quant_config) + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: Starcoder2DecoderLayer( - config, cache_config, quant_config=quant_config), + config, cache_config, quant_config=quant_config, prefix=prefix + ), prefix=f"{prefix}.layers", ) self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) @@ -251,12 +264,14 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP): def __init__(self, config: Starcoder2Config, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.config = config self.model = Starcoder2Model(config, cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.model") self.vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size if config.tie_word_embeddings: @@ -269,6 +284,7 @@ def __init__(self, org_num_embeddings=config.vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, quant_config=quant_config, + prefix=f"{prefix}.lm_head", ) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) From f8add6ff944d0735f74ae7cadab03a7225624abc Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Wed, 6 Nov 2024 00:15:31 +0800 Subject: [PATCH 14/16] fix phi-3 tp --- vllm/model_executor/layers/linear.py | 34 ++++++++++++---------------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 1808dfc53072e..0492a7951a990 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -444,7 +444,10 @@ def weight_loader(self, param.data[loaded_shard_id].copy_(loaded_weight) param.shard_weight_type[loaded_shard_id] = loaded_weight.item() else: - param.weight_type = loaded_weight.item() + param.shard_weight_type = { + i: loaded_weight.item() + for i, _ in enumerate(self.output_sizes) + } return if is_gguf_weight: @@ -455,20 +458,15 @@ def weight_loader(self, shard_size = loaded_weight.size(output_dim) // tp_size start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) - if loaded_shard_id is not None: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) param.shard_id.append(loaded_shard_id) param.shard_id_map[loaded_shard_id] = len(param.data_container) param.data_container.append(loaded_weight) if len(param.data_container) == 2: self.qweight = param.materialize_nested() - else: - param.materialize(loaded_weight.shape, - dtype=loaded_weight.dtype) - param.data.copy_(loaded_weight) - return + return param_data = param.data output_dim = getattr(param, "output_dim", None) @@ -784,12 +782,15 @@ def weight_loader(self, is_gguf_weight = getattr(param, "is_gguf_weight", False) is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False) if is_gguf_weight_type: + idx_map = {"q": 0, "k": 1, "v": 2} if loaded_shard_id is not None: - idx_map = {"q": 0, "k": 1, "v": 2} param.data[idx_map[loaded_shard_id]].copy_(loaded_weight) param.shard_weight_type[loaded_shard_id] = loaded_weight.item() else: - param.weight_type = loaded_weight.item() + param.shard_weight_type = { + k: loaded_weight.item() + for k in idx_map + } return if is_gguf_weight: @@ -800,20 +801,15 @@ def weight_loader(self, shard_size = loaded_weight.size(output_dim) // tp_size start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) - if loaded_shard_id is not None: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) param.shard_id.append(loaded_shard_id) param.shard_id_map[loaded_shard_id] = len(param.data_container) param.data_container.append(loaded_weight) if len(param.data_container) == 3: self.qweight = param.materialize_nested() - else: - param.materialize(loaded_weight.shape, - dtype=loaded_weight.dtype) - param.data.copy_(loaded_weight) - return + return param_data = param.data output_dim = getattr(param, "output_dim", None) From 01f2250c48d8af14aa574cb24e379bf5dfa8cb01 Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 5 Dec 2024 22:45:49 +0800 Subject: [PATCH 15/16] code format Signed-off-by: Isotr0py <2037008807@qq.com> --- tests/models/decoder_only/language/test_gguf.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/models/decoder_only/language/test_gguf.py b/tests/models/decoder_only/language/test_gguf.py index 9ccba994f287f..ea503b6562e07 100644 --- a/tests/models/decoder_only/language/test_gguf.py +++ b/tests/models/decoder_only/language/test_gguf.py @@ -7,9 +7,7 @@ from typing import List, NamedTuple, Type import pytest -import transformers from huggingface_hub import hf_hub_download -from packaging.version import parse from transformers import AutoTokenizer from tests.quantization.utils import is_quant_method_supported @@ -26,7 +24,6 @@ class GGUFTestConfig(NamedTuple): original_model: str gguf_repo: str gguf_filename: str - run_requirement: bool = True @property def gguf_model(self): @@ -49,28 +46,24 @@ def gguf_model(self): original_model="microsoft/Phi-3.5-mini-instruct", gguf_repo="bartowski/Phi-3.5-mini-instruct-GGUF", gguf_filename="Phi-3.5-mini-instruct-IQ4_XS.gguf", - run_requirement=TRANSFORMERS_REQUIREMENT, ) GPT2_CONFIG = GGUFTestConfig( original_model="openai-community/gpt2-large", gguf_repo="QuantFactory/gpt2-large-GGUF", gguf_filename="gpt2-large.Q4_K_M.gguf", - run_requirement=TRANSFORMERS_REQUIREMENT, ) STABLELM_CONFIG = GGUFTestConfig( original_model="stabilityai/stablelm-3b-4e1t", gguf_repo="afrideva/stablelm-3b-4e1t-GGUF", gguf_filename="stablelm-3b-4e1t.q4_k_m.gguf", - run_requirement=TRANSFORMERS_REQUIREMENT, ) STARCODER_CONFIG = GGUFTestConfig( original_model="bigcode/starcoder2-3b", gguf_repo="QuantFactory/starcoder2-3b-GGUF", gguf_filename="starcoder2-3b.Q6_K.gguf", - run_requirement=TRANSFORMERS_REQUIREMENT, ) MODELS = [ From 53b21ede30575020ef2902cc334f10a3c4f87c2a Mon Sep 17 00:00:00 2001 From: Isotr0py <2037008807@qq.com> Date: Thu, 5 Dec 2024 22:48:05 +0800 Subject: [PATCH 16/16] fix test Signed-off-by: Isotr0py <2037008807@qq.com> --- tests/models/decoder_only/language/test_gguf.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/tests/models/decoder_only/language/test_gguf.py b/tests/models/decoder_only/language/test_gguf.py index ea503b6562e07..81b93ebdf0fc0 100644 --- a/tests/models/decoder_only/language/test_gguf.py +++ b/tests/models/decoder_only/language/test_gguf.py @@ -78,18 +78,7 @@ def gguf_model(self): @pytest.mark.skipif(not is_quant_method_supported("gguf"), reason="gguf is not supported on this GPU type.") -@pytest.mark.parametrize(("original_model", "gguf_id", "gguf_path"), [ - ("meta-llama/Llama-3.2-1B-Instruct", - "bartowski/Llama-3.2-1B-Instruct-GGUF", - "Llama-3.2-1B-Instruct-Q4_K_M.gguf"), - ("meta-llama/Llama-3.2-1B-Instruct", - "bartowski/Llama-3.2-1B-Instruct-GGUF", - "Llama-3.2-1B-Instruct-IQ4_XS.gguf"), - ("Qwen/Qwen2-1.5B-Instruct", "Qwen/Qwen2-1.5B-Instruct-GGUF", - "qwen2-1_5b-instruct-q4_k_m.gguf"), - ("Qwen/Qwen2-1.5B-Instruct", "legraphista/Qwen2-1.5B-Instruct-IMat-GGUF", - "Qwen2-1.5B-Instruct.IQ4_XS.gguf"), -]) +@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [32]) @pytest.mark.parametrize("num_logprobs", [5])