Skip to content

Commit

Permalink
Add test and fix non-streaming multi-sequence
Browse files Browse the repository at this point in the history
  • Loading branch information
andylolu2 committed Jan 12, 2025
1 parent f9a2eb2 commit debff7f
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 33 deletions.
7 changes: 6 additions & 1 deletion tests/samplers/test_seeded_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_random_sample_with_seed(

sampling_params = SamplingParams(
# Parameters to ensure sufficient randomness
temperature=2.0,
temperature=3.0,
top_p=min(random.random() + 0.3, 1),
top_k=random.randint(5, 20),
n=random.randint(1, 10),
Expand Down Expand Up @@ -75,3 +75,8 @@ def test_random_sample_with_seed(
# verify requests with the same seed match
assert outputs[1] == outputs[4]
assert outputs[2] == outputs[5]

# verify generations within the same parallel sampling group differ
for output in outputs:
for sub_output_a, sub_output_b in combinations(output, 2):
assert sub_output_a != sub_output_b
55 changes: 23 additions & 32 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,17 +817,6 @@ def get_max_num_running_seqs(self) -> int:
lifetime of the request."""
if self.is_single_seq:
return 0 if self.first_seq.is_finished() else 1
if self.sampling_params:
n = self.sampling_params.n
assert isinstance(n, int)
if n > self.num_seqs():
# At prompt stage, the sequence group is not yet filled up
# and only have one sequence running. However, in the
# generation stage, we will have `n` sequences
# running.
return n
# At sampling stages, return the number of actual sequences
# that are not finished yet.
return self.num_seqs() - self.num_finished_seqs()

def get_seqs(
Expand Down Expand Up @@ -1466,25 +1455,27 @@ def maybe_assemble_group(
return None

# in the non-streaming mode, we will return the assembled sequence
# once after all sequences finish, and then return None for the
# when the last sequences finishes, and then return None for the
# rest of the time

if len(self.to_be_finished) > 0:
return None

assert self.assembled_seq_group is not None
params = self.assembled_seq_group.sampling_params
assert isinstance(params, SamplingParams)
if not self.output_produced:
self.output_produced = True
if params._real_n is not None:
# Get the top-n sequences.
n = params._real_n or params.n
seqs = self.assembled_seq_group.seqs
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
top_n_seqs = sorted_seqs[:n]
self.assembled_seq_group.seqs = top_n_seqs
return self.assembled_seq_group
if self.output_produced:
return None
if (
len(self.to_be_finished) == 1
and seq_group.request_id in self.to_be_finished
and seq_group.is_finished()
):
assert self.assembled_seq_group is not None
params = self.assembled_seq_group.sampling_params
assert isinstance(params, SamplingParams)
if not self.output_produced:
self.output_produced = True
if params._real_n is not None:
# Get the top-n sequences.
n = params._real_n or params.n
seqs = self.assembled_seq_group.seqs
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
top_n_seqs = sorted_seqs[:n]
self.assembled_seq_group.seqs = top_n_seqs
return self.assembled_seq_group
if self.output_produced:
return None
return None

0 comments on commit debff7f

Please sign in to comment.