Skip to content

Commit

Permalink
Optimize MoE topk with torch compile (#3236)
Browse files Browse the repository at this point in the history
  • Loading branch information
ispobock authored Jan 31, 2025
1 parent 7811bfd commit 1ebe1d6
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions python/sglang/srt/layers/moe/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import torch
import torch.nn.functional as F

from sglang.srt.utils import get_compiler_backend


def fused_topk_native(
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -74,6 +76,7 @@ def fused_topk(


# This is used by the Deepseek-V2 model
@torch.compile(dynamic=True, backend=get_compiler_backend())
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
Expand Down Expand Up @@ -108,6 +111,7 @@ def grouped_topk(
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)


@torch.compile(dynamic=True, backend=get_compiler_backend())
def biased_grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
Expand Down

0 comments on commit 1ebe1d6

Please sign in to comment.