Skip to content

Commit

Permalink
Fix min_p sampling
Browse files Browse the repository at this point in the history
flashinfer backed won't return success states for min_p.
  • Loading branch information
zifeitong committed Jan 29, 2025
1 parent 9ae1db0 commit 5bbc421
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions python/sglang/srt/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def forward(
if sampling_info.need_min_p_sampling:
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
batch_next_token_ids, success = min_p_sampling_from_probs(
batch_next_token_ids = min_p_sampling_from_probs(
probs, uniform_samples, sampling_info.min_ps
)
else:
Expand All @@ -63,9 +63,9 @@ def forward(
filter_apply_order="joint",
)

if not torch.all(success):
logger.warning("Detected errors during sampling!")
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
if not torch.all(success):
logger.warning("Detected errors during sampling!")
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
elif global_server_args_dict["sampling_backend"] == "pytorch":
# Here we provide a slower fallback implementation.
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
Expand Down

0 comments on commit 5bbc421

Please sign in to comment.