-
Notifications
You must be signed in to change notification settings - Fork 804
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix min_p sampling crash when using flashinfer backend #3207
base: main
Are you sure you want to change the base?
Conversation
This seems unclear to me. Could you explain in what condition do Checking the definition in def min_p_sampling_from_probs(
probs: torch.Tensor,
uniform_samples: torch.Tensor,
min_p: Union[torch.Tensor, float],
deterministic: bool = True,
check_nan: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Fused GPU kernel for `min_p sampling <https://arxiv.org/abs/2407.01082>`_ from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
The multiple rounds of rejection sampling are implemented in a single CUDA kernel,
which is more efficient than the naive implementation that launches a series of kernels.
Parameters
----------
probs: torch.Tensor
Probabilities, shape ``(batch_size, num_classes)``.
uniform_samples: torch.Tensor
The uniform samples used as needle for sampling, shape ``(max_top_k_rounds, batch_size,)``,
where the first dimension is the maximum number of rounds for rejection sampling.
Expected to be uniformly distributed in ``[0, 1)``.
min_p: torch.Tensor
Either a scalar or a tensor of shape ``(batch_size,)``, representing the threshold for min-p sampling.
If a scalar, the same threshold is used for all requests.
If a tensor, each request has its own threshold.
deterministic: bool
Whether to use deterministic kernel implementation, default is ``True``.
check_nan: bool
Whether to check nan in :attr:`probs`, default is ``False``.
Returns
-------
samples: torch.Tensor
Sampled categories, shape ``(batch_size,)``.
success: torch.Tensor
Whether the sampling is successful within ``max_top_k_rounds`` rounds,
shape ``(batch_size,)``.
Examples
--------
>>> import torch
>>> import flashinfer
>>> torch.manual_seed(42)
<torch._C.Generator object at 0x7f8b3db06df0>
>>> batch_size = 4
>>> vocab_size = 5
>>> max_rounds = 3
>>> min_p = torch.full((batch_size,), 0.05).to(0)
>>> pre_norm_prob = torch.rand(batch_size, vocab_size).to(0)
>>> norm_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True)
>>> norm_prob
tensor([[0.2499, 0.2592, 0.1085, 0.2718, 0.1106],
[0.2205, 0.0942, 0.2912, 0.3452, 0.0489],
[0.2522, 0.1602, 0.2346, 0.1532, 0.2000],
[0.1543, 0.3182, 0.2062, 0.0958, 0.2255]], device='cuda:0')
>>> uniform_samples = torch.rand(max_rounds, batch_size).to(0)
>>> samples, success = flashinfer.sampling.min_p_sampling_from_probs(norm_prob, uniform_samples, min_p)
>>> samples
tensor([1, 2, 1, 4], device='cuda:0', dtype=torch.int32)
>>> success
tensor([True, True, True, True], device='cuda:0')
Note
----
This function expects float32 inputs, and the output is int32.
We encourage users to set ``max_rounds`` to a reasonable value, e.g., 32. The actual
implementation usually use much fewer rounds for rejection sampling because of early stopping.
"""
if check_nan:
if torch.any(torch.isnan(probs)):
raise ValueError("Input probs contains NaN.")
return _kernels.min_p_sampling_from_probs(
probs, uniform_samples, *_to_tensor_scalar_tuple(min_p), deterministic
) |
Is this the latest I am seeing the |
flashinfer backed won't return success states for min_p.
Motivation
Fixes #3201
Modifications
Checklist