Skip to content

Commit

Permalink
Fix the prefix indices (#1037)
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyls2002 authored Aug 12, 2024
1 parent d84c5e7 commit 7de6034
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 15 deletions.
15 changes: 10 additions & 5 deletions python/sglang/srt/managers/policy_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,14 @@ def __init__(self, policy: str, tree_cache: BasePrefixCache):

def calc_priority(self, waiting_queue: List[Req]):
# Compute matched prefix length
for r in waiting_queue:
# NOTE: the prefix_indices must always be aligned with last_node
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=r.adjust_max_prefix_ids()
)
prefix_computed = False
if self.policy in ["lpm", "dfs-weight"]:
for r in waiting_queue:
# NOTE: the prefix_indices must always be aligned with last_node
r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
rid=r.rid, key=r.adjust_max_prefix_ids()
)
prefix_computed = True

if self.policy == "lpm":
# Longest Prefix Match
Expand Down Expand Up @@ -80,6 +83,8 @@ def calc_priority(self, waiting_queue: List[Req]):
else:
raise ValueError(f"Unknown schedule_policy: {self.policy}")

return prefix_computed

def calc_weight(self, cur_node: TreeNode, node_to_weight: Dict):
for child in cur_node.children.values():
self.calc_weight(child, node_to_weight)
Expand Down
15 changes: 9 additions & 6 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,18 @@
import logging
import warnings
from dataclasses import dataclass
from typing import List, Union
from typing import List, Optional, Union

import numpy as np
import torch
from flashinfer.sampling import top_k_top_p_sampling_from_probs

import sglang.srt.sampling.penaltylib as penaltylib
from sglang.global_config import global_config
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import RadixCache

INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5

Expand Down Expand Up @@ -164,8 +163,12 @@ def __init__(self, rid, origin_input_text, origin_input_ids):
def finished(self) -> bool:
return self.finished_reason is not None

def init_next_round_input(self):
def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
self.fill_ids = self.origin_input_ids + self.output_ids
if tree_cache is not None:
self.prefix_indices, self.last_node = tree_cache.match_prefix(
rid=self.rid, key=self.adjust_max_prefix_ids()
)
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)

def adjust_max_prefix_ids(self):
Expand Down Expand Up @@ -312,7 +315,7 @@ class ScheduleBatch:
reqs: List[Req]
req_to_token_pool: ReqToTokenPool
token_to_kv_pool: BaseTokenToKVPool
tree_cache: RadixCache
tree_cache: BasePrefixCache

# Batched arguments to model runner
input_ids: torch.Tensor = None
Expand Down Expand Up @@ -534,7 +537,7 @@ def retract_decode(self):
residual_size = max(0, residual_size)
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)

req.prefix_indices = None
req.prefix_indices = []
req.last_node = None
req.extend_input_len = 0

Expand Down
8 changes: 5 additions & 3 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
return None

# Get priority queue
self.scheduler.calc_priority(self.waiting_queue)
prefix_computed = self.scheduler.calc_priority(self.waiting_queue)

adder = PrefillAdder(
self.tree_cache,
Expand All @@ -383,13 +383,15 @@ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:

has_inflight = self.current_inflight_req is not None
if self.current_inflight_req is not None:
self.current_inflight_req.init_next_round_input()
self.current_inflight_req.init_next_round_input(
None if prefix_computed else self.tree_cache
)
self.current_inflight_req = adder.add_inflight_req(
self.current_inflight_req
)

for req in self.waiting_queue:
req.init_next_round_input()
req.init_next_round_input(None if prefix_computed else self.tree_cache)
res = adder.add_one_req(req)
if (
not res
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/mem_cache/radix_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import heapq
import time
from collections import defaultdict
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable, List, Optional

import torch

Expand Down

0 comments on commit 7de6034

Please sign in to comment.