From debff7f01a495c0d0e01fa27c32fb69d0e0acb4a Mon Sep 17 00:00:00 2001 From: Andy Lo Date: Sun, 12 Jan 2025 21:54:28 +0000 Subject: [PATCH] Add test and fix non-streaming multi-sequence --- tests/samplers/test_seeded_generate.py | 7 +++- vllm/sequence.py | 55 +++++++++++--------------- 2 files changed, 29 insertions(+), 33 deletions(-) diff --git a/tests/samplers/test_seeded_generate.py b/tests/samplers/test_seeded_generate.py index 88067f19c8f07..bf1ee6c397838 100644 --- a/tests/samplers/test_seeded_generate.py +++ b/tests/samplers/test_seeded_generate.py @@ -31,7 +31,7 @@ def test_random_sample_with_seed( sampling_params = SamplingParams( # Parameters to ensure sufficient randomness - temperature=2.0, + temperature=3.0, top_p=min(random.random() + 0.3, 1), top_k=random.randint(5, 20), n=random.randint(1, 10), @@ -75,3 +75,8 @@ def test_random_sample_with_seed( # verify requests with the same seed match assert outputs[1] == outputs[4] assert outputs[2] == outputs[5] + + # verify generations within the same parallel sampling group differ + for output in outputs: + for sub_output_a, sub_output_b in combinations(output, 2): + assert sub_output_a != sub_output_b diff --git a/vllm/sequence.py b/vllm/sequence.py index 06f13bbde959c..a854899436ebe 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -817,17 +817,6 @@ def get_max_num_running_seqs(self) -> int: lifetime of the request.""" if self.is_single_seq: return 0 if self.first_seq.is_finished() else 1 - if self.sampling_params: - n = self.sampling_params.n - assert isinstance(n, int) - if n > self.num_seqs(): - # At prompt stage, the sequence group is not yet filled up - # and only have one sequence running. However, in the - # generation stage, we will have `n` sequences - # running. - return n - # At sampling stages, return the number of actual sequences - # that are not finished yet. return self.num_seqs() - self.num_finished_seqs() def get_seqs( @@ -1466,25 +1455,27 @@ def maybe_assemble_group( return None # in the non-streaming mode, we will return the assembled sequence - # once after all sequences finish, and then return None for the + # when the last sequences finishes, and then return None for the # rest of the time - - if len(self.to_be_finished) > 0: - return None - - assert self.assembled_seq_group is not None - params = self.assembled_seq_group.sampling_params - assert isinstance(params, SamplingParams) - if not self.output_produced: - self.output_produced = True - if params._real_n is not None: - # Get the top-n sequences. - n = params._real_n or params.n - seqs = self.assembled_seq_group.seqs - sorting_key = lambda seq: seq.get_cumulative_logprob() - sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) - top_n_seqs = sorted_seqs[:n] - self.assembled_seq_group.seqs = top_n_seqs - return self.assembled_seq_group - if self.output_produced: - return None + if ( + len(self.to_be_finished) == 1 + and seq_group.request_id in self.to_be_finished + and seq_group.is_finished() + ): + assert self.assembled_seq_group is not None + params = self.assembled_seq_group.sampling_params + assert isinstance(params, SamplingParams) + if not self.output_produced: + self.output_produced = True + if params._real_n is not None: + # Get the top-n sequences. + n = params._real_n or params.n + seqs = self.assembled_seq_group.seqs + sorting_key = lambda seq: seq.get_cumulative_logprob() + sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) + top_n_seqs = sorted_seqs[:n] + self.assembled_seq_group.seqs = top_n_seqs + return self.assembled_seq_group + if self.output_produced: + return None + return None