Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Move more control of kv cache initialization from model_executor to EngineCore #11960

Merged
merged 22 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions tests/v1/test_utils.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is heavy code duplication.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running the same test as v0, but the detail workflow is very different due to the dict->list and list->dict difference.

Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import List

import torch

from vllm.v1.utils import bind_kv_cache


def test_bind_kv_cache():
from vllm.attention import Attention

ctx = {
'layers.0.self_attn': Attention(32, 128, 0.1),
'layers.1.self_attn': Attention(32, 128, 0.1),
'layers.2.self_attn': Attention(32, 128, 0.1),
'layers.3.self_attn': Attention(32, 128, 0.1),
}
kv_cache = {
'layers.0.self_attn': torch.zeros((1, )),
'layers.1.self_attn': torch.zeros((1, )),
'layers.2.self_attn': torch.zeros((1, )),
'layers.3.self_attn': torch.zeros((1, )),
}
runner_kv_caches: List[torch.Tensor] = []
bind_kv_cache(kv_cache, ctx, runner_kv_caches)
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[
'layers.0.self_attn']
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[
'layers.1.self_attn']
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[
'layers.2.self_attn']
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[
'layers.3.self_attn']

assert runner_kv_caches[0] is kv_cache['layers.0.self_attn']
assert runner_kv_caches[1] is kv_cache['layers.1.self_attn']
assert runner_kv_caches[2] is kv_cache['layers.2.self_attn']
assert runner_kv_caches[3] is kv_cache['layers.3.self_attn']


def test_bind_kv_cache_non_attention():
from vllm.attention import Attention

# example from Jamba PP=2
ctx = {
'model.layers.20.attn': Attention(32, 128, 0.1),
'model.layers.28.attn': Attention(32, 128, 0.1),
}
kv_cache = {
'model.layers.20.attn': torch.zeros((1, )),
'model.layers.28.attn': torch.zeros((1, )),
}

runner_kv_caches: List[torch.Tensor] = []
bind_kv_cache(kv_cache, ctx, runner_kv_caches)

assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[
'model.layers.20.attn']
assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[
'model.layers.28.attn']

assert runner_kv_caches[0] is kv_cache['model.layers.20.attn']
assert runner_kv_caches[1] is kv_cache['model.layers.28.attn']
2 changes: 2 additions & 0 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def __init__(
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window
self.backend = backend_name_to_enum(attn_backend.get_name())
self.dtype = dtype

# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
# torch.compile works by registering the attention as one giant
Expand Down
124 changes: 124 additions & 0 deletions vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
from dataclasses import dataclass
from typing import Any, List, NamedTuple, Optional, Tuple

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec,
KVCacheTensor)
from vllm.v1.request import Request

logger = init_logger(__name__)
Expand Down Expand Up @@ -305,3 +308,124 @@ def hash_request_tokens(block_size: int,
ret.append(block_hash)
parent_block_hash_value = block_hash.hash_value
return ret


def check_enough_kv_cache_memory(vllm_config: VllmConfig,
kv_cache_spec: KVCacheSpec,
available_memory: int):
"""
Checks whether `available_memory` is enough for the KV cache to hold at
least one request with the model's max_model_len.

Args:
comaniac marked this conversation as resolved.
Show resolved Hide resolved
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of the model
available_memory: Memory available for KV cache in bytes.

Raises:
ValueError: If there is not enough memory available for the KV cache.
"""

if available_memory <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")

max_model_len = vllm_config.model_config.max_model_len
needed_memory = 0
for layer_spec in kv_cache_spec.values():
needed_memory += layer_spec.bytes_for_tokens(max_model_len)

if needed_memory > available_memory:
raise ValueError(
f"To serve at least one request with the models's max seq len "
f"({max_model_len}), ({needed_memory/1024/1024/1024:.2f} GB KV "
f"cache is needed, which is larger than the available KV cache "
f"memory ({available_memory/1024/1024/1024:.2f} GB). Try "
f"increasing `gpu_memory_utilization` or decreasing "
f"`max_model_len` when initializing the engine.")


def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool:
"""
Whether all layers in the given KVCacheSpec have the same type of KV cache.

Args:
kv_cache_spec: The KVCacheSpec of the model

Returns:
True if all layers have the same type, False otherwise.
"""

layer_keys = set(layer.type_id for layer in kv_cache_spec.values())
return len(layer_keys) == 1


def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
kv_cache_spec: KVCacheSpec,
available_memory: int) -> KVCacheConfig:
"""
Generates the KV cache configuration for a model with one type of KV cache.
Divide the available memory equally among all layers.

Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of the model
available_memory: Memory available for KV cache in bytes.

Returns:
The generated KVCacheConfig
"""

page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
assert len(page_sizes) == 1
page_size = page_sizes.pop()

num_blocks = int(available_memory // page_size // len(kv_cache_spec))
num_blocks = max(num_blocks, 0)

if vllm_config.cache_config.num_gpu_blocks_override is not None:
num_gpu_blocks_override = \
vllm_config.cache_config.num_gpu_blocks_override
logger.info(
"Overriding num_gpu_blocks=%d with "
"num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
num_blocks = num_gpu_blocks_override

logger.info("# GPU blocks: %d", num_blocks)

per_layer_size = page_size * num_blocks

kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
tensors={
layer_name: KVCacheTensor(size=per_layer_size)
for layer_name in kv_cache_spec
},
groups=[[layer_name for layer_name in kv_cache_spec]],
kv_cache_spec=kv_cache_spec)
return kv_cache_config


def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec,
available_memory: int) -> KVCacheConfig:
"""
Generates the KV cache configuration for a model
TODO: support hybrid models with more than one type of KV cache.

Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of the model
available_memory: Memory available for KV cache in bytes.

Returns:
The generated KVCacheConfig
"""
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
if is_kv_cache_type_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for most models.
# Allocate the same amount of memory for each layer.
return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
available_memory)
else:
raise NotImplementedError
31 changes: 18 additions & 13 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
import zmq.asyncio
from msgspec import msgpack

from vllm.config import CacheConfig, VllmConfig
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.utils import get_exception_traceback, zmq_socket_ctx
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
EngineCoreRequest, EngineCoreRequestType,
Expand Down Expand Up @@ -49,7 +50,7 @@ def __init__(

# Setup KV Caches and update CacheConfig after profiling.
num_gpu_blocks, num_cpu_blocks = self._initialize_kv_caches(
vllm_config.cache_config)
vllm_config)
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks

Expand All @@ -65,21 +66,25 @@ def __init__(
vllm_config.model_config)

def _initialize_kv_caches(self,
cache_config: CacheConfig) -> Tuple[int, int]:
vllm_config: VllmConfig) -> Tuple[int, int]:
start = time.time()
num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks(
)

if cache_config.num_gpu_blocks_override is not None:
num_gpu_blocks_override = cache_config.num_gpu_blocks_override
logger.info(
"Overriding num_gpu_blocks=%d with "
"num_gpu_blocks_override=%d", num_gpu_blocks,
num_gpu_blocks_override)
num_gpu_blocks = num_gpu_blocks_override
# Get all kv cache needed by the model
kv_cache_spec = self.model_executor.get_kv_cache_spec()

# Profiles the peak memory usage of the model to determine how much
# memory can be allocated for kv cache.
availble_gpu_memory = self.model_executor.determine_available_memory()

# Get the kv cache tensor size
kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec,
availble_gpu_memory)
num_gpu_blocks = kv_cache_config.num_blocks
num_cpu_blocks = 0
self.model_executor.initialize(num_gpu_blocks)

# Initialize kv cache and warmup the execution
self.model_executor.initialize(kv_cache_config)

elapsed = time.time() - start
logger.info(("init engine (profile, create kv cache, "
"warmup model) took %.2f seconds"), elapsed)
Expand Down
11 changes: 8 additions & 3 deletions vllm/v1/executor/abstract.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from abc import ABC, abstractmethod
from typing import Tuple, Type
from typing import Type

from vllm.config import VllmConfig
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput


Expand Down Expand Up @@ -31,11 +32,15 @@ def __init__(self, vllm_config: VllmConfig) -> None:
raise NotImplementedError

@abstractmethod
def initialize(self, num_gpu_blocks: int) -> None:
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
raise NotImplementedError

@abstractmethod
def determine_num_available_blocks(self) -> Tuple[int, int]:
def determine_available_memory(self) -> int: # in bytes
raise NotImplementedError

@abstractmethod
def get_kv_cache_spec(self) -> KVCacheSpec:
raise NotImplementedError

@abstractmethod
Expand Down
25 changes: 15 additions & 10 deletions vllm/v1/executor/multiproc_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from vllm.utils import (get_distributed_init_method, get_mp_context,
get_open_port, get_open_zmq_ipc_path, zmq_socket_ctx)
from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.worker.worker_base import WorkerWrapperBase

Expand Down Expand Up @@ -90,29 +91,33 @@ def sigusr1_handler(signum, frame):
for w in self.workers:
w.worker_response_mq.wait_until_ready()

def initialize(self, num_gpu_blocks: int) -> None:
def initialize(self, kv_cache_config: KVCacheConfig) -> None:
"""
Initialize the KV caches and begin the model execution loop of the
underlying workers.
"""
logger.info("# GPU blocks: %d", num_gpu_blocks)
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, ))
self.collective_rpc("initialize_cache", args=(kv_cache_config, ))
self.collective_rpc("compile_or_warm_up_model")

def determine_num_available_blocks(self) -> Tuple[int, int]:
def determine_available_memory(self) -> int:
"""
Determine the number of available KV blocks by invoking the
Determine the available memory (in bytes) for KV cache by invoking the
underlying worker.
"""
num_blocks = self.collective_rpc("determine_num_available_blocks")
memory_sizes = self.collective_rpc("determine_available_memory")

# Since we use a shared centralized controller, we take the minimum
# number of blocks across all workers to make sure all the memory
# memory size across all workers to make sure all the memory
# operators can be applied to all workers.
num_gpu_blocks = min(b[0] for b in num_blocks)
num_cpu_blocks = min(b[1] for b in num_blocks)
return min(memory_sizes)

return num_gpu_blocks, num_cpu_blocks
def get_kv_cache_spec(self) -> KVCacheSpec:
"""
Get all kv cache needed by the model by invoking the underlying worker.
"""
kv_cache_specs = self.collective_rpc("get_kv_cache_spec")
assert all(s == kv_cache_specs[0] for s in kv_cache_specs)
return kv_cache_specs[0]

def collective_rpc(self,
method: str,
Expand Down
Loading
Loading