diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index e71d4c1bcde5e..9286b9fe633fa 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -136,7 +136,9 @@ def main(args): use_v2_block_manager=args.use_v2_block_manager, tensor_parallel_size=args.tensor_parallel_size, enable_prefix_caching=args.enable_prefix_caching, - enable_chunked_prefill=args.enable_chunked_prefill) + enable_chunked_prefill=args.enable_chunked_prefill, + max_num_seqs = args.max_num_seqs, + ) sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) @@ -198,5 +200,9 @@ def main(args): default='128:256', help='Range of input lengths for sampling prompts,' 'specified as "min:max" (e.g., "128:256").') + parser.add_argument('--max-num-seqs', + type=int, + default=256, + help='Maximum number of sequences per iteration.') args = parser.parse_args() main(args) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 94271c4a93151..1339ed97cf138 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -54,6 +54,7 @@ usage_message) from vllm.utils import Counter, Device from vllm.version import __version__ as VLLM_VERSION +from vllm.worker.vineyard_llm_cache import CacheServiceMetrics logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 @@ -296,6 +297,7 @@ def __init__( ) self.log_stats = log_stats self.step_return_finished_only = step_return_finished_only + self.cache_service_metrics = CacheServiceMetrics if not self.model_config.skip_tokenizer_init: self.tokenizer = self._init_tokenizer() @@ -422,7 +424,6 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: # See https://prometheus.github.io/client_python/multiprocess/ from vllm.engine.metrics import (LoggingStatLogger, PrometheusStatLogger) - self.stat_loggers = { "logging": LoggingStatLogger( @@ -477,7 +478,7 @@ def _initialize_kv_caches(self) -> None: self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks - self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks) + self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks, cache_service_metrics = self.cache_service_metrics) @classmethod def _get_executor_cls(cls, diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index 74277cae7c8ef..da28355b3694f 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics + from vllm.worker.vineyard_llm_cache import CacheServiceMetrics logger = init_logger(__name__) @@ -374,6 +375,12 @@ def log(self, stats: Stats) -> None: self._format_spec_decode_metrics_str( self.spec_decode_metrics)) + if self.external_cache_service_metrics is not None: + logger.info( + "Cache service hit rate: by tokens: %.2f%%, by blocks: %.2f%%", + 0 if self.external_cache_service_metrics.total_tokens == 0 else self.external_cache_service_metrics.hit_tokens/self.external_cache_service_metrics.total_tokens * 100, + 0 if self.external_cache_service_metrics.total_blocks == 0 else self.external_cache_service_metrics.hit_blocks/self.external_cache_service_metrics.total_blocks * 100, + ) # Reset tracked stats for next interval. self.num_prompt_tokens = [] self.num_generation_tokens = [] diff --git a/vllm/engine/metrics_types.py b/vllm/engine/metrics_types.py index 1eccb23593408..57647f1fa7e5f 100644 --- a/vllm/engine/metrics_types.py +++ b/vllm/engine/metrics_types.py @@ -17,7 +17,7 @@ from typing import Dict, List, Optional, Protocol from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics - +from vllm.worker.vineyard_llm_cache import CacheServiceMetrics @dataclass class Stats: @@ -71,8 +71,9 @@ def __init__(self, local_interval: float) -> None: self.num_generation_tokens: List[int] = [] self.last_local_log = time.time() self.local_interval = local_interval + self.external_cache_service_metrics = CacheServiceMetrics self.spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None - + @abstractmethod def log(self, stats: Stats) -> None: raise NotImplementedError diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b1d9f386b6c3e..c9d3f1b4da01d 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -26,6 +26,8 @@ from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter, deprecate_kwargs, is_list_of +import numpy as np + logger = init_logger(__name__) @@ -716,10 +718,14 @@ def _run_engine( outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] total_in_toks = 0 total_out_toks = 0 + times_to_first_token = [] + normalized_times_to_first_token = [] while self.llm_engine.has_unfinished_requests(): step_outputs = self.llm_engine.step() for output in step_outputs: if output.finished: + times_to_first_token.append(output.metrics.first_token_time - output.metrics.first_scheduled_time) + normalized_times_to_first_token.append((output.metrics.first_token_time - output.metrics.first_scheduled_time)/len(output.prompt_token_ids)) outputs.append(output) if use_tqdm: if isinstance(output, RequestOutput): @@ -734,10 +740,35 @@ def _run_engine( f"est. speed input: {in_spd:.2f} toks/s, " f"output: {out_spd:.2f} toks/s") pbar.update(1) + if self.llm_engine.cache_service_metrics is not None: + logger.info( + "Cache service hit rate: by tokens: %.2f%%, by blocks: %.2f%%, total tokens hit %d, number of measurement collected %d", + 0 if self.llm_engine.cache_service_metrics.total_tokens == 0 else self.llm_engine.cache_service_metrics.hit_tokens/self.llm_engine.cache_service_metrics.total_tokens * 100, + 0 if self.llm_engine.cache_service_metrics.total_blocks == 0 else self.llm_engine.cache_service_metrics.hit_blocks/self.llm_engine.cache_service_metrics.total_blocks * 100, + self.llm_engine.cache_service_metrics.hit_tokens, + self.llm_engine.cache_service_metrics.counter, + ) + logger.info( + "Cache service time query avg %.4f std %.4f median %.4f, normalized mean %.4f std %.4f median %.4f, time load avg %.4f std %.4f median %.4f, normalized mean %.4f std %.4f median %.4f, time reshape avg %.4f std %.4f median %.4f, normalized mean %.4f std %.4f median %.4f, time unload avg %.4f std %.4f median %.4f, normalized mean %.4f std %.4f median %.4f, time update avg %.4f std %.4f median %.4f, normalized mean %.4f std %.4f median %.4f, time to first token avg %.4f std %.4f median %.4f, normalized time to first token %.4f std %.4f median %.4f", + np.mean(self.llm_engine.cache_service_metrics.time_query), np.std(self.llm_engine.cache_service_metrics.time_query), np.median(self.llm_engine.cache_service_metrics.time_query), + np.mean(self.llm_engine.cache_service_metrics.normalized_time_query), np.std(self.llm_engine.cache_service_metrics.normalized_time_query), np.median(self.llm_engine.cache_service_metrics.normalized_time_query), + np.mean(self.llm_engine.cache_service_metrics.time_load), np.std(self.llm_engine.cache_service_metrics.time_load), np.median(self.llm_engine.cache_service_metrics.time_load), + np.mean(self.llm_engine.cache_service_metrics.normalized_time_load), np.std(self.llm_engine.cache_service_metrics.normalized_time_load), np.median(self.llm_engine.cache_service_metrics.normalized_time_load), + np.mean(self.llm_engine.cache_service_metrics.time_reshape), np.std(self.llm_engine.cache_service_metrics.time_reshape), np.median(self.llm_engine.cache_service_metrics.time_reshape), + np.mean(self.llm_engine.cache_service_metrics.normalized_time_reshape), np.std(self.llm_engine.cache_service_metrics.normalized_time_reshape), np.median(self.llm_engine.cache_service_metrics.normalized_time_reshape), + np.mean(self.llm_engine.cache_service_metrics.time_unload), np.std(self.llm_engine.cache_service_metrics.time_unload), np.median(self.llm_engine.cache_service_metrics.time_unload), + np.mean(self.llm_engine.cache_service_metrics.normalized_time_unload), np.std(self.llm_engine.cache_service_metrics.normalized_time_unload), np.median(self.llm_engine.cache_service_metrics.normalized_time_unload), + np.mean(self.llm_engine.cache_service_metrics.time_update), np.std(self.llm_engine.cache_service_metrics.time_update), np.median(self.llm_engine.cache_service_metrics.time_update), + np.mean(self.llm_engine.cache_service_metrics.normalized_time_update), np.std(self.llm_engine.cache_service_metrics.normalized_time_update), np.median(self.llm_engine.cache_service_metrics.normalized_time_update), + np.mean(times_to_first_token), np.std(times_to_first_token), np.median(times_to_first_token), + np.mean(normalized_times_to_first_token), np.std(normalized_times_to_first_token), np.median(normalized_times_to_first_token), + + ) + # Restore original behavior self.llm_engine.step_return_finished_only = False - + if use_tqdm: pbar.close() # Sort the outputs by request ID. diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index ad84422ee2129..650b3af940dee 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -8,6 +8,7 @@ from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest +from vllm.worker.vineyard_llm_cache import CacheServiceMetrics logger = init_logger(__name__) @@ -47,7 +48,8 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: return num_gpu_blocks, num_cpu_blocks def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: + num_cpu_blocks: int, + cache_service_metrics: Optional[CacheServiceMetrics] = None) -> None: """Initialize the KV cache in all workers. """ @@ -62,7 +64,9 @@ def initialize_cache(self, num_gpu_blocks: int, self._run_workers("initialize_cache", num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks) + num_cpu_blocks=num_cpu_blocks, + cache_service_metrics = cache_service_metrics + ) def execute_model( self, diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 2185c9cf6cead..ea7a1c62658de 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -9,6 +9,7 @@ from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) from vllm.worker.worker_base import WorkerBase, WorkerWrapperBase +from vllm.worker.vineyard_llm_cache import CacheServiceMetrics logger = init_logger(__name__) @@ -113,16 +114,21 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: """ return self.driver_worker.determine_num_available_blocks() - def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: + def initialize_cache(self, + num_gpu_blocks: int, + num_cpu_blocks: int, + cache_service_metrics: Optional[CacheServiceMetrics] = None) -> None: """Initialize the KV cache by invoking the underlying worker. """ # NOTE: This is logged in the executor because there can be >1 worker # with other executors. We could log in the engine level, but work # remains to abstract away the device for non-GPU configurations. - logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, - num_cpu_blocks) + logger.info("# GPU blocks: %d, # CPU blocks: %d cache_service_metrics %s", num_gpu_blocks, num_cpu_blocks, cache_service_metrics) - self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + self.driver_worker.initialize_cache( + num_gpu_blocks, + num_cpu_blocks, + cache_service_metrics = cache_service_metrics) def execute_model( self, execute_model_req: ExecuteModelRequest diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index b68faebed4a5f..55924ed0bd69f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -55,6 +55,7 @@ _add_sampling_metadata_broadcastable_dict, _init_attn_metadata_from_tensor_dict, _init_sampling_metadata_from_tensor_dict) +from vllm.worker.vineyard_llm_cache import CacheServiceMetrics if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionBackend @@ -976,6 +977,7 @@ def __init__( # Multi-modal data support self.input_registry = input_registry self.mm_registry = mm_registry + self.cache_service_metrics = None self.multi_modal_input_mapper = mm_registry \ .create_input_mapper(model_config) self.mm_registry.init_mm_limits_per_prompt(self.model_config) @@ -997,7 +999,7 @@ def __init__( # to ensure the tensor model parallel group is initialized. self.vineyard_llm_cache = None - def _init_vineyard_cache(self): + def _init_vineyard_cache(self, metrics: CacheServiceMetrics = None): if envs.VLLM_USE_VINEYARD_CACHE: if not self.scheduler_config.chunked_prefill_enabled: raise Exception("Vineyard LLM cache is not enabled, requires chunked prefill") @@ -1011,6 +1013,7 @@ def _init_vineyard_cache(self): kv_cache_dtype=self.kv_cache_dtype, torch_dtype=get_kv_cache_torch_dtype(self.kv_cache_dtype, self.model_config.dtype), + metrics = metrics, ) logger.info("Using Vineyard LLM cache") else: @@ -1020,6 +1023,7 @@ def _init_vineyard_cache(self): self.inter_data_cache: Dict[int, PyObjectCache] = {} self.sampling_metadata_cache: SamplingMetadataCache = \ SamplingMetadataCache() + logger.info(f"Initialize CacheServiceMetric {metrics}") def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) @@ -1096,7 +1100,7 @@ def load_model(self) -> None: self.model, fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE, backend="eager") - self._init_vineyard_cache() + self._init_vineyard_cache(self.cache_service_metrics) def save_sharded_state( self, @@ -1128,6 +1132,10 @@ def save_tensorized_model( def get_max_block_per_batch(self) -> int: block_size = self.block_size return (self.max_seq_len_to_capture + block_size - 1) // block_size + + def set_cache_service_metrics(self, metrics: CacheServiceMetrics) -> None: + if self.vineyard_llm_cache != None: + self.vineyard_llm_cache.metrics = metrics def _prepare_model_input_tensors( self, diff --git a/vllm/worker/vineyard_llm_cache.py b/vllm/worker/vineyard_llm_cache.py index f43780efee8cd..009ed74bc3e06 100644 --- a/vllm/worker/vineyard_llm_cache.py +++ b/vllm/worker/vineyard_llm_cache.py @@ -1,5 +1,6 @@ import logging import time +import numpy as np from typing import Dict, List, NamedTuple, Optional, Set, Tuple import torch @@ -24,7 +25,25 @@ logger = init_logger(__name__) - +class CacheServiceMetrics: + hit_tokens: int = 0 # Total number of tokens hit. + total_tokens: int = 0 # Total number of tokens requested. + hit_blocks: int = 0 # Total number of blocks hit. + total_blocks: int = 0 # Total number of blocks requested. + counter: int = 0 # Total number of measurements. + time_query: list = [] # Times used query cache from cache service. + time_load: list = [] # Times used load fetched cache to device memory. + time_reshape: list = [] # Times used reshaping tensors for flash attention KV format. + time_unload: list = [] # Times used move computed KV from device memory. + time_update: list = [] # Times used update computed KV to cache service. + normalized_time_query: list = [] # Times used query cache from cache service normalized by number of tokens. + normalized_time_load: list = [] # Times used load fetched cache to device memory normalized by number of tokens. + normalized_time_reshape: list = [] # Times used reshaping tensors for flash attention KV format normalized by number of tokens. + normalized_time_unload: list = [] # Times used move computed KV from device memory normalized by number of tokens. + normalized_time_update: list = [] # Times used update computed KV to cache service normalized by number of tokens. + + + class VineyardLLMCache: def __init__( self, @@ -34,6 +53,7 @@ def __init__( layer: int = 2, kv_cache_dtype: str = None, torch_dtype: torch.dtype = torch.bfloat16, + metrics: CacheServiceMetrics = None ): self._init_vineyard_logger() @@ -67,6 +87,8 @@ def __init__( VineyardKVTensor(k_tensor.data_ptr(), k_tensor.numel() * k_tensor.element_size()), VineyardKVTensor(v_tensor.data_ptr(), v_tensor.numel() * v_tensor.element_size()), )) + self.metrics = metrics + logger.info(f"VineyardLLMCache init {metrics}") def _init_vineyard_logger(self): import vineyard @@ -83,6 +105,7 @@ def from_envs( parallel_config: ParallelConfig, kv_cache_dtype: str, torch_dtype: torch.dtype = torch.bfloat16, + metrics: CacheServiceMetrics = None, ) -> Optional["VineyardLLMCache"]: if VineyardKVCache is None: logger.warn("VineyardKVCache module is not available") @@ -95,7 +118,7 @@ def from_envs( head_size = model_config.get_head_size() num_kv_heads = model_config.get_num_kv_heads(parallel_config) num_layers = model_config.get_num_layers(parallel_config) - + logger.info(f"VineyardLLMCache from_envs {metrics}") return VineyardLLMCache( head_size=head_size, num_kv_heads=num_kv_heads, @@ -103,6 +126,7 @@ def from_envs( layer=num_layers, kv_cache_dtype=kv_cache_dtype, torch_dtype=torch_dtype, + metrics = metrics ) def prefetch_seq_kv_caches( @@ -129,7 +153,6 @@ def prefetch_seq_kv_caches( # alignment `context_len` to `self.chunk_size` query_context_len = context_len - context_len % self.chunk_size query_token_size = context_len + token_chunk_size - query_context_len - query_prefix = tokens[:query_context_len] query_tokens = tokens[query_context_len:query_context_len + query_token_size] query_args = [ @@ -156,13 +179,15 @@ def prefetch_seq_kv_caches( query_tokens ) = query_args - + start_time = time.time() matched = self.cache.query( prefix=query_prefix, tokens=query_tokens, kv_cache_list=self.tensors[:query_token_size], ) - + duration = time.time() - start_time + self.metrics.time_query.append(duration) + self.metrics.normalized_time_query.append(duration/len(tokens)) # synchronized across tensor parallel ranks matched_tensor = torch.tensor([matched], dtype=torch.long, device='cuda') # torch.distributed.all_reduce(matched_tensor, op=torch.distributed.ReduceOp.MIN, @@ -176,7 +201,6 @@ def prefetch_seq_kv_caches( matched = min(matched, token_chunk_size - 1) if matched <= 0: return seq_id, 0 - if seq_group_metadata is not None: block_table = seq_group_metadata.block_tables[seq_id] slot_mapping = [] @@ -193,9 +217,25 @@ def prefetch_seq_kv_caches( # torch.distributed.broadcast(slot_mapping, src=0, # group=get_tensor_model_parallel_group()) - + self.metrics.hit_tokens += matched + self.metrics.total_tokens += query_token_size + self.metrics.hit_blocks += (matched // block_size) + self.metrics.total_blocks += ((-query_token_size) // (-block_size)) + self.metrics.counter += 1 # save to GPU kv cache + copy_start = torch.cuda.Event(enable_timing=True) + copy_end = torch.cuda.Event(enable_timing=True) + copy_start.record() buffer = self.buffer[:, :, offset:offset+matched].cuda() + copy_end.record() + torch.cuda.synchronize() + duration = copy_start.elapsed_time(copy_end) / 1000.0 + self.metrics.time_load.append(duration) + self.metrics.normalized_time_load.append(0 if matched == 0 else duration/matched) + + reshape_start = torch.cuda.Event(enable_timing=True) + reshape_end = torch.cuda.Event(enable_timing=True) + reshape_start.record() for j in range(self.layer): # use `reshape_and_cache_flash` rather than `copy_` as # the target kv cache slots is not contingous. @@ -209,7 +249,12 @@ def prefetch_seq_kv_caches( 1.0, 1.0 ) - + reshape_end.record() + torch.cuda.synchronize() + duration = reshape_start.elapsed_time(reshape_end) / 1000.0 + self.metrics.time_reshape.append(duration) + self.metrics.normalized_time_reshape.append(0 if matched == 0 else duration/matched) + # update the seq_group_metadata's and seq's metadata if seq_group_metadata is not None: seq_data.update_num_computed_tokens(matched) @@ -243,7 +288,6 @@ def prefetch_kv_caches( # group=get_tensor_model_parallel_group()) prefill_requests = [None] * num_prefill_requests[0] num_prefill_requests = num_prefill_requests[0] - matched = {} for seq_group_meta in prefill_requests: seq_id, seq_matched = self.prefetch_seq_kv_caches( @@ -297,14 +341,12 @@ def update_seq_kv_caches( update_prefix, update_tokens, ) = update_args - if update_token_size <= 0: # restore the seq_group_metadata's and seq's metadata if seq_group_metadata is not None: seq_data.update_num_computed_tokens(-matched[seq_id]) seq_group_metadata.token_chunk_size += matched[seq_id] return seq_id, 0 - if seq_group_metadata is not None: block_table = seq_group_metadata.block_tables[seq_id] slot_mapping = [] @@ -322,16 +364,28 @@ def update_seq_kv_caches( # group=get_tensor_model_parallel_group()) # fetch from GPU kv cache + start_unload = torch.cuda.Event(enable_timing=True) + end_unload = torch.cuda.Event(enable_timing=True) + start_unload.record() for j in range(self.layer): self.buffer[:, j, :update_token_size].copy_( kv_caches[j][:, slot_mapping // block_size, slot_mapping % block_size]) - + torch.cuda.synchronize() + end_unload.record() + duration = start_unload.elapsed_time(end_unload) / 1000.0 + self.metrics.time_unload.append(duration) + self.metrics.normalized_time_unload.append(0 if update_token_size == 0 else duration/update_token_size) + + start_time = time.time() # updates into vineyard updated = self.cache.update( prefix=update_prefix, tokens=update_tokens, kv_cache_list=self.tensors[:update_token_size], ) + duration = time.time() - start_time + self.metrics.time_update.append(duration) + self.metrics.normalized_time_update.append(0 if update_token_size == 0 else duration/update_token_size) # restore the seq_group_metadata's and seq's metadata if seq_group_metadata is not None: diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 52092dc2dc291..298795360843c 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -28,6 +28,7 @@ from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput +from vllm.worker.vineyard_llm_cache import CacheServiceMetrics logger = init_logger(__name__) @@ -250,7 +251,8 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: return num_gpu_blocks, num_cpu_blocks def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: + num_cpu_blocks: int, + cache_service_metrics: Optional[CacheServiceMetrics] = None) -> None: """Allocate GPU and CPU KV cache with the specified number of blocks. This also warms up the model, which may record CUDA graphs. @@ -261,7 +263,8 @@ def initialize_cache(self, num_gpu_blocks: int, self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks - + logger.info(f"worker initialize_cache initialize cache_service_metrics {cache_service_metrics} ") + self.model_runner.set_cache_service_metrics(cache_service_metrics) self._init_cache_engine() self._warm_up_model()