From 960c605ca0073a5dd1adf856add5ef5a495aaf6e Mon Sep 17 00:00:00 2001
From: Isotr0py <2037008807@qq.com>
Date: Fri, 25 Oct 2024 14:12:56 +0800
Subject: [PATCH 01/16] fix phi3 gguf

---
 vllm/model_executor/layers/linear.py | 48 ++++++++++++++++++----------
 vllm/model_executor/models/llama.py  |  3 +-
 2 files changed, 34 insertions(+), 17 deletions(-)

diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py
index 94f30412e43b3..1808dfc53072e 100644
--- a/vllm/model_executor/layers/linear.py
+++ b/vllm/model_executor/layers/linear.py
@@ -440,8 +440,11 @@ def weight_loader(self,
         is_gguf_weight = getattr(param, "is_gguf_weight", False)
         is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
         if is_gguf_weight_type:
-            param.data[loaded_shard_id].copy_(loaded_weight)
-            param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
+            if loaded_shard_id is not None:
+                param.data[loaded_shard_id].copy_(loaded_weight)
+                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
+            else:
+                param.weight_type = loaded_weight.item()
             return
 
         if is_gguf_weight:
@@ -455,11 +458,16 @@ def weight_loader(self,
             loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                  shard_size)
 
-            param.shard_id.append(loaded_shard_id)
-            param.shard_id_map[loaded_shard_id] = len(param.data_container)
-            param.data_container.append(loaded_weight)
-            if len(param.data_container) == 2:
-                self.qweight = param.materialize_nested()
+            if loaded_shard_id is not None:
+                param.shard_id.append(loaded_shard_id)
+                param.shard_id_map[loaded_shard_id] = len(param.data_container)
+                param.data_container.append(loaded_weight)
+                if len(param.data_container) == 2:
+                    self.qweight = param.materialize_nested()
+            else:
+                param.materialize(loaded_weight.shape,
+                                  dtype=loaded_weight.dtype)
+                param.data.copy_(loaded_weight)
             return
 
         param_data = param.data
@@ -775,10 +783,13 @@ def weight_loader(self,
         # initialize GGUF param after we know the quantize type
         is_gguf_weight = getattr(param, "is_gguf_weight", False)
         is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
-        if is_gguf_weight_type and loaded_shard_id is not None:
-            idx_map = {"q": 0, "k": 1, "v": 2}
-            param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
-            param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
+        if is_gguf_weight_type:
+            if loaded_shard_id is not None:
+                idx_map = {"q": 0, "k": 1, "v": 2}
+                param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
+                param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
+            else:
+                param.weight_type = loaded_weight.item()
             return
 
         if is_gguf_weight:
@@ -792,11 +803,16 @@ def weight_loader(self,
             loaded_weight = loaded_weight.narrow(output_dim, start_idx,
                                                  shard_size)
 
-            param.shard_id.append(loaded_shard_id)
-            param.shard_id_map[loaded_shard_id] = len(param.data_container)
-            param.data_container.append(loaded_weight)
-            if len(param.data_container) == 3:
-                self.qweight = param.materialize_nested()
+            if loaded_shard_id is not None:
+                param.shard_id.append(loaded_shard_id)
+                param.shard_id_map[loaded_shard_id] = len(param.data_container)
+                param.data_container.append(loaded_weight)
+                if len(param.data_container) == 3:
+                    self.qweight = param.materialize_nested()
+            else:
+                param.materialize(loaded_weight.shape,
+                                  dtype=loaded_weight.dtype)
+                param.data.copy_(loaded_weight)
             return
 
         param_data = param.data
diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py
index c346e3e808e3f..08964fa4f706e 100644
--- a/vllm/model_executor/models/llama.py
+++ b/vllm/model_executor/models/llama.py
@@ -156,7 +156,8 @@ def __init__(
         )
 
         is_neox_style = True
-        if quant_config is not None and quant_config.get_name() == "gguf":
+        is_gguf = quant_config and quant_config.get_name() == "gguf"
+        if is_gguf and config.model_type == "llama":
             is_neox_style = False
 
         self.rotary_emb = get_rope(

From 0551d73504b19738a99892740ef00e43f2a20e31 Mon Sep 17 00:00:00 2001
From: Isotr0py <2037008807@qq.com>
Date: Fri, 25 Oct 2024 14:26:54 +0800
Subject: [PATCH 02/16] update gguf example

---
 examples/gguf_inference.py | 22 ++++++++--------------
 1 file changed, 8 insertions(+), 14 deletions(-)

diff --git a/examples/gguf_inference.py b/examples/gguf_inference.py
index 09a5fcc22e553..aa05c4c0bfaa5 100644
--- a/examples/gguf_inference.py
+++ b/examples/gguf_inference.py
@@ -3,27 +3,20 @@
 from vllm import LLM, SamplingParams
 
 
-def run_gguf_inference(model_path):
-    PROMPT_TEMPLATE = "<|system|>\n{system_message}</s>\n<|user|>\n{prompt}</s>\n<|assistant|>\n"  # noqa: E501
-    system_message = "You are a friendly chatbot who always responds in the style of a pirate."  # noqa: E501
+def run_gguf_inference(model_path, tokenizer):
     # Sample prompts.
     prompts = [
         "How many helicopters can a human eat in one sitting?",
         "What's the future of AI?",
     ]
-    prompts = [
-        PROMPT_TEMPLATE.format(system_message=system_message, prompt=prompt)
-        for prompt in prompts
-    ]
+    prompts = [[{"role": "user", "content": prompt}] for prompt in prompts]
     # Create a sampling params object.
     sampling_params = SamplingParams(temperature=0, max_tokens=128)
 
     # Create an LLM.
-    llm = LLM(model=model_path,
-              tokenizer="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
-              gpu_memory_utilization=0.95)
+    llm = LLM(model=model_path, tokenizer=tokenizer)
 
-    outputs = llm.generate(prompts, sampling_params)
+    outputs = llm.chat(prompts, sampling_params)
     # Print the outputs.
     for output in outputs:
         prompt = output.prompt
@@ -32,7 +25,8 @@ def run_gguf_inference(model_path):
 
 
 if __name__ == "__main__":
-    repo_id = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF"
-    filename = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf"
+    repo_id = "bartowski/Phi-3-medium-4k-instruct-GGUF"
+    filename = "Phi-3-medium-4k-instruct-IQ2_M.gguf"
+    tokenizer = "microsoft/Phi-3-medium-4k-instruct"
     model = hf_hub_download(repo_id, filename=filename)
-    run_gguf_inference(model)
+    run_gguf_inference(model, tokenizer)

From 1ed74daf1f6890dea1145e89eb463e2411a9ef9d Mon Sep 17 00:00:00 2001
From: Isotr0py <2037008807@qq.com>
Date: Fri, 25 Oct 2024 15:21:37 +0800
Subject: [PATCH 03/16] fix stablelm and starcoder2

---
 vllm/model_executor/models/stablelm.py   | 4 +++-
 vllm/model_executor/models/starcoder2.py | 3 ++-
 2 files changed, 5 insertions(+), 2 deletions(-)

diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py
index 083a48588d01a..1ef192e136b7f 100644
--- a/vllm/model_executor/models/stablelm.py
+++ b/vllm/model_executor/models/stablelm.py
@@ -62,7 +62,8 @@ def __init__(self,
             quant_config=quant_config)
         self.down_proj = RowParallelLinear(config.intermediate_size,
                                            config.hidden_size,
-                                           bias=False)
+                                           bias=False,
+                                           quant_config=quant_config)
         self.act_fn = SiluAndMul()
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -203,6 +204,7 @@ def __init__(self,
         self.embed_tokens = VocabParallelEmbedding(
             config.vocab_size,
             config.hidden_size,
+            quant_config=quant_config,
         )
         self.start_layer, self.end_layer, self.layers = make_layers(
             config.num_hidden_layers,
diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py
index 8f0644bca3e2e..123bcd0a42f91 100644
--- a/vllm/model_executor/models/starcoder2.py
+++ b/vllm/model_executor/models/starcoder2.py
@@ -209,7 +209,8 @@ def __init__(self,
 
         # TODO: consider padding_idx (currently removed)
         self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
-                                                   config.hidden_size)
+                                                   config.hidden_size,
+                                                   quant_config=quant_config)
         self.start_layer, self.end_layer, self.layers = make_layers(
             config.num_hidden_layers,
             lambda prefix: Starcoder2DecoderLayer(

From b3f0e43be9967269a10d1cfdd28a38ead4188eb4 Mon Sep 17 00:00:00 2001
From: Isotr0py <2037008807@qq.com>
Date: Sat, 26 Oct 2024 12:07:55 +0800
Subject: [PATCH 04/16] fix gpt2

---
 vllm/model_executor/models/gpt2.py | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)

diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py
index 3330d84021368..8b42918c7cc0d 100644
--- a/vllm/model_executor/models/gpt2.py
+++ b/vllm/model_executor/models/gpt2.py
@@ -199,7 +199,7 @@ def __init__(
         assert not config.scale_attn_by_inverse_layer_idx
         assert not config.reorder_and_upcast_attn
         self.embed_dim = config.hidden_size
-        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
+        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim, quant_config=quant_config)
         self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
         self.start_layer, self.end_layer, self.h = make_layers(
             config.num_hidden_layers,
@@ -259,7 +259,8 @@ def __init__(
             self.lm_head = self.transformer.wte
         else:
             self.lm_head = ParallelLMHead(self.config.vocab_size,
-                                          self.config.hidden_size)
+                                          self.config.hidden_size,
+                                          quant_config=quant_config)
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
         self.make_empty_intermediate_tensors = (
@@ -297,7 +298,7 @@ def sample(
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         params_dict = dict(self.named_parameters(remove_duplicate=False))
         for name, loaded_weight in weights:
-            if "lm_head.weight" in name:
+            if "lm_head" in name:
                 # GPT-2 ties the weights of the embedding layer and the final
                 # linear layer.
                 continue

From 01dc5c8bc2dfbbdfcdc6f51764ca3db5942af5fc Mon Sep 17 00:00:00 2001
From: Isotr0py <2037008807@qq.com>
Date: Mon, 28 Oct 2024 17:27:36 +0800
Subject: [PATCH 05/16] refactor gguf test

---
 .../models/decoder_only/language/test_gguf.py | 88 ++++++++++++-------
 vllm/model_executor/models/gpt2.py            |  4 +-
 2 files changed, 60 insertions(+), 32 deletions(-)

diff --git a/tests/models/decoder_only/language/test_gguf.py b/tests/models/decoder_only/language/test_gguf.py
index 5dc83942632fd..335b91696f437 100644
--- a/tests/models/decoder_only/language/test_gguf.py
+++ b/tests/models/decoder_only/language/test_gguf.py
@@ -4,33 +4,70 @@
 """
 
 import os
+from typing import List, NamedTuple, Type
 
 import pytest
 from huggingface_hub import hf_hub_download
-from transformers import AutoTokenizer
 
 from tests.quantization.utils import is_quant_method_supported
 
+from ....conftest import VllmRunner
 from ...utils import check_logprobs_close
 
 os.environ["TOKENIZERS_PARALLELISM"] = "true"
 
 MAX_MODEL_LEN = 1024
 
-# FIXME: Move this to confest
+
+class GGUFTestConfig(NamedTuple):
+    original_model: str
+    gguf_repo: str
+    gguf_filename: str
+
+    @property
+    def gguf_model(self):
+        return hf_hub_download(self.gguf_repo, filename=self.gguf_filename)
+
+
+LLAMA_CONFIG = GGUFTestConfig(
+    original_model="meta-llama/Llama-3.2-1B-Instruct",
+    gguf_repo="bartowski/Llama-3.2-1B-Instruct-GGUF",
+    gguf_filename="Llama-3.2-1B-Instruct-IQ4_XS.gguf",
+)
+
+QWEN2_CONFIG = GGUFTestConfig(
+    original_model="Qwen/Qwen2-1.5B-Instruct",
+    gguf_repo="Qwen/Qwen2.5-1.5B-Instruct-GGUF",
+    gguf_filename="qwen2.5-1.5b-instruct-q4_k_m.gguf",
+)
+
+PHI3_CONFIG = GGUFTestConfig(
+    original_model="microsoft/Phi-3.5-mini-instruct",
+    gguf_repo="bartowski/Phi-3.5-mini-instruct-GGUF",
+    gguf_filename="Phi-3.5-mini-instruct-IQ4_XS.gguf",
+)
+
+GPT2_CONFIG = GGUFTestConfig(
+    original_model="openai-community/gpt2-large",
+    gguf_repo="QuantFactory/gpt2-large-GGUF",
+    gguf_filename="gpt2-large.Q4_K_M.gguf",
+)
+
+STABLELM_CONFIG = GGUFTestConfig(
+    original_model="afrideva/stablelm-3b-4e1t-GGUF",
+    gguf_repo="afrideva/stablelm-3b-4e1t-GGUF",
+    gguf_filename="stablelm-3b-4e1t.q4_k_m.gguf",
+)
+
+STARCODER_CONFIG = GGUFTestConfig(
+    original_model="bigcode/starcoder2-3b",
+    gguf_repo="QuantFactory/starcoder2-3b-GGUF",
+    gguf_filename="starcoder2-3b.Q4_K_M.gguf",
+)
+
 MODELS = [
-    ("meta-llama/Llama-3.2-1B-Instruct",
-     hf_hub_download("bartowski/Llama-3.2-1B-Instruct-GGUF",
-                     filename="Llama-3.2-1B-Instruct-Q4_K_M.gguf")),
-    ("meta-llama/Llama-3.2-1B-Instruct",
-     hf_hub_download("bartowski/Llama-3.2-1B-Instruct-GGUF",
-                     filename="Llama-3.2-1B-Instruct-IQ4_XS.gguf")),
-    ("Qwen/Qwen2-1.5B-Instruct",
-     hf_hub_download("Qwen/Qwen2-1.5B-Instruct-GGUF",
-                     filename="qwen2-1_5b-instruct-q4_k_m.gguf")),
-    ("Qwen/Qwen2-1.5B-Instruct",
-     hf_hub_download("legraphista/Qwen2-1.5B-Instruct-IMat-GGUF",
-                     filename="Qwen2-1.5B-Instruct.IQ4_XS.gguf")),
+    LLAMA_CONFIG, QWEN2_CONFIG, PHI3_CONFIG, GPT2_CONFIG, STABLELM_CONFIG,
+    STARCODER_CONFIG
 ]
 
 
@@ -42,10 +79,10 @@
 @pytest.mark.parametrize("num_logprobs", [5])
 @pytest.mark.parametrize("tp_size", [1, 2])
 def test_models(
-    num_gpus_available,
-    vllm_runner,
-    example_prompts,
-    model,
+    num_gpus_available: int,
+    vllm_runner: Type[VllmRunner],
+    example_prompts: List[str],
+    model: GGUFTestConfig,
     dtype: str,
     max_tokens: int,
     num_logprobs: int,
@@ -54,19 +91,8 @@ def test_models(
     if num_gpus_available < tp_size:
         pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
 
-    original_model, gguf_model = model
-
-    tokenizer = AutoTokenizer.from_pretrained(original_model)
-    messages = [[{
-        'role': 'user',
-        'content': prompt
-    }] for prompt in example_prompts]
-    example_prompts = tokenizer.apply_chat_template(messages,
-                                                    tokenize=False,
-                                                    add_generation_prompt=True)
-
     # Run unquantized model.
-    with vllm_runner(model_name=original_model,
+    with vllm_runner(model_name=model.original_model,
                      dtype=dtype,
                      max_model_len=MAX_MODEL_LEN,
                      tensor_parallel_size=tp_size) as original_model:
@@ -75,7 +101,7 @@ def test_models(
             example_prompts[:-1], max_tokens, num_logprobs)
 
     # Run gguf model.
-    with vllm_runner(model_name=gguf_model,
+    with vllm_runner(model_name=model.gguf_model,
                      dtype=dtype,
                      max_model_len=MAX_MODEL_LEN,
                      tensor_parallel_size=tp_size) as gguf_model:
diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py
index 8b42918c7cc0d..398b0714187e1 100644
--- a/vllm/model_executor/models/gpt2.py
+++ b/vllm/model_executor/models/gpt2.py
@@ -199,7 +199,9 @@ def __init__(
         assert not config.scale_attn_by_inverse_layer_idx
         assert not config.reorder_and_upcast_attn
         self.embed_dim = config.hidden_size
-        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim, quant_config=quant_config)
+        self.wte = VocabParallelEmbedding(config.vocab_size,
+                                          self.embed_dim,
+                                          quant_config=quant_config)
         self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
         self.start_layer, self.end_layer, self.h = make_layers(
             config.num_hidden_layers,

From e3e858111ee14959558b6158194a1774de25e39c Mon Sep 17 00:00:00 2001
From: Isotr0py <2037008807@qq.com>
Date: Tue, 29 Oct 2024 01:16:24 +0800
Subject: [PATCH 06/16] fix failed tests

---
 .../models/decoder_only/language/test_gguf.py | 26 ++++++++++++++-----
 1 file changed, 20 insertions(+), 6 deletions(-)

diff --git a/tests/models/decoder_only/language/test_gguf.py b/tests/models/decoder_only/language/test_gguf.py
index 335b91696f437..76aa3b2c74a38 100644
--- a/tests/models/decoder_only/language/test_gguf.py
+++ b/tests/models/decoder_only/language/test_gguf.py
@@ -8,6 +8,7 @@
 
 import pytest
 from huggingface_hub import hf_hub_download
+from transformers import AutoTokenizer
 
 from tests.quantization.utils import is_quant_method_supported
 
@@ -36,7 +37,7 @@ def gguf_model(self):
 )
 
 QWEN2_CONFIG = GGUFTestConfig(
-    original_model="Qwen/Qwen2-1.5B-Instruct",
+    original_model="Qwen/Qwen2.5-1.5B-Instruct",
     gguf_repo="Qwen/Qwen2.5-1.5B-Instruct-GGUF",
     gguf_filename="qwen2.5-1.5b-instruct-q4_k_m.gguf",
 )
@@ -54,7 +55,7 @@ def gguf_model(self):
 )
 
 STABLELM_CONFIG = GGUFTestConfig(
-    original_model="afrideva/stablelm-3b-4e1t-GGUF",
+    original_model="stabilityai/stablelm-3b-4e1t",
     gguf_repo="afrideva/stablelm-3b-4e1t-GGUF",
     gguf_filename="stablelm-3b-4e1t.q4_k_m.gguf",
 )
@@ -62,12 +63,16 @@ def gguf_model(self):
 STARCODER_CONFIG = GGUFTestConfig(
     original_model="bigcode/starcoder2-3b",
     gguf_repo="QuantFactory/starcoder2-3b-GGUF",
-    gguf_filename="starcoder2-3b.Q4_K_M.gguf",
+    gguf_filename="starcoder2-3b.Q6_K.gguf",
 )
 
 MODELS = [
-    LLAMA_CONFIG, QWEN2_CONFIG, PHI3_CONFIG, GPT2_CONFIG, STABLELM_CONFIG,
-    STARCODER_CONFIG
+    LLAMA_CONFIG,
+    QWEN2_CONFIG,
+    PHI3_CONFIG,
+    GPT2_CONFIG,
+    STABLELM_CONFIG,
+    STARCODER_CONFIG,
 ]
 
 
@@ -91,17 +96,26 @@ def test_models(
     if num_gpus_available < tp_size:
         pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
 
+    tokenizer = AutoTokenizer.from_pretrained(model.original_model)
+    if tokenizer.chat_template is not None:
+        messages = [[{
+            'role': 'user',
+            'content': prompt
+        }] for prompt in example_prompts]
+        example_prompts = tokenizer.apply_chat_template(
+            messages, tokenize=False, add_generation_prompt=True)
+
     # Run unquantized model.
     with vllm_runner(model_name=model.original_model,
                      dtype=dtype,
                      max_model_len=MAX_MODEL_LEN,
                      tensor_parallel_size=tp_size) as original_model:
-
         original_outputs = original_model.generate_greedy_logprobs(
             example_prompts[:-1], max_tokens, num_logprobs)
 
     # Run gguf model.
     with vllm_runner(model_name=model.gguf_model,
+                     tokenizer_name=model.original_model,
                      dtype=dtype,
                      max_model_len=MAX_MODEL_LEN,
                      tensor_parallel_size=tp_size) as gguf_model:

From 3eb08eb06ef52f7586c8be8da37babb212409b63 Mon Sep 17 00:00:00 2001
From: Isotr0py <2037008807@qq.com>
Date: Thu, 31 Oct 2024 15:25:05 +0800
Subject: [PATCH 07/16] fix gguf test

---
 tests/models/decoder_only/language/test_gguf.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/tests/models/decoder_only/language/test_gguf.py b/tests/models/decoder_only/language/test_gguf.py
index 76aa3b2c74a38..81b93ebdf0fc0 100644
--- a/tests/models/decoder_only/language/test_gguf.py
+++ b/tests/models/decoder_only/language/test_gguf.py
@@ -39,7 +39,7 @@ def gguf_model(self):
 QWEN2_CONFIG = GGUFTestConfig(
     original_model="Qwen/Qwen2.5-1.5B-Instruct",
     gguf_repo="Qwen/Qwen2.5-1.5B-Instruct-GGUF",
-    gguf_filename="qwen2.5-1.5b-instruct-q4_k_m.gguf",
+    gguf_filename="qwen2.5-1.5b-instruct-q6_k.gguf",
 )
 
 PHI3_CONFIG = GGUFTestConfig(
@@ -72,7 +72,7 @@ def gguf_model(self):
     PHI3_CONFIG,
     GPT2_CONFIG,
     STABLELM_CONFIG,
-    STARCODER_CONFIG,
+    # STARCODER_CONFIG, # broken
 ]
 
 

From 06b843d19a089c811f7a6d88166554c952c404ed Mon Sep 17 00:00:00 2001
From: Isotr0py <2037008807@qq.com>
Date: Fri, 1 Nov 2024 00:26:35 +0800
Subject: [PATCH 08/16] add transformers flag for gguf test

---
 tests/models/decoder_only/language/test_gguf.py | 12 ++++++++++++
 1 file changed, 12 insertions(+)

diff --git a/tests/models/decoder_only/language/test_gguf.py b/tests/models/decoder_only/language/test_gguf.py
index 81b93ebdf0fc0..77a20e83d5275 100644
--- a/tests/models/decoder_only/language/test_gguf.py
+++ b/tests/models/decoder_only/language/test_gguf.py
@@ -7,6 +7,7 @@
 from typing import List, NamedTuple, Type
 
 import pytest
+import transformers
 from huggingface_hub import hf_hub_download
 from transformers import AutoTokenizer
 
@@ -24,12 +25,15 @@ class GGUFTestConfig(NamedTuple):
     original_model: str
     gguf_repo: str
     gguf_filename: str
+    run_requirement: bool = True
 
     @property
     def gguf_model(self):
         return hf_hub_download(self.gguf_repo, filename=self.gguf_filename)
 
 
+TRANSFORMERS_REQUIREMENT = transformers.__version__.startswith("4.46.0")
+
 LLAMA_CONFIG = GGUFTestConfig(
     original_model="meta-llama/Llama-3.2-1B-Instruct",
     gguf_repo="bartowski/Llama-3.2-1B-Instruct-GGUF",
@@ -46,24 +50,28 @@ def gguf_model(self):
     original_model="microsoft/Phi-3.5-mini-instruct",
     gguf_repo="bartowski/Phi-3.5-mini-instruct-GGUF",
     gguf_filename="Phi-3.5-mini-instruct-IQ4_XS.gguf",
+    run_requirement=TRANSFORMERS_REQUIREMENT,
 )
 
 GPT2_CONFIG = GGUFTestConfig(
     original_model="openai-community/gpt2-large",
     gguf_repo="QuantFactory/gpt2-large-GGUF",
     gguf_filename="gpt2-large.Q4_K_M.gguf",
+    run_requirement=TRANSFORMERS_REQUIREMENT,
 )
 
 STABLELM_CONFIG = GGUFTestConfig(
     original_model="stabilityai/stablelm-3b-4e1t",
     gguf_repo="afrideva/stablelm-3b-4e1t-GGUF",
     gguf_filename="stablelm-3b-4e1t.q4_k_m.gguf",
+    run_requirement=TRANSFORMERS_REQUIREMENT,
 )
 
 STARCODER_CONFIG = GGUFTestConfig(
     original_model="bigcode/starcoder2-3b",
     gguf_repo="QuantFactory/starcoder2-3b-GGUF",
     gguf_filename="starcoder2-3b.Q6_K.gguf",
+    run_requirement=TRANSFORMERS_REQUIREMENT,
 )
 
 MODELS = [
@@ -96,6 +104,10 @@ def test_models(
     if num_gpus_available < tp_size:
         pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
 
+    if not model.run_requirement:
+        pytest.skip(
+            f"Model not supported in transformers=={transformers.__version__}")
+
     tokenizer = AutoTokenizer.from_pretrained(model.original_model)
     if tokenizer.chat_template is not None:
         messages = [[{

From 37004618aba000337bb64666d91c8cb79e77fc47 Mon Sep 17 00:00:00 2001
From: Isotr0py <2037008807@qq.com>
Date: Fri, 1 Nov 2024 01:14:11 +0800
Subject: [PATCH 09/16] make transformers flag more robust

---
 tests/models/decoder_only/language/test_gguf.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/tests/models/decoder_only/language/test_gguf.py b/tests/models/decoder_only/language/test_gguf.py
index 77a20e83d5275..d50c42b3363f1 100644
--- a/tests/models/decoder_only/language/test_gguf.py
+++ b/tests/models/decoder_only/language/test_gguf.py
@@ -10,6 +10,7 @@
 import transformers
 from huggingface_hub import hf_hub_download
 from transformers import AutoTokenizer
+from packaging.version import parse
 
 from tests.quantization.utils import is_quant_method_supported
 
@@ -32,7 +33,7 @@ def gguf_model(self):
         return hf_hub_download(self.gguf_repo, filename=self.gguf_filename)
 
 
-TRANSFORMERS_REQUIREMENT = transformers.__version__.startswith("4.46.0")
+TRANSFORMERS_REQUIREMENT = parse(transformers.__version__) >= parse("4.46.0")
 
 LLAMA_CONFIG = GGUFTestConfig(
     original_model="meta-llama/Llama-3.2-1B-Instruct",

From f76ea3b7ca5510b20f4b1038a26d0557d09116fd Mon Sep 17 00:00:00 2001
From: Isotr0py <2037008807@qq.com>
Date: Fri, 1 Nov 2024 01:17:18 +0800
Subject: [PATCH 10/16] code format

---
 tests/models/decoder_only/language/test_gguf.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tests/models/decoder_only/language/test_gguf.py b/tests/models/decoder_only/language/test_gguf.py
index d50c42b3363f1..8311d60dda0bf 100644
--- a/tests/models/decoder_only/language/test_gguf.py
+++ b/tests/models/decoder_only/language/test_gguf.py
@@ -9,8 +9,8 @@
 import pytest
 import transformers
 from huggingface_hub import hf_hub_download
-from transformers import AutoTokenizer
 from packaging.version import parse
+from transformers import AutoTokenizer
 
 from tests.quantization.utils import is_quant_method_supported
 

From 5a5aa2ce5d178ee464fe6e10a968194a10412338 Mon Sep 17 00:00:00 2001
From: Isotr0py <Isotr0py@outlook.com>
Date: Fri, 1 Nov 2024 01:19:02 +0800
Subject: [PATCH 11/16] Update vllm/model_executor/models/gpt2.py

Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
---
 vllm/model_executor/models/gpt2.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py
index 398b0714187e1..59a3f6b65fe26 100644
--- a/vllm/model_executor/models/gpt2.py
+++ b/vllm/model_executor/models/gpt2.py
@@ -300,7 +300,7 @@ def sample(
     def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         params_dict = dict(self.named_parameters(remove_duplicate=False))
         for name, loaded_weight in weights:
-            if "lm_head" in name:
+            if name.startswith("lm_head"):
                 # GPT-2 ties the weights of the embedding layer and the final
                 # linear layer.
                 continue

From f7c75e689219568e5714a361c9553cfed0fd6e3e Mon Sep 17 00:00:00 2001
From: Isotr0py <2037008807@qq.com>
Date: Fri, 1 Nov 2024 15:27:11 +0800
Subject: [PATCH 12/16] stablelm add prefix

---
 vllm/model_executor/models/stablelm.py | 54 +++++++++++++++++---------
 1 file changed, 35 insertions(+), 19 deletions(-)

diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py
index 1ef192e136b7f..c7435338b2e69 100644
--- a/vllm/model_executor/models/stablelm.py
+++ b/vllm/model_executor/models/stablelm.py
@@ -23,7 +23,7 @@
 
 import torch
 from torch import nn
-from transformers import PretrainedConfig
+from transformers import StableLmConfig
 
 from vllm.attention import Attention, AttentionMetadata
 from vllm.config import CacheConfig
@@ -50,8 +50,9 @@
 class StablelmMLP(nn.Module):
 
     def __init__(self,
-                 config: PretrainedConfig,
-                 quant_config: Optional[QuantizationConfig] = None) -> None:
+                 config: StableLmConfig,
+                 quant_config: Optional[QuantizationConfig] = None,
+                 prefix: str = "") -> None:
         super().__init__()
         self.config = config
         self.hidden_size = config.hidden_size
@@ -59,11 +60,13 @@ def __init__(self,
         self.gate_up_proj = MergedColumnParallelLinear(
             config.hidden_size, [config.intermediate_size] * 2,
             bias=False,
-            quant_config=quant_config)
+            quant_config=quant_config,
+            prefix=f"{prefix}.gate_up_proj")
         self.down_proj = RowParallelLinear(config.intermediate_size,
                                            config.hidden_size,
                                            bias=False,
-                                           quant_config=quant_config)
+                                           quant_config=quant_config,
+                                           prefix=f"{prefix}.down_proj")
         self.act_fn = SiluAndMul()
 
     def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -76,9 +79,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
 class StablelmAttention(nn.Module):
 
     def __init__(self,
-                 config: PretrainedConfig,
+                 config: StableLmConfig,
                  cache_config: Optional[CacheConfig] = None,
-                 quant_config: Optional[QuantizationConfig] = None) -> None:
+                 quant_config: Optional[QuantizationConfig] = None,
+                 prefix: str = "") -> None:
         super().__init__()
         self.config = config
         self.hidden_size = config.hidden_size
@@ -116,11 +120,13 @@ def __init__(self,
                                           self.total_num_heads,
                                           self.total_num_key_value_heads,
                                           self.qkv_bias,
-                                          quant_config=quant_config)
+                                          quant_config=quant_config,
+                                          prefix=f"{prefix}.qkv_proj")
         self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
                                         self.hidden_size,
                                         bias=False,
-                                        quant_config=quant_config)
+                                        quant_config=quant_config,
+                                        prefix=f"{prefix}.o_proj")
         self.rotary_emb = get_rope(
             self.head_dim,
             rotary_dim=self.rotary_ndims,
@@ -153,13 +159,17 @@ class StablelmDecoderLayer(nn.Module):
 
     def __init__(
         self,
-        config: PretrainedConfig,
+        config: StableLmConfig,
         cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
+        prefix: str = "",
     ) -> None:
         super().__init__()
-        self.self_attn = StablelmAttention(config, cache_config, quant_config)
-        self.mlp = StablelmMLP(config, quant_config)
+        self.self_attn = StablelmAttention(config,
+                                           cache_config,
+                                           quant_config,
+                                           prefix=f"{prefix}.self_attn")
+        self.mlp = StablelmMLP(config, quant_config, prefix=f"{prefix}.mlp")
         norm_eps = getattr(config, "norm_eps",
                            getattr(config, "layer_norm_eps", 1e-05))
         self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
@@ -196,20 +206,21 @@ def forward(
 class StableLMEpochModel(nn.Module):
 
     def __init__(self,
-                 config: PretrainedConfig,
+                 config: StableLmConfig,
                  cache_config: Optional[CacheConfig] = None,
                  quant_config: Optional[QuantizationConfig] = None,
-                 prefix: str = '') -> None:
+                 prefix: str = "") -> None:
         super().__init__()
         self.embed_tokens = VocabParallelEmbedding(
             config.vocab_size,
             config.hidden_size,
             quant_config=quant_config,
+            prefix=f"{prefix}.embed_tokens",
         )
         self.start_layer, self.end_layer, self.layers = make_layers(
             config.num_hidden_layers,
-            lambda prefix: StablelmDecoderLayer(config, cache_config,
-                                                quant_config),
+            lambda prefix: StablelmDecoderLayer(
+                config, cache_config, quant_config, prefix=f"{prefix}.layers"),
             prefix=f"{prefix}.layers",
         )
         norm_eps = getattr(config, "norm_eps",
@@ -250,17 +261,22 @@ class StablelmForCausalLM(nn.Module, SupportsPP):
 
     def __init__(
         self,
-        config: PretrainedConfig,
+        config: StableLmConfig,
         cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
+        prefix: str = "",
     ) -> None:
         super().__init__()
         self.config = config
         self.quant_config = quant_config
-        self.model = StableLMEpochModel(config, cache_config, quant_config)
+        self.model = StableLMEpochModel(config,
+                                        cache_config,
+                                        quant_config,
+                                        prefix=f"{prefix}.model")
         self.lm_head = ParallelLMHead(config.vocab_size,
                                       config.hidden_size,
-                                      quant_config=quant_config)
+                                      quant_config=quant_config,
+                                      prefix=f"{prefix}.lm_head")
         if self.config.tie_word_embeddings:
             self.lm_head.weight = self.model.embed_tokens.weight
         self.logits_processor = LogitsProcessor(config.vocab_size)

From f4c78cd41eba8e946eaaf0e88470cd1e3a37e085 Mon Sep 17 00:00:00 2001
From: Isotr0py <2037008807@qq.com>
Date: Fri, 1 Nov 2024 15:49:39 +0800
Subject: [PATCH 13/16] add prefix for gpt2 and starcoder2

---
 vllm/model_executor/models/gpt2.py       |  9 ++++--
 vllm/model_executor/models/stablelm.py   |  2 +-
 vllm/model_executor/models/starcoder2.py | 38 +++++++++++++++++-------
 3 files changed, 34 insertions(+), 15 deletions(-)

diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py
index 59a3f6b65fe26..f8f80ee86447d 100644
--- a/vllm/model_executor/models/gpt2.py
+++ b/vllm/model_executor/models/gpt2.py
@@ -201,7 +201,8 @@ def __init__(
         self.embed_dim = config.hidden_size
         self.wte = VocabParallelEmbedding(config.vocab_size,
                                           self.embed_dim,
-                                          quant_config=quant_config)
+                                          quant_config=quant_config,
+                                          prefix=f"{prefix}.wte")
         self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
         self.start_layer, self.end_layer, self.h = make_layers(
             config.num_hidden_layers,
@@ -249,6 +250,7 @@ def __init__(
         config: GPT2Config,
         cache_config: Optional[CacheConfig] = None,
         quant_config: Optional[QuantizationConfig] = None,
+        prefix: str = "",
     ):
         super().__init__()
         self.config = config
@@ -256,13 +258,14 @@ def __init__(
         self.transformer = GPT2Model(config,
                                      cache_config,
                                      quant_config,
-                                     prefix="transformer")
+                                     prefix=f"{prefix}.transformer")
         if self.config.tie_word_embeddings:
             self.lm_head = self.transformer.wte
         else:
             self.lm_head = ParallelLMHead(self.config.vocab_size,
                                           self.config.hidden_size,
-                                          quant_config=quant_config)
+                                          quant_config=quant_config,
+                                          prefix=f"{prefix}.lm_head")
         self.logits_processor = LogitsProcessor(config.vocab_size)
         self.sampler = Sampler()
         self.make_empty_intermediate_tensors = (
diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py
index c7435338b2e69..90b9449b79e79 100644
--- a/vllm/model_executor/models/stablelm.py
+++ b/vllm/model_executor/models/stablelm.py
@@ -220,7 +220,7 @@ def __init__(self,
         self.start_layer, self.end_layer, self.layers = make_layers(
             config.num_hidden_layers,
             lambda prefix: StablelmDecoderLayer(
-                config, cache_config, quant_config, prefix=f"{prefix}.layers"),
+                config, cache_config, quant_config, prefix=prefix),
             prefix=f"{prefix}.layers",
         )
         norm_eps = getattr(config, "norm_eps",
diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py
index 123bcd0a42f91..5f5f818c84f48 100644
--- a/vllm/model_executor/models/starcoder2.py
+++ b/vllm/model_executor/models/starcoder2.py
@@ -52,7 +52,8 @@ class Starcoder2Attention(nn.Module):
     def __init__(self,
                  config: Starcoder2Config,
                  cache_config: Optional[CacheConfig] = None,
-                 quant_config: Optional[QuantizationConfig] = None):
+                 quant_config: Optional[QuantizationConfig] = None,
+                 prefix: str = ""):
         super().__init__()
         self.config = config
 
@@ -86,12 +87,14 @@ def __init__(self,
             self.total_num_kv_heads,
             bias=self.use_bias,
             quant_config=quant_config,
+            prefix=f"{prefix}.qkv_proj",
         )
         self.o_proj = RowParallelLinear(
             self.total_num_heads * self.head_dim,
             self.hidden_size,
             bias=self.use_bias,
             quant_config=quant_config,
+            prefix=f"{prefix}.o_proj",
         )
         self.rotary_emb = get_rope(
             self.head_dim,
@@ -126,19 +129,22 @@ class Starcoder2MLP(nn.Module):
 
     def __init__(self,
                  config: Starcoder2Config,
-                 quant_config: Optional[QuantizationConfig] = None):
+                 quant_config: Optional[QuantizationConfig] = None,
+                 prefix: str = ""):
         super().__init__()
         self.c_fc = ColumnParallelLinear(
             config.hidden_size,
             config.intermediate_size,
             bias=config.use_bias,
             quant_config=quant_config,
+            prefix=f"{prefix}.c_fc",
         )
         self.c_proj = RowParallelLinear(
             config.intermediate_size,
             config.hidden_size,
             bias=config.use_bias,
             quant_config=quant_config,
+            prefix=f"{prefix}.c_proj",
         )
         self.act = get_act_fn(config.hidden_act, quant_config,
                               config.intermediate_size)
@@ -155,13 +161,17 @@ class Starcoder2DecoderLayer(nn.Module):
     def __init__(self,
                  config: Starcoder2Config,
                  cache_config: Optional[CacheConfig] = None,
-                 quant_config: Optional[QuantizationConfig] = None):
+                 quant_config: Optional[QuantizationConfig] = None,
+                 prefix: str = ""):
         super().__init__()
         self.hidden_size = config.hidden_size
         self.self_attn = Starcoder2Attention(config,
                                              cache_config,
-                                             quant_config=quant_config)
-        self.mlp = Starcoder2MLP(config, quant_config=quant_config)
+                                             quant_config=quant_config,
+                                             prefix=f"{prefix}.self_attn")
+        self.mlp = Starcoder2MLP(config,
+                                 quant_config=quant_config,
+                                 prefix=f"{prefix}.mlp")
         self.input_layernorm = nn.LayerNorm(config.hidden_size,
                                             eps=config.norm_epsilon)
         self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
@@ -208,13 +218,16 @@ def __init__(self,
         self.vocab_size = config.vocab_size
 
         # TODO: consider padding_idx (currently removed)
-        self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
-                                                   config.hidden_size,
-                                                   quant_config=quant_config)
+        self.embed_tokens = VocabParallelEmbedding(
+            config.vocab_size,
+            config.hidden_size,
+            quant_config=quant_config,
+            prefix=f"{prefix}.embed_tokens")
         self.start_layer, self.end_layer, self.layers = make_layers(
             config.num_hidden_layers,
             lambda prefix: Starcoder2DecoderLayer(
-                config, cache_config, quant_config=quant_config),
+                config, cache_config, quant_config=quant_config, prefix=prefix
+            ),
             prefix=f"{prefix}.layers",
         )
         self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon)
@@ -251,12 +264,14 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP):
     def __init__(self,
                  config: Starcoder2Config,
                  cache_config: Optional[CacheConfig] = None,
-                 quant_config: Optional[QuantizationConfig] = None):
+                 quant_config: Optional[QuantizationConfig] = None,
+                 prefix: str = ""):
         super().__init__()
         self.config = config
         self.model = Starcoder2Model(config,
                                      cache_config,
-                                     quant_config=quant_config)
+                                     quant_config=quant_config,
+                                     prefix=f"{prefix}.model")
         self.vocab_size = config.vocab_size
         self.unpadded_vocab_size = config.vocab_size
         if config.tie_word_embeddings:
@@ -269,6 +284,7 @@ def __init__(self,
                 org_num_embeddings=config.vocab_size,
                 padding_size=DEFAULT_VOCAB_PADDING_SIZE,
                 quant_config=quant_config,
+                prefix=f"{prefix}.lm_head",
             )
         self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                 config.vocab_size)

From f8add6ff944d0735f74ae7cadab03a7225624abc Mon Sep 17 00:00:00 2001
From: Isotr0py <2037008807@qq.com>
Date: Wed, 6 Nov 2024 00:15:31 +0800
Subject: [PATCH 14/16] fix phi-3 tp

---
 vllm/model_executor/layers/linear.py | 34 ++++++++++++----------------
 1 file changed, 15 insertions(+), 19 deletions(-)

diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py
index 1808dfc53072e..0492a7951a990 100644
--- a/vllm/model_executor/layers/linear.py
+++ b/vllm/model_executor/layers/linear.py
@@ -444,7 +444,10 @@ def weight_loader(self,
                 param.data[loaded_shard_id].copy_(loaded_weight)
                 param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
             else:
-                param.weight_type = loaded_weight.item()
+                param.shard_weight_type = {
+                    i: loaded_weight.item()
+                    for i, _ in enumerate(self.output_sizes)
+                }
             return
 
         if is_gguf_weight:
@@ -455,20 +458,15 @@ def weight_loader(self,
             shard_size = loaded_weight.size(output_dim) // tp_size
             start_idx = tp_rank * shard_size
 
-            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
-                                                 shard_size)
-
             if loaded_shard_id is not None:
+                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
+                                                     shard_size)
                 param.shard_id.append(loaded_shard_id)
                 param.shard_id_map[loaded_shard_id] = len(param.data_container)
                 param.data_container.append(loaded_weight)
                 if len(param.data_container) == 2:
                     self.qweight = param.materialize_nested()
-            else:
-                param.materialize(loaded_weight.shape,
-                                  dtype=loaded_weight.dtype)
-                param.data.copy_(loaded_weight)
-            return
+                return
 
         param_data = param.data
         output_dim = getattr(param, "output_dim", None)
@@ -784,12 +782,15 @@ def weight_loader(self,
         is_gguf_weight = getattr(param, "is_gguf_weight", False)
         is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
         if is_gguf_weight_type:
+            idx_map = {"q": 0, "k": 1, "v": 2}
             if loaded_shard_id is not None:
-                idx_map = {"q": 0, "k": 1, "v": 2}
                 param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
                 param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
             else:
-                param.weight_type = loaded_weight.item()
+                param.shard_weight_type = {
+                    k: loaded_weight.item()
+                    for k in idx_map
+                }
             return
 
         if is_gguf_weight:
@@ -800,20 +801,15 @@ def weight_loader(self,
             shard_size = loaded_weight.size(output_dim) // tp_size
             start_idx = tp_rank * shard_size
 
-            loaded_weight = loaded_weight.narrow(output_dim, start_idx,
-                                                 shard_size)
-
             if loaded_shard_id is not None:
+                loaded_weight = loaded_weight.narrow(output_dim, start_idx,
+                                                     shard_size)
                 param.shard_id.append(loaded_shard_id)
                 param.shard_id_map[loaded_shard_id] = len(param.data_container)
                 param.data_container.append(loaded_weight)
                 if len(param.data_container) == 3:
                     self.qweight = param.materialize_nested()
-            else:
-                param.materialize(loaded_weight.shape,
-                                  dtype=loaded_weight.dtype)
-                param.data.copy_(loaded_weight)
-            return
+                return
 
         param_data = param.data
         output_dim = getattr(param, "output_dim", None)

From 01f2250c48d8af14aa574cb24e379bf5dfa8cb01 Mon Sep 17 00:00:00 2001
From: Isotr0py <2037008807@qq.com>
Date: Thu, 5 Dec 2024 22:45:49 +0800
Subject: [PATCH 15/16] code format

Signed-off-by: Isotr0py <2037008807@qq.com>
---
 tests/models/decoder_only/language/test_gguf.py | 7 -------
 1 file changed, 7 deletions(-)

diff --git a/tests/models/decoder_only/language/test_gguf.py b/tests/models/decoder_only/language/test_gguf.py
index 9ccba994f287f..ea503b6562e07 100644
--- a/tests/models/decoder_only/language/test_gguf.py
+++ b/tests/models/decoder_only/language/test_gguf.py
@@ -7,9 +7,7 @@
 from typing import List, NamedTuple, Type
 
 import pytest
-import transformers
 from huggingface_hub import hf_hub_download
-from packaging.version import parse
 from transformers import AutoTokenizer
 
 from tests.quantization.utils import is_quant_method_supported
@@ -26,7 +24,6 @@ class GGUFTestConfig(NamedTuple):
     original_model: str
     gguf_repo: str
     gguf_filename: str
-    run_requirement: bool = True
 
     @property
     def gguf_model(self):
@@ -49,28 +46,24 @@ def gguf_model(self):
     original_model="microsoft/Phi-3.5-mini-instruct",
     gguf_repo="bartowski/Phi-3.5-mini-instruct-GGUF",
     gguf_filename="Phi-3.5-mini-instruct-IQ4_XS.gguf",
-    run_requirement=TRANSFORMERS_REQUIREMENT,
 )
 
 GPT2_CONFIG = GGUFTestConfig(
     original_model="openai-community/gpt2-large",
     gguf_repo="QuantFactory/gpt2-large-GGUF",
     gguf_filename="gpt2-large.Q4_K_M.gguf",
-    run_requirement=TRANSFORMERS_REQUIREMENT,
 )
 
 STABLELM_CONFIG = GGUFTestConfig(
     original_model="stabilityai/stablelm-3b-4e1t",
     gguf_repo="afrideva/stablelm-3b-4e1t-GGUF",
     gguf_filename="stablelm-3b-4e1t.q4_k_m.gguf",
-    run_requirement=TRANSFORMERS_REQUIREMENT,
 )
 
 STARCODER_CONFIG = GGUFTestConfig(
     original_model="bigcode/starcoder2-3b",
     gguf_repo="QuantFactory/starcoder2-3b-GGUF",
     gguf_filename="starcoder2-3b.Q6_K.gguf",
-    run_requirement=TRANSFORMERS_REQUIREMENT,
 )
 
 MODELS = [

From 53b21ede30575020ef2902cc334f10a3c4f87c2a Mon Sep 17 00:00:00 2001
From: Isotr0py <2037008807@qq.com>
Date: Thu, 5 Dec 2024 22:48:05 +0800
Subject: [PATCH 16/16] fix test

Signed-off-by: Isotr0py <2037008807@qq.com>
---
 tests/models/decoder_only/language/test_gguf.py | 13 +------------
 1 file changed, 1 insertion(+), 12 deletions(-)

diff --git a/tests/models/decoder_only/language/test_gguf.py b/tests/models/decoder_only/language/test_gguf.py
index ea503b6562e07..81b93ebdf0fc0 100644
--- a/tests/models/decoder_only/language/test_gguf.py
+++ b/tests/models/decoder_only/language/test_gguf.py
@@ -78,18 +78,7 @@ def gguf_model(self):
 
 @pytest.mark.skipif(not is_quant_method_supported("gguf"),
                     reason="gguf is not supported on this GPU type.")
-@pytest.mark.parametrize(("original_model", "gguf_id", "gguf_path"), [
-    ("meta-llama/Llama-3.2-1B-Instruct",
-     "bartowski/Llama-3.2-1B-Instruct-GGUF",
-     "Llama-3.2-1B-Instruct-Q4_K_M.gguf"),
-    ("meta-llama/Llama-3.2-1B-Instruct",
-     "bartowski/Llama-3.2-1B-Instruct-GGUF",
-     "Llama-3.2-1B-Instruct-IQ4_XS.gguf"),
-    ("Qwen/Qwen2-1.5B-Instruct", "Qwen/Qwen2-1.5B-Instruct-GGUF",
-     "qwen2-1_5b-instruct-q4_k_m.gguf"),
-    ("Qwen/Qwen2-1.5B-Instruct", "legraphista/Qwen2-1.5B-Instruct-IMat-GGUF",
-     "Qwen2-1.5B-Instruct.IQ4_XS.gguf"),
-])
+@pytest.mark.parametrize("model", MODELS)
 @pytest.mark.parametrize("dtype", ["half"])
 @pytest.mark.parametrize("max_tokens", [32])
 @pytest.mark.parametrize("num_logprobs", [5])