Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize MoE topk with torch compile #3236

Merged
merged 3 commits into from
Jan 31, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/sglang/srt/layers/moe/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import torch
import torch.nn.functional as F

from sglang.srt.utils import get_compiler_backend


def fused_topk_native(
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -74,6 +76,7 @@ def fused_topk(


# This is used by the Deepseek-V2 model
@torch.compile(dynamic=True, backend=get_compiler_backend())
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
Expand Down Expand Up @@ -108,6 +111,7 @@ def grouped_topk(
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)


@torch.compile(dynamic=True, backend=get_compiler_backend())
def biased_grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
Expand Down
Loading