From a3660e32f3d0ed0b4b17e953c375c873e92ab365 Mon Sep 17 00:00:00 2001 From: Zifei Tong Date: Wed, 29 Jan 2025 09:11:03 -0800 Subject: [PATCH] Fix min_p sampling crash when using flashinfer backend --- python/sglang/srt/layers/sampler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index b24bfc8dacf..c20e478a1af 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -83,7 +83,7 @@ def forward( if sampling_info.need_min_p_sampling: probs = top_k_renorm_prob(probs, sampling_info.top_ks) probs = top_p_renorm_prob(probs, sampling_info.top_ps) - batch_next_token_ids, success = min_p_sampling_from_probs( + batch_next_token_ids = min_p_sampling_from_probs( probs, uniform_samples, sampling_info.min_ps ) else: @@ -95,9 +95,9 @@ def forward( filter_apply_order="joint", ) - if self.use_nan_detectioin and not torch.all(success): - logger.warning("Detected errors during sampling!") - batch_next_token_ids = torch.zeros_like(batch_next_token_ids) + if self.use_nan_detectioin and not torch.all(success): + logger.warning("Detected errors during sampling!") + batch_next_token_ids = torch.zeros_like(batch_next_token_ids) elif global_server_args_dict["sampling_backend"] == "pytorch": # A slower fallback implementation with torch native operations.