Skip to content

Commit

Permalink
[V1][Core] min_p sampling support (#13191)
Browse files Browse the repository at this point in the history
Signed-off-by: Aoyu <[email protected]>
Co-authored-by: Aoyu <[email protected]>
  • Loading branch information
AoyuQC and AoyuQC authored Feb 14, 2025
1 parent 3bcb8c7 commit a12934d
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 0 deletions.
42 changes: 42 additions & 0 deletions tests/v1/sample/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def _create_default_sampling_metadata(
top_k=torch.empty(batch_size, ),
no_top_p=True,
no_top_k=True,
min_p=torch.empty(batch_size, ),
no_min_p=True,
generators={},
max_num_logprobs=0,
prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids,
Expand Down Expand Up @@ -336,6 +338,46 @@ def test_sampler_repetition_penalty(device: str, batch_size: int,
non_penalized_token_id in output_tokens)


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("min_p", [0.0, 0.1])
def test_sampler_min_p(device: str, batch_size: int, min_p: float):
"""
Tests that when min_p is applied, tokens with probability below
min_p * max_prob are masked with -inf.
"""
torch.set_default_device(device)
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)

# Create one dominant token per batch
for i in range(batch_size):
fake_logits[i, 0] = 10.0 # High logit for first token
fake_logits[i, 1:] = 1e-2 # Others remain low

sampling_metadata = _create_default_sampling_metadata(
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))

# Configure min_p parameters
sampling_metadata.min_p = torch.full((batch_size, ), min_p, device=device)

sampler = Sampler()
logits = sampler.apply_min_p(fake_logits, sampling_metadata.min_p)
logits = logits.cpu()

for batch_idx in range(batch_size):
for token_id in range(VOCAB_SIZE):
if token_id == 0:
# Dominant token should always be unmasked
assert logits[batch_idx][token_id] != -float("inf")
else:
if min_p > 0.0:
# Non-dominant tokens should be masked when min_p > 0
assert logits[batch_idx][token_id] == -float("inf")
else:
# No masking when min_p is 0
assert logits[batch_idx][token_id] != -float("inf")


@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("batch_size", [1, 2, 32])
@pytest.mark.parametrize("bias_value", [-0.1, 1.2])
Expand Down
2 changes: 2 additions & 0 deletions vllm/v1/sample/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class SamplingMetadata:
top_k: torch.Tensor
no_top_p: bool
no_top_k: bool
min_p: torch.Tensor
no_min_p: bool

generators: Dict[int, torch.Generator]

Expand Down
26 changes: 26 additions & 0 deletions vllm/v1/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ def sample(
sampling_metadata.no_top_p,
sampling_metadata.top_p,
)

if not sampling_metadata.no_min_p:
logits = self.apply_min_p(logits, sampling_metadata.min_p)

if sampling_metadata.all_random:
return random_sampled

Expand Down Expand Up @@ -169,6 +173,28 @@ def apply_penalties(
sampling_metadata.output_token_ids)
return logits

def apply_min_p(
self,
logits: torch.Tensor,
min_p: torch.Tensor,
) -> torch.Tensor:
"""
Filters logits using adaptive probability thresholding.
"""
# Convert logits to probability distribution
probability_values = torch.nn.functional.softmax(logits, dim=-1)
# Calculate maximum probabilities per sequence
max_probabilities = torch.amax(probability_values,
dim=-1,
keepdim=True)
# Reshape min_p for broadcasting
adjusted_min_p = min_p.unsqueeze(1) * max_probabilities
# Identify valid tokens using threshold comparison
valid_token_mask = probability_values >= adjusted_min_p
# Apply mask using boolean indexing
logits[~valid_token_mask] = -float('inf')
return logits

def apply_logits_bias(
self,
logits: torch.Tensor,
Expand Down
26 changes: 26 additions & 0 deletions vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.block_table import BlockTable

_SAMPLING_EPS = 1e-5

if TYPE_CHECKING:
from vllm.multimodal.inputs import PlaceholderRange

Expand Down Expand Up @@ -120,6 +122,16 @@ def __init__(
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
self.top_k_reqs: Set[str] = set()

self.min_p = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device=device)
self.min_p_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
self.min_p_reqs: Set[str] = set()

# Frequency penalty related data structures
self.frequency_penalties = torch.empty((max_num_reqs, ),
dtype=torch.float,
Expand Down Expand Up @@ -223,8 +235,11 @@ def add_request(
self.top_k_cpu[req_index] = sampling_params.top_k
if sampling_params.top_k > 0:
self.top_k_reqs.add(req_id)
self.min_p_cpu[req_index] = sampling_params.min_p
self.frequency_penalties_cpu[
req_index] = sampling_params.frequency_penalty
if sampling_params.min_p > _SAMPLING_EPS:
self.min_p_reqs.add(req_id)
if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id)
self.presence_penalties_cpu[
Expand Down Expand Up @@ -273,6 +288,7 @@ def remove_request(self, req_id: str) -> Optional[int]:
self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id)
self.top_k_reqs.discard(req_id)
self.min_p_reqs.discard(req_id)
self.frequency_penalties_reqs.discard(req_id)
self.presence_penalties_reqs.discard(req_id)
self.repetition_penalties_reqs.discard(req_id)
Expand All @@ -299,6 +315,7 @@ def clear(self) -> None:
self.random_reqs.clear()
self.top_p_reqs.clear()
self.top_k_reqs.clear()
self.min_p_reqs.clear()
self.frequency_penalties_reqs.clear()
self.presence_penalties_reqs.clear()
self.repetition_penalties_reqs.clear()
Expand Down Expand Up @@ -354,6 +371,7 @@ def condense(self, empty_req_indices: List[int]) -> None:
empty_index] = self.presence_penalties_cpu[last_req_index]
self.repetition_penalties_cpu[
empty_index] = self.repetition_penalties_cpu[last_req_index]
self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index]
self.min_tokens[empty_index] = self.min_tokens[last_req_index]
self.stop_token_ids[empty_index] = self.stop_token_ids[
last_req_index]
Expand Down Expand Up @@ -381,6 +399,8 @@ def make_sampling_metadata(
self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True)
self.top_k[:self.num_reqs].copy_(
self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True)
self.min_p[:self.num_reqs].copy_(
self.min_p_cpu_tensor[:self.num_reqs], non_blocking=True)
if not self.no_penalties:
# Since syncing these tensors is expensive only copy them
# if necessary i.e. if there are requests which require
Expand Down Expand Up @@ -421,6 +441,8 @@ def make_sampling_metadata(
all_random=self.all_random,
top_p=self.top_p[:self.num_reqs],
top_k=self.top_k[:self.num_reqs],
min_p=self.min_p[:self.num_reqs],
no_min_p=self.no_min_p,
no_top_p=self.no_top_p,
no_top_k=self.no_top_k,
generators=self.generators,
Expand Down Expand Up @@ -497,6 +519,10 @@ def no_top_p(self) -> bool:
def no_top_k(self) -> bool:
return len(self.top_k_reqs) == 0

@property
def no_min_p(self) -> bool:
return len(self.min_p_reqs) == 0

@property
def no_penalties(self) -> bool:
return (len(self.presence_penalties_reqs) == 0
Expand Down

0 comments on commit a12934d

Please sign in to comment.