Skip to content

Commit

Permalink
Fix bugs of logprobs_nums (#1548)
Browse files Browse the repository at this point in the history
  • Loading branch information
hnyls2002 authored Oct 1, 2024
1 parent 99ec439 commit b88ea90
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,8 @@ def filter_batch(self, unfinished_indices: List[int]):
self.top_logprobs_nums = [
self.top_logprobs_nums[i] for i in unfinished_indices
]
else:
self.top_logprobs_nums = None
self.has_stream = any(req.stream for req in self.reqs)

self.sampling_info.filter_batch(unfinished_indices, new_indices)
Expand All @@ -758,20 +760,20 @@ def merge_batch(self, other: "ScheduleBatch"):
# needs to be called with pre-merged Batch.reqs.
self.sampling_info.merge_batch(other.sampling_info)

self.reqs.extend(other.reqs)
self.req_pool_indices = torch.concat(
[self.req_pool_indices, other.req_pool_indices]
)
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
self.out_cache_loc = None
self.return_logprob = any(req.return_logprob for req in self.reqs)
if self.return_logprob and other.return_logprob:
self.top_logprobs_nums.extend(other.top_logprobs_nums)
elif self.return_logprob:
self.top_logprobs_nums.extend([0] * len(other.reqs))
elif other.return_logprob:
self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
self.has_stream = any(req.stream for req in self.reqs)
self.reqs.extend(other.reqs)
self.return_logprob = self.return_logprob or other.return_logprob

def get_model_worker_batch(self):
if self.forward_mode.is_decode():
Expand Down

0 comments on commit b88ea90

Please sign in to comment.