From 7b17c0406dd6c3cca494dc15cccd563bdd2ea3a6 Mon Sep 17 00:00:00 2001 From: Lu Fang Date: Wed, 12 Feb 2025 12:35:16 -0800 Subject: [PATCH] Support logit_bias in v1 Sampler Signed-off-by: Lu Fang --- tests/v1/sample/test_sampler.py | 70 +++++++++++++++++++++++++------ vllm/v1/sample/metadata.py | 2 + vllm/v1/sample/sampler.py | 14 +++++++ vllm/v1/worker/gpu_input_batch.py | 27 +++++++++++- 4 files changed, 100 insertions(+), 13 deletions(-) diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index f7eedcb9c58d6..ab4b33968f59d 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple import numpy as np import pytest @@ -45,6 +45,18 @@ def _create_prompt_tokens_tensor( ) +def _create_logit_bias( + batch_size: int, + vocab_size: int, + bias_value: float, +) -> List[Optional[Dict[int, float]]]: + res: List[Optional[Dict[int, float]]] = [] + for i in range(batch_size): + logit_bias = {min(i, vocab_size - 1): bias_value} + res.append(logit_bias) + return res + + def _create_default_sampling_metadata( num_output_tokens: int, batch_size: int, @@ -80,6 +92,7 @@ def _create_default_sampling_metadata( no_penalties=True, min_tokens=[], stop_token_ids=[], + logit_bias=[None] * batch_size, ) return fake_sampling_metadata @@ -89,14 +102,14 @@ def _generate_min_token_penalties_and_stop_tokens( batch_indices_for_min_token_penalty: List[int] ) -> Tuple[List[int], List[Set[int]]]: """ - Generates and returns a list of minimum token penalties (`min_tokens`) - and a corresponding list of stop token IDs (`stop_token_ids`) for each + Generates and returns a list of minimum token penalties (`min_tokens`) + and a corresponding list of stop token IDs (`stop_token_ids`) for each batch. - If a batch index is included in `batch_indices_for_min_token_penalty`, - a higher `min_tokens` value is assigned (within a randomized range), - and a random set of stop token IDs is created. Otherwise, a lower - `min_tokens` value is assigned, and the stop token IDs set is empty. + If a batch index is included in `batch_indices_for_min_token_penalty`, + a higher `min_tokens` value is assigned (within a randomized range), + and a random set of stop token IDs is created. Otherwise, a lower + `min_tokens` value is assigned, and the stop token IDs set is empty. """ stop_token_ids: List[Set[int]] = [] min_tokens: List[int] = [] @@ -120,7 +133,7 @@ def _create_weighted_output_token_list( batch_size: int, vocab_size: int) -> Tuple[List[List[int]], List[List[int]]]: """ - Creates an output token list where each token occurs a distinct + Creates an output token list where each token occurs a distinct number of times. For each batch, a random subset of token IDs is selected from the @@ -129,8 +142,8 @@ def _create_weighted_output_token_list( Returns: Tuple[List[List[int]], List[List[int]]]: - - The first element is the output token list, where each sublist - corresponds to a batch and contains tokens with weighted + - The first element is the output token list, where each sublist + corresponds to a batch and contains tokens with weighted frequencies. - The second element is a list of distinct token IDs for each batch, ordered by their frequency in the corresponding output @@ -155,7 +168,7 @@ def _create_weighted_output_token_list( @pytest.mark.parametrize("batch_size", [1, 2, 32]) def test_sampler_min_tokens_penalty(device: str, batch_size: int): """ - Tests that if the number of output tokens is less than + Tests that if the number of output tokens is less than SamplingParams.min_tokens then we will set the logits for the stop token ids to -inf. """ @@ -283,7 +296,7 @@ def test_sampler_frequency_penalty(device: str, batch_size: int, def test_sampler_repetition_penalty(device: str, batch_size: int, repetition_penalty: float): """ - Test to verify that when the repetition penalty is enabled, tokens + Test to verify that when the repetition penalty is enabled, tokens are penalized based on their presence in the prompt or the existing output. """ @@ -321,3 +334,36 @@ def test_sampler_repetition_penalty(device: str, batch_size: int, penalized_token_id not in output_tokens) assert (non_penalized_token_id in prompt_tokens or \ non_penalized_token_id in output_tokens) + + +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("batch_size", [1, 2, 32]) +@pytest.mark.parametrize("bias_value", [-0.1, 1.2]) +def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float): + """ + Test to verify that when the repetition penalty is enabled, tokens + are penalized based on their presence in the prompt or the existing + output. + """ + torch.set_default_device(device) + # Create fake logits where each token is assigned the same + # logit value. + fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) + sampling_metadata = _create_default_sampling_metadata( + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + sampling_metadata.logit_bias = _create_logit_bias( + batch_size=batch_size, + vocab_size=VOCAB_SIZE, + bias_value=bias_value, + ) + sampler = Sampler() + logits = sampler.apply_logits_bias(fake_logits, sampling_metadata) + logits = logits.cpu() + for batch_idx in range(batch_size): + logits_for_req = logits[batch_idx] + biased_index = min(batch_idx, VOCAB_SIZE - 1) + for token_id in range(VOCAB_SIZE): + if biased_index == token_id: + assert logits_for_req[token_id] == bias_value + 1e-2 + else: + assert logits_for_req[token_id] == 1e-2 diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 1a2771baba963..6c2478bf662f2 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -32,3 +32,5 @@ class SamplingMetadata: output_token_ids: List[List[int]] min_tokens: List[int] stop_token_ids: List[Set[int]] + + logit_bias: List[Optional[Dict[int, float]]] diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 43fd64aaaa828..2713fa47fcf82 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -166,3 +166,17 @@ def apply_penalties( sampling_metadata.repetition_penalties, sampling_metadata.output_token_ids) return logits + + def apply_logits_bias( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + # TODO(houseroad): this implementation is extremely inefficient. + # One idea is implement this as a PyTorch C++ op, and we may + # even optimize the logit_bias layout. + for i, logit_bias in enumerate(sampling_metadata.logit_bias): + if logit_bias: + for token_id, bias in logit_bias.items(): + logits[i, token_id] += bias + return logits diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index d5b8fd2184156..c19cd3db6599c 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -3,7 +3,7 @@ # Datastructures defining an input batch from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union import numpy as np import torch @@ -180,6 +180,9 @@ def __init__( # that are currently in the prefill phase. self.num_prompt_logprobs: Dict[str, int] = {} + self.logit_bias: List[Optional[Dict[int, + float]]] = [None] * max_num_reqs + def add_request( self, request: "CachedRequestState", @@ -244,6 +247,9 @@ def add_request( self.num_logprobs[req_id] = sampling_params.logprobs if sampling_params.prompt_logprobs is not None: self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs + if sampling_params.logit_bias is not None: + self.logit_bias[req_index] = self._construct_logit_bias( + sampling_params.logit_bias) # Add request lora ID if request.lora_request: @@ -284,6 +290,7 @@ def remove_request(self, req_id: str) -> Optional[int]: self.lora_id_to_lora_request.pop(lora_id) self.request_lora_mapping[req_index] = 0 + self.logit_bias[req_index] = None return req_index def clear(self) -> None: @@ -302,6 +309,7 @@ def clear(self) -> None: self.request_lora_mapping.fill(0) self.lora_id_to_lora_request.clear() self.lora_id_to_request_ids.clear() + self.logit_bias = [None] * self.max_num_reqs def condense(self, empty_req_indices: List[int]) -> None: if self.num_reqs == 0: @@ -421,6 +429,7 @@ def make_sampling_metadata( min_tokens=self.min_tokens[:self.num_reqs], stop_token_ids=self.stop_token_ids[:self.num_reqs], no_penalties=self.no_penalties, + logit_bias=self.logit_bias, ) def _make_prompt_token_ids_tensor(self) -> torch.Tensor: @@ -463,6 +472,22 @@ def make_lora_inputs( return prompt_lora_mapping, token_lora_mapping, active_lora_requests + def _construct_logit_bias( + self, logit_bias: Union[Dict[int, float], + Dict[str, float]]) -> Dict[int, float]: + try: + # Convert token_id to integer + # Clamp the bias between -100 and 100 per OpenAI API spec + clamped_logit_bias: Dict[int, float] = { + int(token_id): min(100.0, max(-100.0, bias)) + for token_id, bias in logit_bias.items() + } + except ValueError as exc: + raise ValueError( + "Found token_id in logit_bias that is not " + "an integer or string representing an integer") from exc + return clamped_logit_bias + @property def num_reqs(self) -> int: return len(self.req_id_to_index)