Skip to content

Commit

Permalink
[Core] Create metrics for external cache services (#3)
Browse files Browse the repository at this point in the history
* Enable vineyard llm kv cache in vLLM

Based on another version of vllm: sighingnow@d347dab

Cherry-pick from commit d347dab

Signed-off-by: Tao He <linzhu.ht@alibaba-inc.com>
(cherry picked from commit 1545f6bf7edcd667e305d3fbcadd913066f04747)

resolving vllm update diff

temporarily comment out torch.distributed for single node env

add VineyardCacheConfig with https://github.com/v6d-io/v6d/blob/ebe8f077e3d3780a27d49238c501854b6b8e29df/modules/llm-cache/ds/kv_cache_block.cc#L163 commented out; cache_ops fix

remove CacheConfig from argument (configure through ENV)

v6d: fix integration w/ v1 APIs

Signed-off-by: Haiyang Shi <haiyang.shi@bytedance.com>

Change model_runner to latest version

cherry pick model_runner from d347dab source sighingnow@d347dab

fix reshape_and_cache_flash argument

add cache prefetch/update to work_base

clean up

Fix after rebase to 029c71d

remove tensor copy from cache managed address to pin memory

clean up

Add fixes to address comments

adding cache service metrics initial

adding cache service metrics initial

update ttft metrics

update prefix caching with max num seqs argument

* fix token_len stat; Add median metrics

* FIX: avg metrics collection; using cuda_event to collect metrics

* add reshape time

* Address comments

---------

Co-authored-by: Tao He <linzhu.ht@alibaba-inc.com>
happyandslow and sighingnow authored Oct 25, 2024
1 parent 7cdac48 commit a8ae12c
Showing 10 changed files with 149 additions and 28 deletions.
8 changes: 7 additions & 1 deletion benchmarks/benchmark_prefix_caching.py
Original file line number Diff line number Diff line change
@@ -136,7 +136,9 @@ def main(args):
use_v2_block_manager=args.use_v2_block_manager,
tensor_parallel_size=args.tensor_parallel_size,
enable_prefix_caching=args.enable_prefix_caching,
enable_chunked_prefill=args.enable_chunked_prefill)
enable_chunked_prefill=args.enable_chunked_prefill,
max_num_seqs = args.max_num_seqs,
)

sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len)

@@ -198,5 +200,9 @@ def main(args):
default='128:256',
help='Range of input lengths for sampling prompts,'
'specified as "min:max" (e.g., "128:256").')
parser.add_argument('--max-num-seqs',
type=int,
default=256,
help='Maximum number of sequences per iteration.')
args = parser.parse_args()
main(args)
5 changes: 3 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
@@ -54,6 +54,7 @@
usage_message)
from vllm.utils import Counter, Device
from vllm.version import __version__ as VLLM_VERSION
from vllm.worker.vineyard_llm_cache import CacheServiceMetrics

logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5
@@ -296,6 +297,7 @@ def __init__(
)
self.log_stats = log_stats
self.step_return_finished_only = step_return_finished_only
self.cache_service_metrics = CacheServiceMetrics

if not self.model_config.skip_tokenizer_init:
self.tokenizer = self._init_tokenizer()
@@ -422,7 +424,6 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
# See https://prometheus.github.io/client_python/multiprocess/
from vllm.engine.metrics import (LoggingStatLogger,
PrometheusStatLogger)

self.stat_loggers = {
"logging":
LoggingStatLogger(
@@ -477,7 +478,7 @@ def _initialize_kv_caches(self) -> None:
self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks

self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks)
self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks, cache_service_metrics = self.cache_service_metrics)

@classmethod
def _get_executor_cls(cls,
7 changes: 7 additions & 0 deletions vllm/engine/metrics.py
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@

if TYPE_CHECKING:
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
from vllm.worker.vineyard_llm_cache import CacheServiceMetrics

logger = init_logger(__name__)

@@ -374,6 +375,12 @@ def log(self, stats: Stats) -> None:
self._format_spec_decode_metrics_str(
self.spec_decode_metrics))

if self.external_cache_service_metrics is not None:
logger.info(
"Cache service hit rate: by tokens: %.2f%%, by blocks: %.2f%%",
0 if self.external_cache_service_metrics.total_tokens == 0 else self.external_cache_service_metrics.hit_tokens/self.external_cache_service_metrics.total_tokens * 100,
0 if self.external_cache_service_metrics.total_blocks == 0 else self.external_cache_service_metrics.hit_blocks/self.external_cache_service_metrics.total_blocks * 100,
)
# Reset tracked stats for next interval.
self.num_prompt_tokens = []
self.num_generation_tokens = []
5 changes: 3 additions & 2 deletions vllm/engine/metrics_types.py
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@
from typing import Dict, List, Optional, Protocol

from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics

from vllm.worker.vineyard_llm_cache import CacheServiceMetrics

@dataclass
class Stats:
@@ -71,8 +71,9 @@ def __init__(self, local_interval: float) -> None:
self.num_generation_tokens: List[int] = []
self.last_local_log = time.time()
self.local_interval = local_interval
self.external_cache_service_metrics = CacheServiceMetrics
self.spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None

@abstractmethod
def log(self, stats: Stats) -> None:
raise NotImplementedError
33 changes: 32 additions & 1 deletion vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
@@ -26,6 +26,8 @@
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Counter, deprecate_kwargs, is_list_of

import numpy as np

logger = init_logger(__name__)


@@ -716,10 +718,14 @@ def _run_engine(
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
total_in_toks = 0
total_out_toks = 0
times_to_first_token = []
normalized_times_to_first_token = []
while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step()
for output in step_outputs:
if output.finished:
times_to_first_token.append(output.metrics.first_token_time - output.metrics.first_scheduled_time)
normalized_times_to_first_token.append((output.metrics.first_token_time - output.metrics.first_scheduled_time)/len(output.prompt_token_ids))
outputs.append(output)
if use_tqdm:
if isinstance(output, RequestOutput):
@@ -734,10 +740,35 @@ def _run_engine(
f"est. speed input: {in_spd:.2f} toks/s, "
f"output: {out_spd:.2f} toks/s")
pbar.update(1)
if self.llm_engine.cache_service_metrics is not None:
logger.info(
"Cache service hit rate: by tokens: %.2f%%, by blocks: %.2f%%, total tokens hit %d, number of measurement collected %d",
0 if self.llm_engine.cache_service_metrics.total_tokens == 0 else self.llm_engine.cache_service_metrics.hit_tokens/self.llm_engine.cache_service_metrics.total_tokens * 100,
0 if self.llm_engine.cache_service_metrics.total_blocks == 0 else self.llm_engine.cache_service_metrics.hit_blocks/self.llm_engine.cache_service_metrics.total_blocks * 100,
self.llm_engine.cache_service_metrics.hit_tokens,
self.llm_engine.cache_service_metrics.counter,
)
logger.info(
"Cache service time query avg %.4f std %.4f median %.4f, normalized mean %.4f std %.4f median %.4f, time load avg %.4f std %.4f median %.4f, normalized mean %.4f std %.4f median %.4f, time reshape avg %.4f std %.4f median %.4f, normalized mean %.4f std %.4f median %.4f, time unload avg %.4f std %.4f median %.4f, normalized mean %.4f std %.4f median %.4f, time update avg %.4f std %.4f median %.4f, normalized mean %.4f std %.4f median %.4f, time to first token avg %.4f std %.4f median %.4f, normalized time to first token %.4f std %.4f median %.4f",
np.mean(self.llm_engine.cache_service_metrics.time_query), np.std(self.llm_engine.cache_service_metrics.time_query), np.median(self.llm_engine.cache_service_metrics.time_query),
np.mean(self.llm_engine.cache_service_metrics.normalized_time_query), np.std(self.llm_engine.cache_service_metrics.normalized_time_query), np.median(self.llm_engine.cache_service_metrics.normalized_time_query),
np.mean(self.llm_engine.cache_service_metrics.time_load), np.std(self.llm_engine.cache_service_metrics.time_load), np.median(self.llm_engine.cache_service_metrics.time_load),
np.mean(self.llm_engine.cache_service_metrics.normalized_time_load), np.std(self.llm_engine.cache_service_metrics.normalized_time_load), np.median(self.llm_engine.cache_service_metrics.normalized_time_load),
np.mean(self.llm_engine.cache_service_metrics.time_reshape), np.std(self.llm_engine.cache_service_metrics.time_reshape), np.median(self.llm_engine.cache_service_metrics.time_reshape),
np.mean(self.llm_engine.cache_service_metrics.normalized_time_reshape), np.std(self.llm_engine.cache_service_metrics.normalized_time_reshape), np.median(self.llm_engine.cache_service_metrics.normalized_time_reshape),
np.mean(self.llm_engine.cache_service_metrics.time_unload), np.std(self.llm_engine.cache_service_metrics.time_unload), np.median(self.llm_engine.cache_service_metrics.time_unload),
np.mean(self.llm_engine.cache_service_metrics.normalized_time_unload), np.std(self.llm_engine.cache_service_metrics.normalized_time_unload), np.median(self.llm_engine.cache_service_metrics.normalized_time_unload),
np.mean(self.llm_engine.cache_service_metrics.time_update), np.std(self.llm_engine.cache_service_metrics.time_update), np.median(self.llm_engine.cache_service_metrics.time_update),
np.mean(self.llm_engine.cache_service_metrics.normalized_time_update), np.std(self.llm_engine.cache_service_metrics.normalized_time_update), np.median(self.llm_engine.cache_service_metrics.normalized_time_update),
np.mean(times_to_first_token), np.std(times_to_first_token), np.median(times_to_first_token),
np.mean(normalized_times_to_first_token), np.std(normalized_times_to_first_token), np.median(normalized_times_to_first_token),

)


# Restore original behavior
self.llm_engine.step_return_finished_only = False

if use_tqdm:
pbar.close()
# Sort the outputs by request ID.
8 changes: 6 additions & 2 deletions vllm/executor/distributed_gpu_executor.py
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.worker.vineyard_llm_cache import CacheServiceMetrics

logger = init_logger(__name__)

@@ -47,7 +48,8 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
return num_gpu_blocks, num_cpu_blocks

def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
num_cpu_blocks: int,
cache_service_metrics: Optional[CacheServiceMetrics] = None) -> None:
"""Initialize the KV cache in all workers.
"""

@@ -62,7 +64,9 @@ def initialize_cache(self, num_gpu_blocks: int,

self._run_workers("initialize_cache",
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)
num_cpu_blocks=num_cpu_blocks,
cache_service_metrics = cache_service_metrics
)

def execute_model(
self,
14 changes: 10 additions & 4 deletions vllm/executor/gpu_executor.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
from vllm.worker.worker_base import WorkerBase, WorkerWrapperBase
from vllm.worker.vineyard_llm_cache import CacheServiceMetrics

logger = init_logger(__name__)

@@ -113,16 +114,21 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
"""
return self.driver_worker.determine_num_available_blocks()

def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
def initialize_cache(self,
num_gpu_blocks: int,
num_cpu_blocks: int,
cache_service_metrics: Optional[CacheServiceMetrics] = None) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
# NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations.
logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
num_cpu_blocks)
logger.info("# GPU blocks: %d, # CPU blocks: %d cache_service_metrics %s", num_gpu_blocks, num_cpu_blocks, cache_service_metrics)

self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
self.driver_worker.initialize_cache(
num_gpu_blocks,
num_cpu_blocks,
cache_service_metrics = cache_service_metrics)

def execute_model(
self, execute_model_req: ExecuteModelRequest
12 changes: 10 additions & 2 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
@@ -55,6 +55,7 @@
_add_sampling_metadata_broadcastable_dict,
_init_attn_metadata_from_tensor_dict,
_init_sampling_metadata_from_tensor_dict)
from vllm.worker.vineyard_llm_cache import CacheServiceMetrics

if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
@@ -976,6 +977,7 @@ def __init__(
# Multi-modal data support
self.input_registry = input_registry
self.mm_registry = mm_registry
self.cache_service_metrics = None
self.multi_modal_input_mapper = mm_registry \
.create_input_mapper(model_config)
self.mm_registry.init_mm_limits_per_prompt(self.model_config)
@@ -997,7 +999,7 @@ def __init__(
# to ensure the tensor model parallel group is initialized.
self.vineyard_llm_cache = None

def _init_vineyard_cache(self):
def _init_vineyard_cache(self, metrics: CacheServiceMetrics = None):
if envs.VLLM_USE_VINEYARD_CACHE:
if not self.scheduler_config.chunked_prefill_enabled:
raise Exception("Vineyard LLM cache is not enabled, requires chunked prefill")
@@ -1011,6 +1013,7 @@ def _init_vineyard_cache(self):
kv_cache_dtype=self.kv_cache_dtype,
torch_dtype=get_kv_cache_torch_dtype(self.kv_cache_dtype,
self.model_config.dtype),
metrics = metrics,
)
logger.info("Using Vineyard LLM cache")
else:
@@ -1020,6 +1023,7 @@ def _init_vineyard_cache(self):
self.inter_data_cache: Dict[int, PyObjectCache] = {}
self.sampling_metadata_cache: SamplingMetadataCache = \
SamplingMetadataCache()
logger.info(f"Initialize CacheServiceMetric {metrics}")

def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
@@ -1096,7 +1100,7 @@ def load_model(self) -> None:
self.model,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend="eager")
self._init_vineyard_cache()
self._init_vineyard_cache(self.cache_service_metrics)

def save_sharded_state(
self,
@@ -1128,6 +1132,10 @@ def save_tensorized_model(
def get_max_block_per_batch(self) -> int:
block_size = self.block_size
return (self.max_seq_len_to_capture + block_size - 1) // block_size

def set_cache_service_metrics(self, metrics: CacheServiceMetrics) -> None:
if self.vineyard_llm_cache != None:
self.vineyard_llm_cache.metrics = metrics

def _prepare_model_input_tensors(
self,
78 changes: 66 additions & 12 deletions vllm/worker/vineyard_llm_cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import time
import numpy as np
from typing import Dict, List, NamedTuple, Optional, Set, Tuple

import torch
@@ -24,7 +25,25 @@

logger = init_logger(__name__)


class CacheServiceMetrics:
hit_tokens: int = 0 # Total number of tokens hit.
total_tokens: int = 0 # Total number of tokens requested.
hit_blocks: int = 0 # Total number of blocks hit.
total_blocks: int = 0 # Total number of blocks requested.
counter: int = 0 # Total number of measurements.
time_query: list = [] # Times used query cache from cache service.
time_load: list = [] # Times used load fetched cache to device memory.
time_reshape: list = [] # Times used reshaping tensors for flash attention KV format.
time_unload: list = [] # Times used move computed KV from device memory.
time_update: list = [] # Times used update computed KV to cache service.
normalized_time_query: list = [] # Times used query cache from cache service normalized by number of tokens.
normalized_time_load: list = [] # Times used load fetched cache to device memory normalized by number of tokens.
normalized_time_reshape: list = [] # Times used reshaping tensors for flash attention KV format normalized by number of tokens.
normalized_time_unload: list = [] # Times used move computed KV from device memory normalized by number of tokens.
normalized_time_update: list = [] # Times used update computed KV to cache service normalized by number of tokens.



class VineyardLLMCache:
def __init__(
self,
@@ -34,6 +53,7 @@ def __init__(
layer: int = 2,
kv_cache_dtype: str = None,
torch_dtype: torch.dtype = torch.bfloat16,
metrics: CacheServiceMetrics = None
):
self._init_vineyard_logger()

@@ -67,6 +87,8 @@ def __init__(
VineyardKVTensor(k_tensor.data_ptr(), k_tensor.numel() * k_tensor.element_size()),
VineyardKVTensor(v_tensor.data_ptr(), v_tensor.numel() * v_tensor.element_size()),
))
self.metrics = metrics
logger.info(f"VineyardLLMCache init {metrics}")

def _init_vineyard_logger(self):
import vineyard
@@ -83,6 +105,7 @@ def from_envs(
parallel_config: ParallelConfig,
kv_cache_dtype: str,
torch_dtype: torch.dtype = torch.bfloat16,
metrics: CacheServiceMetrics = None,
) -> Optional["VineyardLLMCache"]:
if VineyardKVCache is None:
logger.warn("VineyardKVCache module is not available")
@@ -95,14 +118,15 @@ def from_envs(
head_size = model_config.get_head_size()
num_kv_heads = model_config.get_num_kv_heads(parallel_config)
num_layers = model_config.get_num_layers(parallel_config)

logger.info(f"VineyardLLMCache from_envs {metrics}")
return VineyardLLMCache(
head_size=head_size,
num_kv_heads=num_kv_heads,
cache_capacity=2**20,
layer=num_layers,
kv_cache_dtype=kv_cache_dtype,
torch_dtype=torch_dtype,
metrics = metrics
)

def prefetch_seq_kv_caches(
@@ -129,7 +153,6 @@ def prefetch_seq_kv_caches(
# alignment `context_len` to `self.chunk_size`
query_context_len = context_len - context_len % self.chunk_size
query_token_size = context_len + token_chunk_size - query_context_len

query_prefix = tokens[:query_context_len]
query_tokens = tokens[query_context_len:query_context_len + query_token_size]
query_args = [
@@ -156,13 +179,15 @@ def prefetch_seq_kv_caches(
query_tokens
) = query_args


start_time = time.time()
matched = self.cache.query(
prefix=query_prefix,
tokens=query_tokens,
kv_cache_list=self.tensors[:query_token_size],
)

duration = time.time() - start_time
self.metrics.time_query.append(duration)
self.metrics.normalized_time_query.append(duration/len(tokens))
# synchronized across tensor parallel ranks
matched_tensor = torch.tensor([matched], dtype=torch.long, device='cuda')
# torch.distributed.all_reduce(matched_tensor, op=torch.distributed.ReduceOp.MIN,
@@ -176,7 +201,6 @@ def prefetch_seq_kv_caches(
matched = min(matched, token_chunk_size - 1)
if matched <= 0:
return seq_id, 0

if seq_group_metadata is not None:
block_table = seq_group_metadata.block_tables[seq_id]
slot_mapping = []
@@ -193,9 +217,25 @@ def prefetch_seq_kv_caches(
# torch.distributed.broadcast(slot_mapping, src=0,
# group=get_tensor_model_parallel_group())


self.metrics.hit_tokens += matched
self.metrics.total_tokens += query_token_size
self.metrics.hit_blocks += (matched // block_size)
self.metrics.total_blocks += ((-query_token_size) // (-block_size))
self.metrics.counter += 1
# save to GPU kv cache
copy_start = torch.cuda.Event(enable_timing=True)
copy_end = torch.cuda.Event(enable_timing=True)
copy_start.record()
buffer = self.buffer[:, :, offset:offset+matched].cuda()
copy_end.record()
torch.cuda.synchronize()
duration = copy_start.elapsed_time(copy_end) / 1000.0
self.metrics.time_load.append(duration)
self.metrics.normalized_time_load.append(0 if matched == 0 else duration/matched)

reshape_start = torch.cuda.Event(enable_timing=True)
reshape_end = torch.cuda.Event(enable_timing=True)
reshape_start.record()
for j in range(self.layer):
# use `reshape_and_cache_flash` rather than `copy_` as
# the target kv cache slots is not contingous.
@@ -209,7 +249,12 @@ def prefetch_seq_kv_caches(
1.0,
1.0
)

reshape_end.record()
torch.cuda.synchronize()
duration = reshape_start.elapsed_time(reshape_end) / 1000.0
self.metrics.time_reshape.append(duration)
self.metrics.normalized_time_reshape.append(0 if matched == 0 else duration/matched)

# update the seq_group_metadata's and seq's metadata
if seq_group_metadata is not None:
seq_data.update_num_computed_tokens(matched)
@@ -243,7 +288,6 @@ def prefetch_kv_caches(
# group=get_tensor_model_parallel_group())
prefill_requests = [None] * num_prefill_requests[0]
num_prefill_requests = num_prefill_requests[0]

matched = {}
for seq_group_meta in prefill_requests:
seq_id, seq_matched = self.prefetch_seq_kv_caches(
@@ -297,14 +341,12 @@ def update_seq_kv_caches(
update_prefix,
update_tokens,
) = update_args

if update_token_size <= 0:
# restore the seq_group_metadata's and seq's metadata
if seq_group_metadata is not None:
seq_data.update_num_computed_tokens(-matched[seq_id])
seq_group_metadata.token_chunk_size += matched[seq_id]
return seq_id, 0

if seq_group_metadata is not None:
block_table = seq_group_metadata.block_tables[seq_id]
slot_mapping = []
@@ -322,16 +364,28 @@ def update_seq_kv_caches(
# group=get_tensor_model_parallel_group())

# fetch from GPU kv cache
start_unload = torch.cuda.Event(enable_timing=True)
end_unload = torch.cuda.Event(enable_timing=True)
start_unload.record()
for j in range(self.layer):
self.buffer[:, j, :update_token_size].copy_(
kv_caches[j][:, slot_mapping // block_size, slot_mapping % block_size])

torch.cuda.synchronize()
end_unload.record()
duration = start_unload.elapsed_time(end_unload) / 1000.0
self.metrics.time_unload.append(duration)
self.metrics.normalized_time_unload.append(0 if update_token_size == 0 else duration/update_token_size)

start_time = time.time()
# updates into vineyard
updated = self.cache.update(
prefix=update_prefix,
tokens=update_tokens,
kv_cache_list=self.tensors[:update_token_size],
)
duration = time.time() - start_time
self.metrics.time_update.append(duration)
self.metrics.normalized_time_update.append(0 if update_token_size == 0 else duration/update_token_size)

# restore the seq_group_metadata's and seq's metadata
if seq_group_metadata is not None:
7 changes: 5 additions & 2 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
@@ -28,6 +28,7 @@
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput
from vllm.worker.vineyard_llm_cache import CacheServiceMetrics

logger = init_logger(__name__)

@@ -250,7 +251,8 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
return num_gpu_blocks, num_cpu_blocks

def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
num_cpu_blocks: int,
cache_service_metrics: Optional[CacheServiceMetrics] = None) -> None:
"""Allocate GPU and CPU KV cache with the specified number of blocks.
This also warms up the model, which may record CUDA graphs.
@@ -261,7 +263,8 @@ def initialize_cache(self, num_gpu_blocks: int,

self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks

logger.info(f"worker initialize_cache initialize cache_service_metrics {cache_service_metrics} ")
self.model_runner.set_cache_service_metrics(cache_service_metrics)
self._init_cache_engine()
self._warm_up_model()

0 comments on commit a8ae12c

Please sign in to comment.