diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1339ed97cf138..5b08436b1db35 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -297,7 +297,7 @@ def __init__( ) self.log_stats = log_stats self.step_return_finished_only = step_return_finished_only - self.cache_service_metrics = CacheServiceMetrics + self.cache_service_metrics = CacheServiceMetrics() if not self.model_config.skip_tokenizer_init: self.tokenizer = self._init_tokenizer() @@ -1815,7 +1815,14 @@ def _get_stats(self, best_of_requests: List[int] = [] n_requests: List[int] = [] finished_reason_requests: List[str] = [] - + + # Cache Service Metrics + cache_service_tokens_hit_rate: float + cache_service_blocks_hit_rate: float + cache_service_time_async_update_queue: List[int] = [] + cache_service_time_async_update_exec: List[int] = [] + cache_service_counter_async_update_updated: List[int] = [] + # NOTE: This loop assumes prefill seq_groups are before # decode seq_groups in scheduled_seq_groups. if scheduler_outputs is not None: @@ -1903,6 +1910,32 @@ def _get_stats(self, spec_decode_metrics = model_output[0].spec_decode_worker_metrics else: spec_decode_metrics = None + + if self.cache_service_metrics is not None: + cache_service_hit_tokens = self.cache_service_metrics.hit_tokens + cache_service_total_tokens = self.cache_service_metrics.total_tokens + cache_service_hit_blocks = self.cache_service_metrics.hit_blocks + cache_service_total_blocks = self.cache_service_metrics.total_blocks + cache_service_tokens_hit_rate = self.cache_service_metrics.get_tokens_hit_rate() + cache_service_blocks_hit_rate = self.cache_service_metrics.get_blocks_hit_rate() + cache_service_err_query = self.cache_service_metrics.err_query + cache_service_err_async_update_task_queue_full = self.cache_service_metrics.err_async_update_task_queue_full + cache_service_err_update = self.cache_service_metrics.err_update + + cache_service_time_query = self.cache_service_metrics.time_query + cache_service_time_load = self.cache_service_metrics.time_load + cache_service_time_reshape = self.cache_service_metrics.time_reshape + cache_service_time_unload = self.cache_service_metrics.time_unload + cache_service_time_update = self.cache_service_metrics.time_update + cache_service_time_async_update_queue, cache_service_time_async_update_exec, cache_service_counter_async_update_updated = self.cache_service_metrics.get_async_metrics() + + self.cache_service_metrics.time_query = [] + self.cache_service_metrics.time_load = [] + self.cache_service_metrics.time_reshape = [] + self.cache_service_metrics.time_unload = [] + self.cache_service_metrics.time_update = [] + self.cache_service_metrics.reset_async_metrics() + return Stats( now=now, @@ -1935,6 +1968,25 @@ def _get_stats(self, best_of_requests=best_of_requests, n_requests=n_requests, finished_reason_requests=finished_reason_requests, + + # Cache Service + cache_service_hit_tokens = cache_service_hit_tokens, + cache_service_total_tokens = cache_service_total_tokens, + cache_service_hit_blocks = cache_service_hit_blocks, + cache_service_total_blocks = cache_service_total_blocks, + cache_service_tokens_hit_rate = cache_service_tokens_hit_rate, + cache_service_blocks_hit_rate = cache_service_blocks_hit_rate, + cache_service_err_query = cache_service_err_query, + cache_service_err_async_update_task_queue_full = cache_service_err_async_update_task_queue_full, + cache_service_err_update = cache_service_err_update, + cache_service_time_query = cache_service_time_query, + cache_service_time_load = cache_service_time_load, + cache_service_time_reshape = cache_service_time_reshape, + cache_service_time_unload = cache_service_time_unload, + cache_service_time_update = cache_service_time_update, + cache_service_time_async_update_queue = cache_service_time_async_update_queue, + cache_service_time_async_update_exec = cache_service_time_async_update_exec, + cache_service_counter_async_update_updated = cache_service_counter_async_update_updated, ) def add_lora(self, lora_request: LoRARequest) -> bool: diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index da28355b3694f..1d8a2d5e01163 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -190,6 +190,132 @@ def __init__(self, labelnames: List[str], max_model_len: int): labelnames=labelnames, multiprocess_mode="sum", ) + + self.gauge_cache_service_tokens_hit_rate = self._gauge_cls( + name="vllm:cache_service_tokens_hit_rate", + documentation="External cache service tokens hit rate.", + labelnames=labelnames, + multiprocess_mode="all") + + self.gauge_cache_service_blocks_hit_rate = self._gauge_cls( + name="vllm:cache_service_blocks_hit_rate", + documentation="External cache service blocks hit rate.", + labelnames=labelnames, + multiprocess_mode="all") + + self.gauge_cache_service_hit_tokens = self._gauge_cls( + name="vllm:cache_service_hit_tokens", + documentation="External cache service hit tokens.", + labelnames=labelnames, + multiprocess_mode="all") + + self.gauge_cache_service_total_tokens = self._gauge_cls( + name="vllm:cache_service_total_tokens", + documentation="External cache service total tokens.", + labelnames=labelnames, + multiprocess_mode="all") + + self.gauge_cache_service_hit_blocks = self._gauge_cls( + name="vllm:cache_service_hit_blocks", + documentation="External cache service hit blocks.", + labelnames=labelnames, + multiprocess_mode="all") + + self.gauge_cache_service_total_blocks = self._gauge_cls( + name="vllm:cache_service_total_blocks", + documentation="External cache service total blocks.", + labelnames=labelnames, + multiprocess_mode="all") + + self.gauge_cache_service_err_query = self._gauge_cls( + name="vllm:cache_service_err_query", + documentation="External cache service query errors.", + labelnames=labelnames, + multiprocess_mode="all") + + self.gauge_cache_service_err_async_update_task_queue_full = self._gauge_cls( + name="vllm:cache_service_err_async_update_task_queue_full", + documentation="External cache service async update task queue full errors.", + labelnames=labelnames, + multiprocess_mode="all") + + self.gauge_cache_service_err_update = self._gauge_cls( + name="vllm:cache_service_err_update", + documentation="External cache service update errors.", + labelnames=labelnames, + multiprocess_mode="all") + + self.histogram_cache_service_time_query_seconds = self._histogram_cls( + name="vllm:cache_service_time_query_seconds", + documentation="Histogram of cache service time query in seconds.", + labelnames=labelnames, + buckets=[ + 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, + 1.0, 2.5 + ]) + + self.histogram_cache_service_time_load_seconds = self._histogram_cls( + name="vllm:cache_service_time_load_seconds", + documentation="Histogram of cache service time load.", + labelnames=labelnames, + buckets=[ + 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, + 1.0, 2.5 + ]) + + self.histogram_cache_service_time_reshape_seconds = self._histogram_cls( + name="vllm:cache_service_time_reshape_seconds", + documentation="Histogram of cache service time update.", + labelnames=labelnames, + buckets=[ + 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, + 1.0, 2.5 + ]) + + self.histogram_cache_service_time_unload_seconds = self._histogram_cls( + name="vllm:cache_service_time_unload_seconds", + documentation="Histogram of cache service time unload.", + labelnames=labelnames, + buckets=[ + 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, + 1.0, 2.5 + ]) + + self.histogram_cache_service_time_update_seconds = self._histogram_cls( + name="vllm:cache_service_time_update_seconds", + documentation="Histogram of cache service time update.", + labelnames=labelnames, + buckets=[ + 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, + 1.0, 2.5 + ]) + + self.histogram_cache_service_time_async_update_queue_seconds = self._histogram_cls( + name="vllm:cache_service_time_async_update_queue_seconds", + documentation="Histogram of cache service async update time in queue.", + labelnames=labelnames, + buckets=[ + 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, + 1.0, 2.5 + ]) + + self.histogram_cache_service_time_async_update_exec_seconds = self._histogram_cls( + name="vllm:cache_service_time_async_update_exec_seconds", + documentation="Histogram of cache service async update time in execution.", + labelnames=labelnames, + buckets=[ + 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, + 1.0, 2.5 + ]) + + self.histogram_cache_service_counter_async_update_updated_seconds = self._histogram_cls( + name="vllm:cache_service_counter_async_update_updated_seconds", + documentation="Histogram of cache service async update time in update.", + labelnames=labelnames, + buckets=[ + 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, + 1.0, 2.5 + ]) # end-metrics-definitions @@ -375,12 +501,10 @@ def log(self, stats: Stats) -> None: self._format_spec_decode_metrics_str( self.spec_decode_metrics)) - if self.external_cache_service_metrics is not None: - logger.info( - "Cache service hit rate: by tokens: %.2f%%, by blocks: %.2f%%", - 0 if self.external_cache_service_metrics.total_tokens == 0 else self.external_cache_service_metrics.hit_tokens/self.external_cache_service_metrics.total_tokens * 100, - 0 if self.external_cache_service_metrics.total_blocks == 0 else self.external_cache_service_metrics.hit_blocks/self.external_cache_service_metrics.total_blocks * 100, - ) + logger.info( + "Cache service hit rate: by tokens: %.2f%%, by blocks: %.2f%%", + stats.cache_service_tokens_hit_rate, stats.cache_service_blocks_hit_rate + ) # Reset tracked stats for next interval. self.num_prompt_tokens = [] self.num_generation_tokens = [] @@ -482,6 +606,42 @@ def _log_prometheus(self, stats: Stats) -> None: self._log_histogram(self.metrics.histogram_n_request, stats.n_requests) self._log_histogram(self.metrics.histogram_best_of_request, stats.best_of_requests) + + # Cache Service + self._log_gauge(self.metrics.gauge_cache_service_hit_tokens, + stats.cache_service_hit_tokens) + self._log_gauge(self.metrics.gauge_cache_service_total_tokens, + stats.cache_service_total_tokens) + self._log_gauge(self.metrics.gauge_cache_service_hit_blocks, + stats.cache_service_hit_blocks) + self._log_gauge(self.metrics.gauge_cache_service_total_blocks, + stats.cache_service_total_blocks) + self._log_gauge(self.metrics.gauge_cache_service_tokens_hit_rate, + stats.cache_service_tokens_hit_rate) + self._log_gauge(self.metrics.gauge_cache_service_blocks_hit_rate, + stats.cache_service_blocks_hit_rate) + self._log_gauge(self.metrics.gauge_cache_service_err_query, + stats.cache_service_err_query) + self._log_gauge(self.metrics.gauge_cache_service_err_async_update_task_queue_full, + stats.cache_service_err_async_update_task_queue_full) + self._log_gauge(self.metrics.gauge_cache_service_err_update, + stats.cache_service_err_update) + self._log_histogram(self.metrics.histogram_cache_service_time_query_seconds, + stats.cache_service_time_query) + self._log_histogram(self.metrics.histogram_cache_service_time_load_seconds, + stats.cache_service_time_load) + self._log_histogram(self.metrics.histogram_cache_service_time_reshape_seconds, + stats.cache_service_time_reshape) + self._log_histogram(self.metrics.histogram_cache_service_time_unload_seconds, + stats.cache_service_time_unload) + self._log_histogram(self.metrics.histogram_cache_service_time_update_seconds, + stats.cache_service_time_update) + self._log_histogram(self.metrics.histogram_cache_service_time_async_update_queue_seconds, + stats.cache_service_time_async_update_queue) + self._log_histogram(self.metrics.histogram_cache_service_time_async_update_exec_seconds, + stats.cache_service_time_async_update_exec) + self._log_histogram(self.metrics.histogram_cache_service_counter_async_update_updated_seconds, + stats.cache_service_counter_async_update_updated) def _log_prometheus_interval(self, prompt_throughput: float, generation_throughput: float) -> None: diff --git a/vllm/engine/metrics_types.py b/vllm/engine/metrics_types.py index 57647f1fa7e5f..61d2d61dcada7 100644 --- a/vllm/engine/metrics_types.py +++ b/vllm/engine/metrics_types.py @@ -53,7 +53,27 @@ class Stats: n_requests: List[int] finished_reason_requests: List[str] - spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None + # Cache Service + cache_service_hit_tokens: int + cache_service_total_tokens: int + cache_service_hit_blocks: int + cache_service_total_blocks: int + cache_service_tokens_hit_rate: float + cache_service_blocks_hit_rate: float + cache_service_err_query: int + cache_service_err_async_update_task_queue_full: int + cache_service_err_update: int + cache_service_time_query: List[float] + cache_service_time_load: List[float] + cache_service_time_reshape: List[float] + cache_service_time_unload: List[float] + cache_service_time_update: List[float] + cache_service_time_async_update_queue: List[float] + cache_service_time_async_update_exec: List[float] + cache_service_counter_async_update_updated: List[float] + + + spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None class SupportsMetricsInfo(Protocol): @@ -71,7 +91,6 @@ def __init__(self, local_interval: float) -> None: self.num_generation_tokens: List[int] = [] self.last_local_log = time.time() self.local_interval = local_interval - self.external_cache_service_metrics = CacheServiceMetrics self.spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None @abstractmethod diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 8820bf3896a9a..29fe100c22278 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -742,27 +742,19 @@ def _run_engine( pbar.update(1) if self.llm_engine.cache_service_metrics is not None: logger.info( - "Cache service hit rate: by tokens: %.2f%%, by blocks: %.2f%%, total tokens hit %d, number of measurement collected %d", + "Cache service hit rate: by tokens: %.2f%%, by blocks: %.2f%%, total tokens hit %d", 0 if self.llm_engine.cache_service_metrics.total_tokens == 0 else self.llm_engine.cache_service_metrics.hit_tokens/self.llm_engine.cache_service_metrics.total_tokens * 100, 0 if self.llm_engine.cache_service_metrics.total_blocks == 0 else self.llm_engine.cache_service_metrics.hit_blocks/self.llm_engine.cache_service_metrics.total_blocks * 100, self.llm_engine.cache_service_metrics.hit_tokens, - self.llm_engine.cache_service_metrics.counter, ) logger.info( - "Cache service time query avg %.4f std %.4f median %.4f, normalized mean %.4f std %.4f median %.4f, time load avg %.4f std %.4f median %.4f, normalized mean %.4f std %.4f median %.4f, time reshape avg %.4f std %.4f median %.4f, normalized mean %.4f std %.4f median %.4f, time unload avg %.4f std %.4f median %.4f, normalized mean %.4f std %.4f median %.4f, time update avg %.4f std %.4f median %.4f, normalized mean %.4f std %.4f median %.4f, time to first token avg %.4f std %.4f median %.4f, normalized time to first token %.4f std %.4f median %.4f", + "Cache service time query avg %.4f std %.4f median %.4f, time load avg %.4f std %.4f median %.4f, time reshape avg %.4f std %.4f median %.4f, time unload avg %.4f std %.4f median %.4f, time update avg %.4f std %.4f median %.4f, time to first token avg %.4f std %.4f median %.4f", np.mean(self.llm_engine.cache_service_metrics.time_query), np.std(self.llm_engine.cache_service_metrics.time_query), np.median(self.llm_engine.cache_service_metrics.time_query), - np.mean(self.llm_engine.cache_service_metrics.normalized_time_query), np.std(self.llm_engine.cache_service_metrics.normalized_time_query), np.median(self.llm_engine.cache_service_metrics.normalized_time_query), np.mean(self.llm_engine.cache_service_metrics.time_load), np.std(self.llm_engine.cache_service_metrics.time_load), np.median(self.llm_engine.cache_service_metrics.time_load), - np.mean(self.llm_engine.cache_service_metrics.normalized_time_load), np.std(self.llm_engine.cache_service_metrics.normalized_time_load), np.median(self.llm_engine.cache_service_metrics.normalized_time_load), np.mean(self.llm_engine.cache_service_metrics.time_reshape), np.std(self.llm_engine.cache_service_metrics.time_reshape), np.median(self.llm_engine.cache_service_metrics.time_reshape), - np.mean(self.llm_engine.cache_service_metrics.normalized_time_reshape), np.std(self.llm_engine.cache_service_metrics.normalized_time_reshape), np.median(self.llm_engine.cache_service_metrics.normalized_time_reshape), np.mean(self.llm_engine.cache_service_metrics.time_unload), np.std(self.llm_engine.cache_service_metrics.time_unload), np.median(self.llm_engine.cache_service_metrics.time_unload), - np.mean(self.llm_engine.cache_service_metrics.normalized_time_unload), np.std(self.llm_engine.cache_service_metrics.normalized_time_unload), np.median(self.llm_engine.cache_service_metrics.normalized_time_unload), np.mean(self.llm_engine.cache_service_metrics.time_update), np.std(self.llm_engine.cache_service_metrics.time_update), np.median(self.llm_engine.cache_service_metrics.time_update), - np.mean(self.llm_engine.cache_service_metrics.normalized_time_update), np.std(self.llm_engine.cache_service_metrics.normalized_time_update), np.median(self.llm_engine.cache_service_metrics.normalized_time_update), np.mean(times_to_first_token), np.std(times_to_first_token), np.median(times_to_first_token), - np.mean(normalized_times_to_first_token), np.std(normalized_times_to_first_token), np.median(normalized_times_to_first_token), - ) with self.llm_engine.cache_service_metrics.lock: diff --git a/vllm/envs.py b/vllm/envs.py index 4aba1fb63c877..d31f10228d7f5 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -69,6 +69,7 @@ VINEYARD_CACHE_ASYNC_UPDATE_CPU_MEM_UTIL: float = 0.2 VINEYARD_CACHE_MIN_INFLIGHT_TASKS: int = 1 VINEYARD_CACHE_MAX_INFLIGHT_TASKS: int = 32 + VINEYARD_CACHE_METRICS_ENABLED: bool = False def get_default_cache_root(): @@ -458,6 +459,12 @@ def get_default_config_root(): # Max number of inflight async tasks for vineyard cache "VINEYARD_CACHE_MAX_INFLIGHT_TASKS": lambda: int(os.getenv("VINEYARD_CACHE_MAX_INFLIGHT_TASKS", "32")), + + # Max number of inflight async tasks for vineyard cache + "VINEYARD_CACHE_METRICS_ENABLED":lambda: ( + os.environ.get("VINEYARD_CACHE_METRICS_ENABLED", "0").strip().lower() + in ("1", "true") + ), } # end-env-vars-definition diff --git a/vllm/worker/vineyard_llm_cache.py b/vllm/worker/vineyard_llm_cache.py index 29649e5b6c0f4..8a3904885edb6 100644 --- a/vllm/worker/vineyard_llm_cache.py +++ b/vllm/worker/vineyard_llm_cache.py @@ -4,6 +4,7 @@ import threading from functools import partial from queue import Queue, Full +from collections import deque from typing import Dict, List, NamedTuple, Optional, Set, Tuple import torch @@ -30,33 +31,79 @@ logger = init_logger(__name__) -class CacheServiceMetrics: - hit_tokens: int = 0 # Total number of tokens hit. - total_tokens: int = 0 # Total number of tokens requested. - hit_blocks: int = 0 # Total number of blocks hit. - total_blocks: int = 0 # Total number of blocks requested. - counter: int = 0 # Total number of measurements. - time_query: list = [] # Times used query cache from cache service. - time_load: list = [] # Times used load fetched cache to device memory. - time_reshape: list = [] # Times used reshaping tensors for flash attention KV format. - time_unload: list = [] # Times used move computed KV from device memory. - time_update: list = [] # Times used update computed KV to cache service. - normalized_time_query: list = [] # Times used query cache from cache service normalized by number of tokens. - normalized_time_load: list = [] # Times used load fetched cache to device memory normalized by number of tokens. - normalized_time_reshape: list = [] # Times used reshaping tensors for flash attention KV format normalized by number of tokens. - normalized_time_unload: list = [] # Times used move computed KV from device memory normalized by number of tokens. - normalized_time_update: list = [] # Times used update computed KV to cache service normalized by number of tokens. - - err_query: int = 0 # Number of query errors. - err_async_update_task_queue_full: int = 0 # Number of Full exceptions when enqueuing async update tasks - - lock: threading.Lock = threading.Lock() - # The following metrics need to be protected by `lock` - time_async_update_queue: list = [] # Queuing delays of async update tasks - time_async_update_exec: list = [] # Execution times of async update tasks - counter_async_update_updated: list = [] # Number of udpated tokens - err_update: int = 0 # Number of update errors. +class CacheServiceMetrics: + def __init__(self): + # Instance variables + self.hit_tokens: int = 0 # Total number of tokens hit. + self.total_tokens: int = 0 # Total number of tokens requested. + self.hit_blocks: int = 0 # Total number of blocks hit. + self.total_blocks: int = 0 # Total number of blocks requested. + + self.time_query: List[float] = [] # Times used query cache from cache service. + self.time_load: List[float] = [] # Times used load fetched cache to device memory. + self.time_reshape: List[float] = [] # Times used reshaping tensors for flash attention KV format. + self.time_unload: List[float] = [] # Times used move computed KV from device memory. + self.time_update: List[float] = [] # Times used update computed KV to cache service. + + self.err_query: int = 0 # Number of query errors. + self.err_async_update_task_queue_full: int = 0 # Number of exceptions for async update tasks. + + self.lock: threading.Lock = threading.Lock() + # The following metrics need to be protected by `lock` + self.time_async_update_queue: List[float] = [] # Queuing delays of async update tasks. + self.time_async_update_exec: List[float] = [] # Execution times of async update tasks. + self.counter_async_update_updated: List = [] # Number of updated tokens. + self.err_update: int = 0 # Number of update errors. + + def __getstate__(self): + # Create a state dictionary excluding the lock + state = self.__dict__.copy() + del state['lock'] + return state + + def __setstate__(self, state): + # Restore the instance attributes + self.__dict__.update(state) + # Reinitialize the lock + self.lock = threading.Lock() + + def add_time_query(self, value): + self.time_query.append(value) + + def add_time_load(self, value): + self.time_load.append(value) + + def add_time_reshape(self, value): + self.time_reshape.append(value) + + def add_time_unload(self, value): + self.time_unload.append(value) + + def add_time_update(self, value): + self.time_update.append(value) + + def get_tokens_hit_rate(self): + return 0 if self.total_tokens == 0 else self.hit_tokens / float(self.total_tokens) + + def get_blocks_hit_rate(self): + return 0 if self.total_blocks == 0 else self.hit_blocks / float(self.total_blocks) + + def update_async_metrics(self, queue_duration, exec_duration, updated): + with self.lock: + self.time_async_update_queue.append(queue_duration) + self.time_async_update_exec.append(exec_duration) + self.counter_async_update_updated.append(updated) + + def get_async_metrics(self): + with self.lock: + return self.time_async_update_queue, self.time_async_update_exec, self.counter_async_update_updated + + def reset_async_metrics(self): + with self.lock: + self.time_async_update_queue = [] + self.time_async_update_exec = [] + self.counter_async_update_updated = [] class VineyardLLMCache: def __init__( @@ -69,6 +116,7 @@ def __init__( kv_cache_dtype: str = None, torch_dtype: torch.dtype = torch.bfloat16, metrics: CacheServiceMetrics = None, + metrics_enabled: bool = False, enable_async_update: bool = False, min_inflight_tasks: int = 1, max_inflight_tasks: int = 1, @@ -125,7 +173,8 @@ def __init__( self._background_loop.start() self.metrics = metrics - logger.info(f"VineyardLLMCache init {metrics}") + self.metrics_enabled = metrics_enabled + logger.info(f"VineyardLLMCache init {metrics} metrics_enabled {metrics_enabled}") logger.info(self) def _pinned_tensor_creator( @@ -217,6 +266,10 @@ def from_envs( kwargs["min_inflight_tasks"] = min(envs.VINEYARD_CACHE_MIN_INFLIGHT_TASKS, max_inflight_tasks) kwargs["max_inflight_tasks"] = max_inflight_tasks logger.info(f"VineyardLLMCache async update: {kwargs}") + + metrics_enabled = False + if envs.VINEYARD_CACHE_METRICS_ENABLED: + metrics_enabled = True # convert cache capacity to number of tokens cache_capacity = ( @@ -233,6 +286,7 @@ def from_envs( kv_cache_dtype=kv_cache_dtype, torch_dtype=torch_dtype, metrics = metrics, + metrics_enabled = metrics_enabled, **kwargs, ) @@ -301,8 +355,11 @@ def prefetch_seq_kv_caches( if query_token_size <= 0: return seq_id, 0 - - start_time = time.perf_counter() + + self.metrics.total_tokens += query_token_size + self.metrics.total_blocks += ((-query_token_size) // (-block_size)) + if self.metrics_enabled: + start_time = time.perf_counter() matched = 0 try: matched = self.cache.query( @@ -311,10 +368,11 @@ def prefetch_seq_kv_caches( kv_cache_list=self.fetch_tensors[:query_token_size], ) except Exception: - self.metrics.err_query += 1 - duration = time.perf_counter() - start_time - self.metrics.time_query.append(duration) - self.metrics.normalized_time_query.append(duration/(len(query_tokens) + len(query_prefix))) + if self.metrics_enabled: + self.metrics.err_query += 1 + if self.metrics_enabled: + duration = time.perf_counter() - start_time + self.metrics.add_time_query(duration) # If sampling is required, we need to leave one token unmatched # to trigger the following sampling step in engine worker's workflow. if seq_group_metadata is not None and seq_group_metadata.is_sampling_enabled: @@ -327,7 +385,6 @@ def prefetch_seq_kv_caches( # shift offset = context_len % self.chunk_size matched -= offset - if matched <= 0: return seq_id, 0 if get_tensor_model_parallel_rank() == 0: @@ -344,32 +401,29 @@ def prefetch_seq_kv_caches( slot_mapping = torch.zeros((matched,), dtype=torch.long, device='cuda') tensor_model_parallel_broadcast(slot_mapping, src=0) self.metrics.hit_tokens += matched - self.metrics.total_tokens += query_token_size self.metrics.hit_blocks += (matched // block_size) - self.metrics.total_blocks += ((-query_token_size) // (-block_size)) - self.metrics.counter += 1 - - # save to GPU kv cache - torch.cuda.synchronize() - copy_start = torch.cuda.Event(enable_timing=True) - copy_end = torch.cuda.Event(enable_timing=True) - copy_start.record() + if self.metrics_enabled: + # save to GPU kv cache + torch.cuda.synchronize() + copy_start = torch.cuda.Event(enable_timing=True) + copy_end = torch.cuda.Event(enable_timing=True) + copy_start.record() # Copying the entire buffer to the GPU in a single operation and then # slicing it into smaller, non-contiguou chunks on the GPU is more # efficient than performing multiple smaller copy operations. This # approach reduces the number of transfers between CPU and GPU, # leading to faster overall performance. buffer = self.cuda_buffer.copy_(self.fetch_buffer)[:, :, :matched] - copy_end.record() - copy_end.synchronize() - duration = copy_start.elapsed_time(copy_end) / 1000.0 - self.metrics.time_load.append(duration) - self.metrics.normalized_time_load.append(0 if matched == 0 else duration/matched) - - torch.cuda.synchronize() - reshape_start = torch.cuda.Event(enable_timing=True) - reshape_end = torch.cuda.Event(enable_timing=True) - reshape_start.record() + if self.metrics_enabled: + copy_end.record() + copy_end.synchronize() + duration = copy_start.elapsed_time(copy_end) / 1000.0 + self.metrics.add_time_unload(duration) + + torch.cuda.synchronize() + reshape_start = torch.cuda.Event(enable_timing=True) + reshape_end = torch.cuda.Event(enable_timing=True) + reshape_start.record() for j in range(self.layer): # use `reshape_and_cache_flash` rather than `copy_` as # the target kv cache slots is not contingous. @@ -383,11 +437,11 @@ def prefetch_seq_kv_caches( 1.0, 1.0 ) - reshape_end.record() - reshape_end.synchronize() - duration = reshape_start.elapsed_time(reshape_end) / 1000.0 - self.metrics.time_reshape.append(duration) - self.metrics.normalized_time_reshape.append(0 if matched == 0 else duration/matched) + if self.metrics_enabled: + reshape_end.record() + reshape_end.synchronize() + duration = reshape_start.elapsed_time(reshape_end) / 1000.0 + self.metrics.add_time_reshape(duration) # update the seq_group_metadata's and seq's metadata self._update_seq_group_metadata(seq_group_metadata, matched) @@ -453,28 +507,32 @@ def _update_kv_cache( scheduled_time: The timestamp that the task is scheduled. ''' try: - start_time = time.perf_counter() - queue_duration = start_time - scheduled_time + if self.metrics_enabled: + start_time = time.perf_counter() + queue_duration = start_time - scheduled_time update_token_size = len(tokens) kv_cache_list = buffer_tensors_tuple[1][:update_token_size] updated = self.cache.update(prefix, tokens, kv_cache_list) - exec_duration = time.perf_counter() - start_time + if self.metrics_enabled: + exec_duration = time.perf_counter() - start_time if self.enable_async_update: - logger.debug( - f"update kv cache: #prefix={len(prefix)}, #tokens={len(tokens)}, updated={updated}, " - f"queue_duration={queue_duration:.4f}, exec_duration={exec_duration:.4f}" - ) - with self.metrics.lock: - self.metrics.time_async_update_queue.append(queue_duration) - self.metrics.time_async_update_exec.append(exec_duration) - self.metrics.counter_async_update_updated.append(updated) + if self.metrics_enabled: + logger.debug( + f"update kv cache: #prefix={len(prefix)}, #tokens={len(tokens)}, updated={updated}, " + f"queue_duration={queue_duration:.4f}, exec_duration={exec_duration:.4f}" + ) + self.metrics.update_async_metrics(queue_duration, exec_duration, updated) + else: + logger.debug( + f"update kv cache: #prefix={len(prefix)}, #tokens={len(tokens)}, updated={updated}") else: logger.debug( f"update kv cache: #prefix={len(prefix)}, #tokens={len(tokens)}, updated={updated}" ) except Exception: - with self.metrics.lock: - self.metrics.err_update += 1 + if self.metrics_enabled: + with self.metrics.lock: + self.metrics.err_update += 1 finally: if self.enable_async_update: self.tensor_pool.put(buffer_tensors_tuple) @@ -557,20 +615,21 @@ def update_seq_kv_caches( # fetch from GPU kv cache - torch.cuda.synchronize() - start_unload = torch.cuda.Event(enable_timing=True) - end_unload = torch.cuda.Event(enable_timing=True) - start_unload.record() + if self.metrics_enabled: + torch.cuda.synchronize() + start_unload = torch.cuda.Event(enable_timing=True) + end_unload = torch.cuda.Event(enable_timing=True) + start_unload.record() # using a cuda staging buffer to avoid the inefficiency of non-contiguous HBM->DRAM memcpy for j in range(self.layer): self.cuda_buffer[:, j, :update_token_size].copy_( kv_caches[j][:, slot_mapping // block_size, slot_mapping % block_size]) update_buffer.copy_(self.cuda_buffer) - end_unload.record() - end_unload.synchronize() - duration = start_unload.elapsed_time(end_unload) / 1000.0 - self.metrics.time_unload.append(duration) - self.metrics.normalized_time_unload.append(0 if update_token_size == 0 else duration/update_token_size) + if self.metrics_enabled: + end_unload.record() + end_unload.synchronize() + duration = start_unload.elapsed_time(end_unload) / 1000.0 + self.metrics.add_time_unload(duration) start_time = time.perf_counter() @@ -592,14 +651,15 @@ def update_seq_kv_caches( ) except Full: logger.warning(f"update_seq_kv_caches: queue is full, skip this update") - self.metrics.err_async_update_task_queue_full += 1 + if self.metrics_enabled: + self.metrics.err_async_update_task_queue_full += 1 self.tensor_pool.put(buffer_tensors_tuple) else: update_task() - - duration = time.perf_counter() - start_time - self.metrics.time_update.append(duration) - self.metrics.normalized_time_update.append(0 if update_token_size == 0 else duration/update_token_size) + + if self.metrics_enabled: + duration = time.perf_counter() - start_time + self.metrics.add_time_update(duration) # restore the seq_group_metadata's and seq's metadata self._update_seq_group_metadata(seq_group_metadata, -matched[seq_id])