-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V1] Move more control of kv cache initialization from model_executor to EngineCore #11960
Changes from 16 commits
0827ca8
6025d5e
6024290
03130cd
1229600
e3764d4
fec7d2d
4294435
f79dff2
97176da
e6179a8
eb37f0c
88fd1b8
3493061
105814a
9ff57d0
044876e
62f2c09
138a4ac
00f2bda
2aa7509
e8a1eb0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from typing import List | ||
|
||
import torch | ||
|
||
from vllm.v1.utils import bind_kv_cache | ||
|
||
|
||
def test_bind_kv_cache(): | ||
from vllm.attention import Attention | ||
|
||
ctx = { | ||
'layers.0.self_attn': Attention(32, 128, 0.1), | ||
'layers.1.self_attn': Attention(32, 128, 0.1), | ||
'layers.2.self_attn': Attention(32, 128, 0.1), | ||
'layers.3.self_attn': Attention(32, 128, 0.1), | ||
} | ||
kv_cache = { | ||
'layers.0.self_attn': torch.zeros((1, )), | ||
'layers.1.self_attn': torch.zeros((1, )), | ||
'layers.2.self_attn': torch.zeros((1, )), | ||
'layers.3.self_attn': torch.zeros((1, )), | ||
} | ||
runner_kv_caches: List[torch.Tensor] = [] | ||
bind_kv_cache(ctx, runner_kv_caches, kv_cache) | ||
assert ctx['layers.0.self_attn'].kv_cache[0] is kv_cache[ | ||
'layers.0.self_attn'] | ||
assert ctx['layers.1.self_attn'].kv_cache[0] is kv_cache[ | ||
'layers.1.self_attn'] | ||
assert ctx['layers.2.self_attn'].kv_cache[0] is kv_cache[ | ||
'layers.2.self_attn'] | ||
assert ctx['layers.3.self_attn'].kv_cache[0] is kv_cache[ | ||
'layers.3.self_attn'] | ||
|
||
assert runner_kv_caches[0] is kv_cache['layers.0.self_attn'] | ||
assert runner_kv_caches[1] is kv_cache['layers.1.self_attn'] | ||
assert runner_kv_caches[2] is kv_cache['layers.2.self_attn'] | ||
assert runner_kv_caches[3] is kv_cache['layers.3.self_attn'] | ||
|
||
|
||
def test_bind_kv_cache_non_attention(): | ||
from vllm.attention import Attention | ||
|
||
# example from Jamba PP=2 | ||
ctx = { | ||
'model.layers.20.attn': Attention(32, 128, 0.1), | ||
'model.layers.28.attn': Attention(32, 128, 0.1), | ||
} | ||
kv_cache = { | ||
'model.layers.20.attn': torch.zeros((1, )), | ||
'model.layers.28.attn': torch.zeros((1, )), | ||
} | ||
|
||
runner_kv_caches: List[torch.Tensor] = [] | ||
bind_kv_cache(ctx, runner_kv_caches, kv_cache) | ||
|
||
assert ctx['model.layers.20.attn'].kv_cache[0] is kv_cache[ | ||
'model.layers.20.attn'] | ||
assert ctx['model.layers.28.attn'].kv_cache[0] is kv_cache[ | ||
'model.layers.28.attn'] | ||
|
||
assert runner_kv_caches[0] is kv_cache['model.layers.20.attn'] | ||
assert runner_kv_caches[1] is kv_cache['model.layers.28.attn'] | ||
|
||
|
||
def test_bind_kv_cache_encoder_decoder(): | ||
from vllm.attention import Attention, AttentionType | ||
|
||
# example from bart | ||
ctx = { | ||
'encoder.layers.0.self_attn.attn': | ||
Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER), | ||
'decoder.layers.0.encoder_attn.attn': | ||
Attention(32, 128, 0.1, attn_type=AttentionType.ENCODER_DECODER), | ||
'decoder.layers.0.self_attn.attn': | ||
Attention(32, 128, 0.1, attn_type=AttentionType.DECODER), | ||
} | ||
|
||
kv_cache_tensor = torch.zeros((1, )) | ||
kv_cache = { | ||
'decoder.layers.0.encoder_attn.attn': kv_cache_tensor, | ||
'decoder.layers.0.self_attn.attn': kv_cache_tensor, | ||
} | ||
encoder_kv_cache = ctx['encoder.layers.0.self_attn.attn'].kv_cache | ||
|
||
runner_kv_caches: List[torch.Tensor] = [] | ||
bind_kv_cache(ctx, runner_kv_caches, kv_cache) | ||
assert ctx['encoder.layers.0.self_attn.attn'].kv_cache is encoder_kv_cache | ||
assert ctx['decoder.layers.0.encoder_attn.attn'].kv_cache[0] is kv_cache[ | ||
'decoder.layers.0.encoder_attn.attn'] | ||
assert ctx['decoder.layers.0.self_attn.attn'].kv_cache[0] is kv_cache[ | ||
'decoder.layers.0.self_attn.attn'] | ||
|
||
assert runner_kv_caches[0] is kv_cache_tensor |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,10 @@ | |
from dataclasses import dataclass | ||
from typing import Any, List, NamedTuple, Optional, Tuple | ||
|
||
from vllm.config import VllmConfig | ||
from vllm.logger import init_logger | ||
from vllm.v1.kv_cache_interface import (KVCacheConfig, KVCacheSpec, | ||
KVCacheTensor) | ||
from vllm.v1.request import Request | ||
|
||
logger = init_logger(__name__) | ||
|
@@ -305,3 +308,118 @@ def hash_request_tokens(block_size: int, | |
ret.append(block_hash) | ||
parent_block_hash_value = block_hash.hash_value | ||
return ret | ||
|
||
|
||
def check_enough_kv_cache_memory(vllm_config: VllmConfig, | ||
kv_cache_spec: KVCacheSpec, | ||
available_memory: int): | ||
""" | ||
Checks whether `available_memory` is enough for the KV cache of at least one | ||
request with the model's max_model_len. | ||
heheda12345 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Args: | ||
comaniac marked this conversation as resolved.
Show resolved
Hide resolved
|
||
vllm_config: The global VllmConfig | ||
kv_cache_spec: The kv cache spec of the model | ||
available_memory (int): Memory available for KV cache in bytes. | ||
Raises: | ||
ValueError: If there is not enough memory available for the KV cache. | ||
""" | ||
|
||
if available_memory <= 0: | ||
raise ValueError("No available memory for the cache blocks. " | ||
"Try increasing `gpu_memory_utilization` when " | ||
"initializing the engine.") | ||
|
||
max_model_len = vllm_config.model_config.max_model_len | ||
needed_memory = 0 | ||
for layer_spec in kv_cache_spec.values(): | ||
needed_memory += layer_spec.bytes_for_tokens(max_model_len) | ||
|
||
if needed_memory > available_memory: | ||
raise ValueError( | ||
f"To serve at least one request with the models's max seq len " | ||
f"({max_model_len}), ({needed_memory/1024/1024/1024} GB KV cache is" | ||
heheda12345 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
f"needed, which is larger than the available KV Cache memory " | ||
f"({available_memory/1024/1024/1024} GB). Try increasing " | ||
heheda12345 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
f"`gpu_memory_utilization` or decreasing `max_model_len` when " | ||
f"initializing the engine.") | ||
|
||
|
||
def is_same_type(kv_cache_spec: KVCacheSpec) -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel this function name is a bit unclear. It actually checks whether the "kv cache specs" of "all" layers are the same, so it should be informative about "spec", "all" and "same". Maybe "is_uniformed_kv_cache_type" or something like that would be better. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. changed to |
||
""" | ||
Whether all layers in the given KVCacheSpec have the same type of KV cache. | ||
Args: | ||
kv_cache_spec (KVCacheSpec): The KVCacheSpec of the model | ||
Returns: | ||
True if all layers have the same type, False otherwise. | ||
""" | ||
|
||
layer_keys = set(layer.type_id for layer in kv_cache_spec.values()) | ||
return len(layer_keys) == 1 | ||
|
||
|
||
def _get_kv_cache_config_same_type( | ||
vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, | ||
available_memory: int) -> Tuple[KVCacheConfig, int]: | ||
""" | ||
Generates the KV cache configuration for a model with one type of KV cache. | ||
Divide the available memory equally among all layers. | ||
Args: | ||
vllm_config: The global VllmConfig | ||
kv_cache_spec: The kv cache spec of the model | ||
available_memory (int): Memory available for KV cache in bytes. | ||
Returns: | ||
Tuple[KVCacheConfig, int]: The generated KVCacheConfig and the number of | ||
GPU blocks. | ||
""" | ||
|
||
page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()} | ||
assert len(page_sizes) == 1 | ||
page_size = page_sizes.pop() | ||
|
||
num_gpu_blocks = int(available_memory // page_size // len(kv_cache_spec)) | ||
num_gpu_blocks = max(num_gpu_blocks, 0) | ||
|
||
if vllm_config.cache_config.num_gpu_blocks_override is not None: | ||
num_gpu_blocks_override = \ | ||
vllm_config.cache_config.num_gpu_blocks_override | ||
logger.info( | ||
"Overriding num_gpu_blocks=%d with " | ||
"num_gpu_blocks_override=%d", num_gpu_blocks, | ||
num_gpu_blocks_override) | ||
num_gpu_blocks = num_gpu_blocks_override | ||
|
||
logger.info("# GPU blocks: %d", num_gpu_blocks) | ||
|
||
per_layer_size = page_size * num_gpu_blocks | ||
|
||
kv_cache_config = KVCacheConfig( | ||
tensors={ | ||
layer_name: KVCacheTensor(size=per_layer_size) | ||
for layer_name in kv_cache_spec | ||
}, | ||
groups=[[layer_name for layer_name in kv_cache_spec]], | ||
kv_cache_spec=kv_cache_spec) | ||
return kv_cache_config, num_gpu_blocks | ||
|
||
|
||
def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, | ||
available_memory: int) -> Tuple[KVCacheConfig, int]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again this function name (also the callee One possibility: def get_kv_cache_config_and_available_blocks(...):
check_enough_kv_cache_memory(...)
# Later maybe you can introduce a registry when you have more policies.
if is_uniformed_kv_cache_type(...):
return _get_kv_cache_config_and_blocks_for_unifiemd_type(...)
return _get_kv_cache_config_and_blocks_for_xxx(...) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed |
||
""" | ||
Generates the KV cache configuration for a model | ||
TODO: support hybrid models with more than one type of KV cache. | ||
Args: | ||
vllm_config: The global VllmConfig | ||
kv_cache_spec: The kv cache spec of the model | ||
available_memory (int): Memory available for KV cache in bytes. | ||
Returns: | ||
Tuple[KVCacheConfig, int]: The generated KVCacheConfig and the number of | ||
GPU blocks. | ||
""" | ||
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) | ||
if is_same_type(kv_cache_spec): | ||
# kv cache of all layers are the same, which is true for most models. | ||
# Allocate the same amount of memory for each layer. | ||
return _get_kv_cache_config_same_type(vllm_config, kv_cache_spec, | ||
available_memory) | ||
else: | ||
raise NotImplementedError |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,6 +23,7 @@ | |
from vllm.utils import (get_distributed_init_method, get_mp_context, | ||
get_open_port, get_open_zmq_ipc_path, zmq_socket_ctx) | ||
from vllm.v1.executor.abstract import Executor | ||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec | ||
from vllm.v1.outputs import ModelRunnerOutput | ||
from vllm.worker.worker_base import WorkerWrapperBase | ||
|
||
|
@@ -90,29 +91,33 @@ def sigusr1_handler(signum, frame): | |
for w in self.workers: | ||
w.worker_response_mq.wait_until_ready() | ||
|
||
def initialize(self, num_gpu_blocks: int) -> None: | ||
def initialize(self, kv_cache_config: KVCacheConfig) -> None: | ||
""" | ||
Initialize the KV caches and begin the model execution loop of the | ||
underlying workers. | ||
""" | ||
logger.info("# GPU blocks: %d", num_gpu_blocks) | ||
self.collective_rpc("initialize_cache", args=(num_gpu_blocks, )) | ||
self.collective_rpc("initialize_cache", args=(kv_cache_config, )) | ||
self.collective_rpc("compile_or_warm_up_model") | ||
|
||
def determine_num_available_blocks(self) -> Tuple[int, int]: | ||
def determine_available_memory(self) -> int: | ||
""" | ||
Determine the number of available KV blocks by invoking the | ||
Determine the available memory for KV cache by invoking the | ||
heheda12345 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
underlying worker. | ||
""" | ||
num_blocks = self.collective_rpc("determine_num_available_blocks") | ||
memory_sizes = self.collective_rpc("determine_available_memory") | ||
|
||
# Since we use a shared centralized controller, we take the minimum | ||
# number of blocks across all workers to make sure all the memory | ||
# memory size across all workers to make sure all the memory | ||
# operators can be applied to all workers. | ||
num_gpu_blocks = min(b[0] for b in num_blocks) | ||
num_cpu_blocks = min(b[1] for b in num_blocks) | ||
return min(memory_sizes) | ||
|
||
return num_gpu_blocks, num_cpu_blocks | ||
def get_kv_cache_spec(self) -> KVCacheSpec: | ||
""" | ||
Get all kv cache needed by the model by invoking the underlying worker. | ||
""" | ||
kv_cache_specs = self.collective_rpc("get_kv_cache_spec") | ||
assert all(lc == kv_cache_specs[0] for lc in kv_cache_specs) | ||
return kv_cache_specs[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How would this be extended later if you have different specs? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It won't be extended. kv_cache_spec[i] is for all layers of one GPU. Different TP GPUs always have the same spec though the spec of each layer inside one GPU can be different. |
||
|
||
def collective_rpc(self, | ||
method: str, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is heavy code duplication.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Running the same test as v0, but the detail workflow is very different due to the dict->list and list->dict difference.