diff --git a/tests/v1/sample/test_sampler.py b/tests/v1/sample/test_sampler.py index 03606af3867d7..cfef475d8dee2 100644 --- a/tests/v1/sample/test_sampler.py +++ b/tests/v1/sample/test_sampler.py @@ -81,6 +81,8 @@ def _create_default_sampling_metadata( top_k=torch.empty(batch_size, ), no_top_p=True, no_top_k=True, + min_p=torch.empty(batch_size, ), + no_min_p=True, generators={}, max_num_logprobs=0, prompt_token_ids=_create_prompt_tokens_tensor(prompt_token_ids, @@ -336,6 +338,46 @@ def test_sampler_repetition_penalty(device: str, batch_size: int, non_penalized_token_id in output_tokens) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("batch_size", [1, 2, 32]) +@pytest.mark.parametrize("min_p", [0.0, 0.1]) +def test_sampler_min_p(device: str, batch_size: int, min_p: float): + """ + Tests that when min_p is applied, tokens with probability below + min_p * max_prob are masked with -inf. + """ + torch.set_default_device(device) + fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE) + + # Create one dominant token per batch + for i in range(batch_size): + fake_logits[i, 0] = 10.0 # High logit for first token + fake_logits[i, 1:] = 1e-2 # Others remain low + + sampling_metadata = _create_default_sampling_metadata( + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + + # Configure min_p parameters + sampling_metadata.min_p = torch.full((batch_size, ), min_p, device=device) + + sampler = Sampler() + logits = sampler.apply_min_p(fake_logits, sampling_metadata.min_p) + logits = logits.cpu() + + for batch_idx in range(batch_size): + for token_id in range(VOCAB_SIZE): + if token_id == 0: + # Dominant token should always be unmasked + assert logits[batch_idx][token_id] != -float("inf") + else: + if min_p > 0.0: + # Non-dominant tokens should be masked when min_p > 0 + assert logits[batch_idx][token_id] == -float("inf") + else: + # No masking when min_p is 0 + assert logits[batch_idx][token_id] != -float("inf") + + @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("batch_size", [1, 2, 32]) @pytest.mark.parametrize("bias_value", [-0.1, 1.2]) diff --git a/vllm/v1/sample/metadata.py b/vllm/v1/sample/metadata.py index 6c2478bf662f2..cfcc54b7e3436 100644 --- a/vllm/v1/sample/metadata.py +++ b/vllm/v1/sample/metadata.py @@ -17,6 +17,8 @@ class SamplingMetadata: top_k: torch.Tensor no_top_p: bool no_top_k: bool + min_p: torch.Tensor + no_min_p: bool generators: Dict[int, torch.Generator] diff --git a/vllm/v1/sample/sampler.py b/vllm/v1/sample/sampler.py index 739dc811d5d93..ac32c90d67699 100644 --- a/vllm/v1/sample/sampler.py +++ b/vllm/v1/sample/sampler.py @@ -93,6 +93,10 @@ def sample( sampling_metadata.no_top_p, sampling_metadata.top_p, ) + + if not sampling_metadata.no_min_p: + logits = self.apply_min_p(logits, sampling_metadata.min_p) + if sampling_metadata.all_random: return random_sampled @@ -169,6 +173,28 @@ def apply_penalties( sampling_metadata.output_token_ids) return logits + def apply_min_p( + self, + logits: torch.Tensor, + min_p: torch.Tensor, + ) -> torch.Tensor: + """ + Filters logits using adaptive probability thresholding. + """ + # Convert logits to probability distribution + probability_values = torch.nn.functional.softmax(logits, dim=-1) + # Calculate maximum probabilities per sequence + max_probabilities = torch.amax(probability_values, + dim=-1, + keepdim=True) + # Reshape min_p for broadcasting + adjusted_min_p = min_p.unsqueeze(1) * max_probabilities + # Identify valid tokens using threshold comparison + valid_token_mask = probability_values >= adjusted_min_p + # Apply mask using boolean indexing + logits[~valid_token_mask] = -float('inf') + return logits + def apply_logits_bias( self, logits: torch.Tensor, diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index d52b8827d35ee..1604aeab3206d 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -14,6 +14,8 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.block_table import BlockTable +_SAMPLING_EPS = 1e-5 + if TYPE_CHECKING: from vllm.multimodal.inputs import PlaceholderRange @@ -120,6 +122,16 @@ def __init__( self.top_k_cpu = self.top_k_cpu_tensor.numpy() self.top_k_reqs: Set[str] = set() + self.min_p = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + self.min_p_cpu_tensor = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device="cpu", + pin_memory=pin_memory) + self.min_p_cpu = self.min_p_cpu_tensor.numpy() + self.min_p_reqs: Set[str] = set() + # Frequency penalty related data structures self.frequency_penalties = torch.empty((max_num_reqs, ), dtype=torch.float, @@ -223,8 +235,11 @@ def add_request( self.top_k_cpu[req_index] = sampling_params.top_k if sampling_params.top_k > 0: self.top_k_reqs.add(req_id) + self.min_p_cpu[req_index] = sampling_params.min_p self.frequency_penalties_cpu[ req_index] = sampling_params.frequency_penalty + if sampling_params.min_p > _SAMPLING_EPS: + self.min_p_reqs.add(req_id) if sampling_params.frequency_penalty != 0.0: self.frequency_penalties_reqs.add(req_id) self.presence_penalties_cpu[ @@ -273,6 +288,7 @@ def remove_request(self, req_id: str) -> Optional[int]: self.random_reqs.discard(req_id) self.top_p_reqs.discard(req_id) self.top_k_reqs.discard(req_id) + self.min_p_reqs.discard(req_id) self.frequency_penalties_reqs.discard(req_id) self.presence_penalties_reqs.discard(req_id) self.repetition_penalties_reqs.discard(req_id) @@ -299,6 +315,7 @@ def clear(self) -> None: self.random_reqs.clear() self.top_p_reqs.clear() self.top_k_reqs.clear() + self.min_p_reqs.clear() self.frequency_penalties_reqs.clear() self.presence_penalties_reqs.clear() self.repetition_penalties_reqs.clear() @@ -354,6 +371,7 @@ def condense(self, empty_req_indices: List[int]) -> None: empty_index] = self.presence_penalties_cpu[last_req_index] self.repetition_penalties_cpu[ empty_index] = self.repetition_penalties_cpu[last_req_index] + self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index] self.min_tokens[empty_index] = self.min_tokens[last_req_index] self.stop_token_ids[empty_index] = self.stop_token_ids[ last_req_index] @@ -381,6 +399,8 @@ def make_sampling_metadata( self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) self.top_k[:self.num_reqs].copy_( self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) + self.min_p[:self.num_reqs].copy_( + self.min_p_cpu_tensor[:self.num_reqs], non_blocking=True) if not self.no_penalties: # Since syncing these tensors is expensive only copy them # if necessary i.e. if there are requests which require @@ -421,6 +441,8 @@ def make_sampling_metadata( all_random=self.all_random, top_p=self.top_p[:self.num_reqs], top_k=self.top_k[:self.num_reqs], + min_p=self.min_p[:self.num_reqs], + no_min_p=self.no_min_p, no_top_p=self.no_top_p, no_top_k=self.no_top_k, generators=self.generators, @@ -497,6 +519,10 @@ def no_top_p(self) -> bool: def no_top_k(self) -> bool: return len(self.top_k_reqs) == 0 + @property + def no_min_p(self) -> bool: + return len(self.min_p_reqs) == 0 + @property def no_penalties(self) -> bool: return (len(self.presence_penalties_reqs) == 0