Skip to content

Commit

Permalink
Support logit_bias in v1 Sampler
Browse files Browse the repository at this point in the history
Signed-off-by: Lu Fang <[email protected]>
  • Loading branch information
houseroad committed Feb 12, 2025
1 parent e92694b commit 7b17c04
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 13 deletions.
70 changes: 58 additions & 12 deletions tests/v1/sample/test_sampler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0

from typing import List, Set, Tuple
from typing import Dict, List, Optional, Set, Tuple

import numpy as np
import pytest
Expand Down Expand Up @@ -45,6 +45,18 @@ def _create_prompt_tokens_tensor(
)


def _create_logit_bias(
batch_size: int,
vocab_size: int,
bias_value: float,
) -> List[Optional[Dict[int, float]]]:
res: List[Optional[Dict[int, float]]] = []
for i in range(batch_size):
logit_bias = {min(i, vocab_size - 1): bias_value}
res.append(logit_bias)
return res


def _create_default_sampling_metadata(
num_output_tokens: int,
batch_size: int,
Expand Down Expand Up @@ -80,6 +92,7 @@ def _create_default_sampling_metadata(
no_penalties=True,
min_tokens=[],
stop_token_ids=[],
logit_bias=[None] * batch_size,
)
return fake_sampling_metadata

Expand All @@ -89,14 +102,14 @@ def _generate_min_token_penalties_and_stop_tokens(
batch_indices_for_min_token_penalty: List[int]
) -> Tuple[List[int], List[Set[int]]]:
"""
Generates and returns a list of minimum token penalties (`min_tokens`)
and a corresponding list of stop token IDs (`stop_token_ids`) for each
Generates and returns a list of minimum token penalties (`min_tokens`)
and a corresponding list of stop token IDs (`stop_token_ids`) for each
batch.
If a batch index is included in `batch_indices_for_min_token_penalty`,
a higher `min_tokens` value is assigned (within a randomized range),
and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty.
If a batch index is included in `batch_indices_for_min_token_penalty`,
a higher `min_tokens` value is assigned (within a randomized range),
and a random set of stop token IDs is created. Otherwise, a lower
`min_tokens` value is assigned, and the stop token IDs set is empty.
"""
stop_token_ids: List[Set[int]] = []
min_tokens: List[int] = []
Expand All @@ -120,7 +133,7 @@ def _create_weighted_output_token_list(
batch_size: int,
vocab_size: int) -> Tuple[List[List[int]], List[List[int]]]:
"""
Creates an output token list where each token occurs a distinct
Creates an output token list where each token occurs a distinct
number of times.
For each batch, a random subset of token IDs is selected from the
Expand All @@ -129,8 +142,8 @@ def _create_weighted_output_token_list(
Returns:
Tuple[List[List[int]], List[List[int]]]:
- The first element is the output token list, where each sublist
corresponds to a batch and contains tokens with weighted
- The first element is the output token list, where each sublist
corresponds to a batch and contains tokens with weighted
frequencies.
- The second element is a list of distinct token IDs for each
batch, ordered by their frequency in the corresponding output
Expand All @@ -155,7 +168,7 @@ def _create_weighted_output_token_list(
@pytest.mark.parametrize("batch_size", [1, 2, 32])
def test_sampler_min_tokens_penalty(device: str, batch_size: int):
"""
Tests that if the number of output tokens is less than
Tests that if the number of output tokens is less than
SamplingParams.min_tokens then we will set the logits for
the stop token ids to -inf.
"""
Expand Down Expand Up @@ -283,7 +296,7 @@ def test_sampler_frequency_penalty(device: str, batch_size: int,
def test_sampler_repetition_penalty(device: str, batch_size: int,
repetition_penalty: float):
"""
Test to verify that when the repetition penalty is enabled, tokens
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
output.
"""
Expand Down Expand Up @@ -321,3 +334,36 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
penalized_token_id not in output_tokens)
assert (non_penalized_token_id in prompt_tokens or \
non_penalized_token_id in output_tokens)


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("bias_value", [-0.1, 1.2])
def test_sampler_logit_bias(device: str, batch_size: int, bias_value: float):
"""
Test to verify that when the repetition penalty is enabled, tokens
are penalized based on their presence in the prompt or the existing
output.
"""
torch.set_default_device(device)
# Create fake logits where each token is assigned the same
# logit value.
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
sampling_metadata.logit_bias = _create_logit_bias(
batch_size=batch_size,
vocab_size=VOCAB_SIZE,
bias_value=bias_value,
)
sampler = Sampler()
logits = sampler.apply_logits_bias(fake_logits, sampling_metadata)
logits = logits.cpu()
for batch_idx in range(batch_size):
logits_for_req = logits[batch_idx]
biased_index = min(batch_idx, VOCAB_SIZE - 1)
for token_id in range(VOCAB_SIZE):
if biased_index == token_id:
assert logits_for_req[token_id] == bias_value + 1e-2
else:
assert logits_for_req[token_id] == 1e-2
2 changes: 2 additions & 0 deletions vllm/v1/sample/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,5 @@ class SamplingMetadata:
output_token_ids: List[List[int]]
min_tokens: List[int]
stop_token_ids: List[Set[int]]

logit_bias: List[Optional[Dict[int, float]]]
14 changes: 14 additions & 0 deletions vllm/v1/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,17 @@ def apply_penalties(
sampling_metadata.repetition_penalties,
sampling_metadata.output_token_ids)
return logits

def apply_logits_bias(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
# TODO(houseroad): this implementation is extremely inefficient.
# One idea is implement this as a PyTorch C++ op, and we may
# even optimize the logit_bias layout.
for i, logit_bias in enumerate(sampling_metadata.logit_bias):
if logit_bias:
for token_id, bias in logit_bias.items():
logits[i, token_id] += bias
return logits
27 changes: 26 additions & 1 deletion vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Datastructures defining an input batch

from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -180,6 +180,9 @@ def __init__(
# that are currently in the prefill phase.
self.num_prompt_logprobs: Dict[str, int] = {}

self.logit_bias: List[Optional[Dict[int,
float]]] = [None] * max_num_reqs

def add_request(
self,
request: "CachedRequestState",
Expand Down Expand Up @@ -244,6 +247,9 @@ def add_request(
self.num_logprobs[req_id] = sampling_params.logprobs
if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[req_id] = sampling_params.prompt_logprobs
if sampling_params.logit_bias is not None:
self.logit_bias[req_index] = self._construct_logit_bias(
sampling_params.logit_bias)

# Add request lora ID
if request.lora_request:
Expand Down Expand Up @@ -284,6 +290,7 @@ def remove_request(self, req_id: str) -> Optional[int]:
self.lora_id_to_lora_request.pop(lora_id)
self.request_lora_mapping[req_index] = 0

self.logit_bias[req_index] = None
return req_index

def clear(self) -> None:
Expand All @@ -302,6 +309,7 @@ def clear(self) -> None:
self.request_lora_mapping.fill(0)
self.lora_id_to_lora_request.clear()
self.lora_id_to_request_ids.clear()
self.logit_bias = [None] * self.max_num_reqs

def condense(self, empty_req_indices: List[int]) -> None:
if self.num_reqs == 0:
Expand Down Expand Up @@ -421,6 +429,7 @@ def make_sampling_metadata(
min_tokens=self.min_tokens[:self.num_reqs],
stop_token_ids=self.stop_token_ids[:self.num_reqs],
no_penalties=self.no_penalties,
logit_bias=self.logit_bias,
)

def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
Expand Down Expand Up @@ -463,6 +472,22 @@ def make_lora_inputs(

return prompt_lora_mapping, token_lora_mapping, active_lora_requests

def _construct_logit_bias(
self, logit_bias: Union[Dict[int, float],
Dict[str, float]]) -> Dict[int, float]:
try:
# Convert token_id to integer
# Clamp the bias between -100 and 100 per OpenAI API spec
clamped_logit_bias: Dict[int, float] = {
int(token_id): min(100.0, max(-100.0, bias))
for token_id, bias in logit_bias.items()
}
except ValueError as exc:
raise ValueError(
"Found token_id in logit_bias that is not "
"an integer or string representing an integer") from exc
return clamped_logit_bias

@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)
Expand Down

0 comments on commit 7b17c04

Please sign in to comment.