Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Mllama with pagedattention #273

Draft
wants to merge 3 commits into
base: habana-main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
COPY proto proto
COPY server server
COPY server/Makefile server/Makefile

# Build vllm-hpu-extension
RUN git clone https://github.com/HabanaAI/vllm-hpu-extension.git && \
cd vllm-hpu-extension && \
git apply ../server/0001-Remove-vllm.patch && \
pip install . --no-cache-dir

RUN cd server && \
make gen-server && \
pip install --no-deps -r requirements.txt && \
Expand Down
39 changes: 39 additions & 0 deletions server/0001-Remove-vllm.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
From 77e9978bd9efe883ac58badf77ea8ce9ce79eb69 Mon Sep 17 00:00:00 2001
From: yuanwu <[email protected]>
Date: Thu, 20 Feb 2025 02:17:46 +0000
Subject: [PATCH] Remove vllm

Signed-off-by: yuanwu <[email protected]>
---
vllm_hpu_extension/ops.py | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/vllm_hpu_extension/ops.py b/vllm_hpu_extension/ops.py
index 7658602..60997eb 100644
--- a/vllm_hpu_extension/ops.py
+++ b/vllm_hpu_extension/ops.py
@@ -14,15 +14,15 @@ import habana_frameworks.torch.core as htcore
import habana_frameworks.torch.utils.experimental as htexp
from vllm_hpu_extension.flags import enabled_flags

-from vllm.logger import init_logger
-from vllm.platforms import current_platform
+#from vllm.logger import init_logger
+#from vllm.platforms import current_platform

-logger = init_logger(__name__)
+#logger = init_logger(__name__)


-def is_hpu_gaudi2():
- return current_platform.is_hpu() and htexp._get_device_type(
- ) == htexp.synDeviceType.synDeviceGaudi2
+#def is_hpu_gaudi2():
+# return current_platform.is_hpu() and htexp._get_device_type(
+# ) == htexp.synDeviceType.synDeviceGaudi2


def get_hpu_gaudi2_scale_factor():
--
2.34.1

4 changes: 2 additions & 2 deletions server/text_generation_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def serve(
), "MASTER_PORT must be set when sharded is True"

# Remove default handler
logger.remove()
#logger.remove()
logger.add(
sys.stdout,
format="{message}",
Expand Down Expand Up @@ -193,7 +193,7 @@ def download_weights(
merge_lora: bool = False,
):
# Remove default handler
logger.remove()
#logger.remove()
logger.add(
sys.stdout,
format="{message}",
Expand Down
8 changes: 8 additions & 0 deletions server/text_generation_server/layers/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
)
elif SYSTEM == "hpu":
from .hpu import (
attention,
paged_attention,
reshape_and_cache,
PREFILL_IN_KV_CACHE,
SUPPORTS_WINDOWING,
)
else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")

Expand Down
136 changes: 136 additions & 0 deletions server/text_generation_server/layers/attention/hpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import torch
import itertools
from text_generation_server.models.globals import (
ATTENTION,
BLOCK_SIZE,
)
from text_generation_server.layers.attention import Seqlen
from typing import Optional, List
from vllm_hpu_extension import cache_ops, ops
from vllm_hpu_extension.utils import (Matmul, ModuleFusedSDPA, Softmax,
VLLMKVCache)
try:
from habana_frameworks.torch.hpex.kernels import FusedSDPA
except ImportError:
print("Not using HPU fused scaled dot-product attention kernel.")
FusedSDPA = None

SUPPORTS_WINDOWING = False
PREFILL_IN_KV_CACHE = False
def fetch_from_cache(cache, blocks):
return cache.index_select(0, blocks)

def attention(
q: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
seqlen: Seqlen,
block_tables: torch.Tensor,
softmax_scale,
window_size_left=-1,
causal=True,
softcap: Optional[float] = None,
):
attn_output = FusedSDPA.apply(
q, key_cache, value_cache, None, 0.0, causal, None
)

return attn_output


def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slots: torch.Tensor,
):
if ATTENTION in {"flashdecoding", "flashinfer"}:
shape = key_cache.shape
key_cache.view(-1, shape[-2], shape[-1])[slots] = key
value_cache.view(-1, shape[-2], shape[-1])[slots] = value
else:
cache_ops.reshape_and_cache(
key, value, key_cache, value_cache, slots, "auto", 1.0
)

def paged_attention(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
kv_head_mapping: torch.Tensor,
softmax_scale: float,
block_tables: torch.Tensor,
seqlen: Seqlen,
max_s: int,
softcap: Optional[float] = None,
):

batch_size, seq_len, hidden_size = query.shape
blocks_used = [len(bt) for bt in block_tables if bt]
block_list = []
block_scales = []
for i, bt in enumerate(block_tables):
block_list.extend(bt)
blocks_in_group = len(bt)
if blocks_in_group > 0:
scale = 1.0 / blocks_in_group
block_scales.extend([scale] * blocks_in_group)

block_mapping_nested: List[List[int]] = [
[i] * b_u for i, b_u in enumerate(blocks_used)
]
block_mapping: List[int] = list(
itertools.chain.from_iterable(block_mapping_nested))
block_list = torch.tensor(block_list,
dtype=torch.int,
device="hpu")
block_mapping = torch.tensor(block_mapping,
dtype=torch.long,
device="hpu")
block_scales = torch.tensor(block_scales,
dtype=torch.bfloat16,
device="hpu")
output = ops.flat_pa(
query=query,
key_cache=key_cache,
value_cache=value_cache,
block_list=block_list,
block_mapping=block_mapping,
block_bias=None,
block_scales=block_scales,
block_groups=None,
scale=softmax_scale,
matmul_qk_op=Matmul(),
matmul_av_op=Matmul(),
batch2block_matmul_op=Matmul(),
block2batch_matmul_op=Matmul(),
keys_fetch_func=fetch_from_cache,
values_fetch_func=fetch_from_cache)
# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)

# out = torch.empty_like(query)
# ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
# out,
# query,
# key_cache,
# value_cache,
# kv_head_mapping,
# softmax_scale,
# block_tables,
# seqlen.input_lengths,
# BLOCK_SIZE,
# max_s,
# None,
# )
# return out


__all__ = [
"PREFILL_IN_KV_CACHE",
"SUPPORTS_WINDOWING",
"attention",
"paged_attention",
"reshape_and_cache",
]
23 changes: 22 additions & 1 deletion server/text_generation_server/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,16 @@ def forward(self, hidden_states, residual=None):
)
return out, residual if residual is not None else hidden_states

elif SYSTEM == "hpu":

class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
if residual is not None:
hidden_states += residual
residual = hidden_states

return super().forward(hidden_states), residual


class FastRMSNorm(nn.Module):
def __init__(self, weight: torch.Tensor, eps: float):
Expand All @@ -111,7 +121,18 @@ def load(cls, prefix, weights, eps=1e-6):
return cls(weight, eps)

def forward(self, hidden_states, residual=None):
if SYSTEM == "ipex":
if SYSTEM == "hpu":
from habana_frameworks.torch.hpex.normalization import FusedRMSNorm
# mixed dtypes are not good for FusedRMSNorm, both inputs need to have same dtype
if hidden_states.dtype != self.weight.dtype:
orig_dtype = hidden_states.dtype
out = FusedRMSNorm.apply(hidden_states.to(self.weight.dtype), self.weight, self.variance_epsilon)
return out.to(orig_dtype), residual if residual is not None else hidden_states
else:
out = FusedRMSNorm.apply(hidden_states, self.weight, self.variance_epsilon)
return out, residual if residual is not None else hidden_states

elif SYSTEM == "ipex":
out = ipex.llm.functional.add_rms_norm(
residual,
hidden_states,
Expand Down
2 changes: 1 addition & 1 deletion server/text_generation_server/layers/moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
if SYSTEM == "rocm":
from .fused_moe_rocm import grouped_topk
from vllm.model_executor.layers.fused_moe import fused_topk
elif SYSTEM != "ipex":
elif SYSTEM != "ipex" and SYSTEM != "hpu":
from moe_kernels.fused_moe import fused_topk, grouped_topk


Expand Down
2 changes: 1 addition & 1 deletion server/text_generation_server/layers/moe/unquantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

if SYSTEM == "rocm":
from vllm.model_executor.layers.fused_moe import fused_moe
elif SYSTEM != "ipex":
elif SYSTEM != "ipex" and SYSTEM != "hpu":
from moe_kernels.fused_moe import fused_moe


Expand Down
5 changes: 5 additions & 0 deletions server/text_generation_server/layers/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ def forward(
ipex.llm.functional.rotary_embedding(
query, key, sin, cos, query.size(-1), True
)
elif SYSTEM == "hpu":
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingHelperV2 as FusedRoPE

query = FusedRoPE.apply(query, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), None)
key = FusedRoPE.apply(key, cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0), None)
else:
raise ValueError(
"Your system seem to be not supported. Please check your install or open an issue at https://github.com/huggingface/text-generation-inference/issues with a clear reproduction."
Expand Down
43 changes: 36 additions & 7 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import os
import enum

from loguru import logger
from transformers.configuration_utils import PretrainedConfig
Expand All @@ -16,21 +17,22 @@
from text_generation_server.models.causal_lm import CausalLM
from text_generation_server.models.bloom import BLOOM
from text_generation_server.models.starcoder import StarCoder
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
#from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
from text_generation_server.models.static_vlm_causal_lm import StaticVlmCausalLM
from text_generation_server.models.mllama_causal_lm import MllamaCausalLM
from text_generation_server.models.custom_modeling.llava_next import (
LlavaNextForConditionalGeneration,
)
# from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
# from text_generation_server.models.custom_modeling.mllama import (
# MllamaForConditionalGeneration,
# )
from text_generation_server.models.mllama_causal_lm import MllamaCausalLMBatch
from text_generation_server.models.custom_modeling.mllama import (
MllamaForConditionalGeneration,
)
from text_generation_server.utils.adapter import (
AdapterParameters,
build_layer_weight_lookup,
load_and_merge_adapters,
AdapterInfo,
)
from text_generation_server.adapters.lora import LoraWeights


from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
Expand All @@ -39,6 +41,19 @@
# Disable gradients
torch.set_grad_enabled(False)

class ModelType(enum.Enum):
MLLAMA = {
"type": "mllama",
"name": "Mllama",
"url": "https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct",
"multimodal": True,
}

__GLOBALS = locals()
for data in ModelType:
__GLOBALS[data.name] = data.value["type"]



def get_model(
model_id: str,
Expand Down Expand Up @@ -186,7 +201,7 @@ def get_model(
)

if model_type == "llava_next":
return VlmCausalLM(
return StaticVlmCausalLM(
model_class=LlavaNextForConditionalGeneration,
model_id=model_id,
revision=revision,
Expand All @@ -195,6 +210,20 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)

if model_type == MLLAMA:
return MllamaCausalLM(
model_id=model_id,
model_class=MllamaForConditionalGeneration,
batch_class=MllamaCausalLMBatch,
revision=revision,
quantize=quantize,
speculator=speculator,
dtype=dtype,
default_dtype=torch.bfloat16,
trust_remote_code=trust_remote_code,
lora_adapter_ids=lora_adapter_ids,
)

if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(
Expand Down
Loading