Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Multi-sequence broken #11898

Merged
merged 3 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tests/samplers/test_seeded_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_random_sample_with_seed(

sampling_params = SamplingParams(
# Parameters to ensure sufficient randomness
temperature=2.0,
temperature=3.0,
top_p=min(random.random() + 0.3, 1),
top_k=random.randint(5, 20),
n=random.randint(1, 10),
Expand Down Expand Up @@ -75,3 +75,8 @@ def test_random_sample_with_seed(
# verify requests with the same seed match
assert outputs[1] == outputs[4]
assert outputs[2] == outputs[5]

# verify generations within the same parallel sampling group differ
for output in outputs:
for sub_output_a, sub_output_b in combinations(output, 2):
assert sub_output_a != sub_output_b
Comment on lines +78 to +82
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would this test be flaky? e.g. when we generate with n=10, if two sequences happen to generate the same answer ...

2 changes: 1 addition & 1 deletion vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@ def from_seq_group(
if seq_group.request_id in seq_id_to_seq_group:
group: SequenceGroupBase = seq_id_to_seq_group[
seq_group.request_id]
assembled_seq_group = group.maybe_assemble_group(seq_group)
if finished:
group.finish_seq(seq_group)
assembled_seq_group = group.maybe_assemble_group(seq_group)
if assembled_seq_group is None:
return None
return cls.from_seq_group(assembled_seq_group, use_cache,
Expand Down
89 changes: 52 additions & 37 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,9 @@ def set_finished_time(self, time: Optional[float]) -> None:
def get_max_num_running_seqs(self) -> int:
"""The maximum number of sequences running in parallel in the remaining
lifetime of the request."""
return 0 if self.first_seq.is_finished() else 1
if self.is_single_seq:
return 0 if self.first_seq.is_finished() else 1
return self.num_seqs() - self.num_finished_seqs()

def get_seqs(
self,
Expand All @@ -824,7 +826,10 @@ def get_seqs(
if status is None:
return self.seqs

return self.seqs if self.first_seq.status == status else []
if self.is_single_seq:
return self.seqs if self.first_seq.status == status else []

return [seq for seq in self.seqs if seq.status == status]

def is_encoder_decoder(self) -> bool:
return self.encoder_seq is not None
Expand All @@ -833,19 +838,22 @@ def get_encoder_seq(self) -> Optional[Sequence]:
return self.encoder_seq

def get_finished_seqs(self) -> List[Sequence]:
return self.seqs if self.first_seq.is_finished() else []
if self.is_single_seq:
return self.seqs if self.first_seq.is_finished() else []

return [seq for seq in self.seqs if seq.is_finished()]

def update_num_computed_tokens(self, num_new_computed_tokens: int):
"""Update number of tokens computed so far."""
seq = self.first_seq
if not seq.is_finished():
seq.data.update_num_computed_tokens(num_new_computed_tokens)
for seq in self.seqs:
if not seq.is_finished():
seq.data.update_num_computed_tokens(num_new_computed_tokens)

def get_num_uncomputed_tokens(self) -> int:
num_uncomputed_tokens = 0
seq = self.first_seq
if not seq.is_finished():
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
for seq in self.seqs:
if not seq.is_finished():
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
return num_uncomputed_tokens

def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
Expand All @@ -860,10 +868,14 @@ def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
return len(self.get_seqs(status))

def num_finished_seqs(self) -> int:
return 1 if self.first_seq.is_finished() else 0
if self.is_single_seq:
return 1 if self.seqs[0].is_finished() else 0
return len(self.get_finished_seqs())

def is_finished(self) -> bool:
return self.first_seq.is_finished()
if self.is_single_seq:
return self.first_seq.is_finished()
return all(seq.is_finished() for seq in self.seqs)

def is_prefill(self) -> bool:
return self.first_seq.is_prefill()
Expand Down Expand Up @@ -1391,13 +1403,15 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
@staticmethod
def add_request(request_id: str, engine, params, **kwargs):
original_params = params
params = original_params.clone()
params.n = 1
group = ParallelSampleSequenceGroup(request_id)
seqs = []
for i in range(original_params.n):
request_id_i = f"{request_id}_parallel_sample_{i}"
group.seq_id_to_index[request_id_i] = i
params = copy.deepcopy(original_params)
params.n = 1
if params.seed is not None:
params.seed += i
Comment on lines +1411 to +1414
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this part makes sense to me.

seq_group = engine._add_processed_request(
request_id_i,
params=params,
Expand Down Expand Up @@ -1432,33 +1446,34 @@ def maybe_assemble_group(
self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:

# in the streaming mode, we will return the assembled sequence
# for the first sequence, and then return None for the rest of
# sequences
# for the first remaining sequence, and then return None for the
# rest of sequences
if self.streaming:
if self.seq_id_to_index[seq_group.request_id] == 0:
first_remaining_id = next(iter(self.to_be_finished))
if seq_group.request_id == first_remaining_id:
return self.assembled_seq_group
return None

# in the non-streaming mode, we will return the assembled sequence
# once after all sequences finish, and then return None for the
# when the last sequences finishes, and then return None for the
# rest of the time

if len(self.to_be_finished) > 0:
return None

assert self.assembled_seq_group is not None
params = self.assembled_seq_group.sampling_params
assert isinstance(params, SamplingParams)
if not self.output_produced:
self.output_produced = True
if params._real_n is not None:
# Get the top-n sequences.
n = params._real_n or params.n
seqs = self.assembled_seq_group.seqs
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
top_n_seqs = sorted_seqs[:n]
self.assembled_seq_group.seqs = top_n_seqs
return self.assembled_seq_group
if self.output_produced:
return None
if (len(self.to_be_finished) == 1
and seq_group.request_id in self.to_be_finished
and seq_group.is_finished()):
assert self.assembled_seq_group is not None
params = self.assembled_seq_group.sampling_params
assert isinstance(params, SamplingParams)
if not self.output_produced:
self.output_produced = True
if params._real_n is not None:
# Get the top-n sequences.
n = params._real_n or params.n
seqs = self.assembled_seq_group.seqs
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
top_n_seqs = sorted_seqs[:n]
self.assembled_seq_group.seqs = top_n_seqs
return self.assembled_seq_group
if self.output_produced:
return None
return None
Loading