Skip to content

Commit

Permalink
[fix] Clamp logprob with dtype min to prevent -inf (#3224)
Browse files Browse the repository at this point in the history
  • Loading branch information
ByronHsu authored Jan 31, 2025
1 parent 3ee6223 commit 734daed
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
7 changes: 5 additions & 2 deletions python/sglang/srt/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,11 @@ def forward(
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
# https://github.com/flashinfer-ai/flashinfer/issues/708
# so we use the torch implementation.

# clamp to avoid -inf
logprobs = torch.log(
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
)
).clamp(min=torch.finfo(probs.dtype).min)

max_top_k_round, batch_size = 32, probs.shape[0]
uniform_samples = torch.rand(
Expand Down Expand Up @@ -109,9 +111,10 @@ def forward(
sampling_info.need_min_p_sampling,
)
if return_logprob:
# clamp to avoid -inf
logprobs = torch.log(
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
)
).clamp(min=torch.finfo(probs.dtype).min)
else:
raise ValueError(
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def tearDownClass(cls):
def run_decode(
self,
return_logprob=True,
top_logprobs_num=3,
top_logprobs_num=5,
return_text=True,
n=1,
**sampling_params,
Expand All @@ -58,8 +58,7 @@ def run_decode(
"logprob_start_len": 0,
},
)
print(json.dumps(response.json()))
print("=" * 100)
assert response.status_code == 200, "Request failed: " + response.text

def test_default_values(self):
self.run_decode()
Expand Down Expand Up @@ -112,4 +111,4 @@ def test_repetition_penalty(self):


if __name__ == "__main__":
unittest.main()
unittest.main(verbosity=3)

0 comments on commit 734daed

Please sign in to comment.