Skip to content

Commit

Permalink
[V1][BugFix] Free encoder cache for aborted requests (vllm-project#12545
Browse files Browse the repository at this point in the history
)

Signed-off-by: Woosuk Kwon <[email protected]>
  • Loading branch information
WoosukKwon authored Jan 29, 2025
1 parent 73aa6cf commit e0cc5f2
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
9 changes: 8 additions & 1 deletion vllm/v1/core/encoder_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def allocate(self, request: Request, input_id: int) -> None:
def get_cached_input_ids(self, request: Request) -> Set[int]:
return self.cached.get(request.request_id, set())

def free(self, request: Request, input_id: int) -> None:
def free_encoder_input(self, request: Request, input_id: int) -> None:
"""Free a single encoder input id for the request."""
req_id = request.request_id
if req_id not in self.cached:
return
Expand All @@ -49,6 +50,12 @@ def free(self, request: Request, input_id: int) -> None:
self.num_free_slots += request.get_num_encoder_tokens(input_id)
self.freed.append((req_id, input_id))

def free(self, request: Request) -> None:
"""Free all cached input ids for the request."""
input_ids = self.get_cached_input_ids(request)
for input_id in input_ids:
self.free_encoder_input(request, input_id)

def get_freed_ids(self) -> List[Tuple[str, int]]:
freed = self.freed
self.freed = []
Expand Down
14 changes: 8 additions & 6 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def schedule(self) -> "SchedulerOutput":
# which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens
if num_new_tokens == 0:
# The happens when prompt length is divisible by the block
# This happens when prompt length is divisible by the block
# size and all blocks are cached. Now we force to recompute
# the last block. Note that we have to re-compute an entire
# block because allocate_slots() assumes num_computed_tokens
Expand Down Expand Up @@ -269,6 +269,7 @@ def schedule(self) -> "SchedulerOutput":

# Get the longest common prefix among all requests in the running queue.
# This can be potentially used for cascade attention.
num_common_prefix_blocks = 0
if self.running:
any_request = self.running[0]
num_common_prefix_blocks = (
Expand Down Expand Up @@ -433,7 +434,8 @@ def update_from_output(
if start_pos + num_tokens <= request.num_computed_tokens:
# The encoder output is already processed and stored
# in the decoder's KV cache.
self.encoder_cache_manager.free(request, input_id)
self.encoder_cache_manager.free_encoder_input(
request, input_id)

if request.num_computed_tokens == request.num_tokens:
req_index = model_runner_output.req_id_to_index[req_id]
Expand All @@ -445,8 +447,10 @@ def update_from_output(
# TODO: Update the KV cache manager for prefix caching.

# Check for stop and update request state.
# This must be called before me make the EngineCoreOutput.
# This must be called before we make the EngineCoreOutput.
stopped = self._check_stop(request)
if stopped:
self._free_request(request)

# Add EngineCoreOutput for this Request.
output = EngineCoreOutput(
Expand All @@ -472,21 +476,18 @@ def _check_stop(self, request: Request) -> bool:
if (request.num_tokens >= self.max_model_len
or request.num_output_tokens >= request.max_tokens):
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
self._free_request(request)
return True

sampling_params = request.sampling_params
last_token_id = request.output_token_ids[-1]
if (not sampling_params.ignore_eos
and last_token_id == request.eos_token_id):
request.status = RequestStatus.FINISHED_STOPPED
self._free_request(request)
return True

if last_token_id in (sampling_params.stop_token_ids or ()):
request.status = RequestStatus.FINISHED_STOPPED
request.stop_reason = last_token_id
self._free_request(request)
return True
return False

Expand Down Expand Up @@ -525,6 +526,7 @@ def finish_requests(
def _free_request(self, request: Request) -> None:
assert request.is_finished()
self.kv_cache_manager.free(request)
self.encoder_cache_manager.free(request)
self.running_reqs_data.pop(request.request_id, None)
del self.requests[request.request_id]
self.finished_req_ids.add(request.request_id)
Expand Down

0 comments on commit e0cc5f2

Please sign in to comment.