diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 527a7d499b6..dc53e4445db 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -17,6 +17,8 @@ import torch import torch.nn.functional as F +from sglang.srt.utils import get_compiler_backend + def fused_topk_native( hidden_states: torch.Tensor, @@ -74,6 +76,7 @@ def fused_topk( # This is used by the Deepseek-V2 model +@torch.compile(dynamic=True, backend=get_compiler_backend()) def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -108,6 +111,7 @@ def grouped_topk( return topk_weights.to(torch.float32), topk_ids.to(torch.int32) +@torch.compile(dynamic=True, backend=get_compiler_backend()) def biased_grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor,