diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 8ff204abd72..d5e400fbdda 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -423,10 +423,14 @@ class ScheduleBatch: # Stream has_stream: bool = False + # Has regex + has_regex: bool = False + @classmethod def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache): return_logprob = any(req.return_logprob for req in reqs) has_stream = any(req.stream for req in reqs) + has_regex = any(req.regex_fsm for req in reqs) return cls( reqs=reqs, @@ -435,6 +439,7 @@ def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache): tree_cache=tree_cache, return_logprob=return_logprob, has_stream=has_stream, + has_regex=has_regex, ) def batch_size(self): @@ -750,7 +755,9 @@ def filter_batch(self, unfinished_indices: List[int]): ] else: self.top_logprobs_nums = None + self.has_stream = any(req.stream for req in self.reqs) + self.has_regex = any(req.regex_fsm for req in self.reqs) self.sampling_info.filter_batch(unfinished_indices, new_indices) @@ -771,9 +778,11 @@ def merge_batch(self, other: "ScheduleBatch"): self.top_logprobs_nums.extend([0] * len(other.reqs)) elif other.return_logprob: self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums - self.has_stream = any(req.stream for req in self.reqs) self.reqs.extend(other.reqs) + self.return_logprob = self.return_logprob or other.return_logprob + self.has_stream = self.has_stream or other.has_stream + self.has_regex = self.has_regex or other.has_regex def get_model_worker_batch(self): if self.forward_mode.is_decode(): @@ -787,7 +796,11 @@ def get_model_worker_batch(self): image_inputs = [r.image_inputs for r in self.reqs] lora_paths = [req.lora_path for req in self.reqs] - self.sampling_info.regex_fsm_states = [req.regex_fsm_state for req in self.reqs] + if self.has_regex: + self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs] + self.sampling_info.regex_fsm_states = [ + req.regex_fsm_state for req in self.reqs + ] return ModelWorkerBatch( forward_mode=self.forward_mode, diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 606f11d9816..fb0dfc53c71 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -84,10 +84,6 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): # Handle logit bias but only allocate when needed ret.logit_bias = None - # This is only for regex_fsm. We notice a regression if we maintain the list of regex_fsm - # in SamplingBatchInfo, so we keep it here. - ret.schedule_batch = batch - return ret def __len__(self): @@ -113,7 +109,7 @@ def update_penalties(self): self.linear_penalties = penalizer.apply(self.linear_penalties) def update_regex_vocab_mask(self): - has_regex = any(req.regex_fsm is not None for req in self.schedule_batch.reqs) + has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms) # Reset the vocab mask self.vocab_mask = None @@ -122,11 +118,11 @@ def update_regex_vocab_mask(self): self.vocab_mask = torch.zeros( len(self.temperatures), self.vocab_size, dtype=torch.bool, device="cuda" ) - for i, req in enumerate(self.schedule_batch.reqs): - if req.regex_fsm is not None: + for i, regex_fsm in enumerate(self.regex_fsms): + if regex_fsm is not None: self.vocab_mask[i].fill_(1) self.vocab_mask[i][ - req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens + regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens ] = 0 def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):