From 241ad7b301facac0728e2b3312d71fe47acc8c9e Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 10 Jan 2025 20:45:33 +0800 Subject: [PATCH] [ci] Fix sampler tests (#11922) Signed-off-by: youkaichao --- .buildkite/test-pipeline.yaml | 1 + tests/conftest.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index e288f8f30159a..7d13269540864 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -214,6 +214,7 @@ steps: - vllm/model_executor/layers - vllm/sampling_metadata.py - tests/samplers + - tests/conftest.py commands: - pytest -v -s samplers - VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers diff --git a/tests/conftest.py b/tests/conftest.py index 917151ddcb8d4..95af4ac1eb17b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,12 +28,13 @@ init_distributed_environment, initialize_model_parallel) from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt, - to_enc_dec_tuple_list, zip_enc_dec_prompts) + TokensPrompt, to_enc_dec_tuple_list, + zip_enc_dec_prompts) from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless, - identity) + identity, is_list_of) logger = init_logger(__name__) @@ -886,6 +887,12 @@ def generate_beam_search( beam_width: int, max_tokens: int, ) -> List[Tuple[List[List[int]], List[str]]]: + if is_list_of(prompts, str, check="all"): + prompts = [TextPrompt(prompt=prompt) for prompt in prompts] + else: + prompts = [ + TokensPrompt(prompt_token_ids=tokens) for tokens in prompts + ] outputs = self.model.beam_search( prompts, BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))