Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add activation parameters to fused_moe #3170

Merged
merged 3 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(
tp_size: Optional[int] = None,
prefix: str = "",
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
):
super().__init__()

Expand All @@ -140,6 +141,7 @@ def __init__(
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias
self.activation = activation

if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
Expand All @@ -166,6 +168,7 @@ def __init__(

def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
assert self.quant_method is not None
assert self.activation == "silu"

if self.grouped_gemm_runner is None:
self.grouped_gemm_runner = GroupedGemmRunner(
Expand Down
20 changes: 17 additions & 3 deletions python/sglang/srt/layers/moe/fused_moe_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
from torch.nn import functional as F

from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
from sglang.srt.layers.moe.topk import select_experts


Expand All @@ -23,6 +23,7 @@ def fused_moe_forward_native(
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
hidden_states=x,
Expand All @@ -41,7 +42,12 @@ def fused_moe_forward_native(
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
w2_weights = layer.w2_weight[topk_ids]
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
x1 = F.silu(x1)
if activation == "silu":
x1 = F.silu(x1)
elif activation == "gelu":
x1 = F.gelu(x1)
else:
raise ValueError(f"Unsupported activation: {activation=}")
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
Expand All @@ -58,6 +64,7 @@ def moe_forward_native(
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:

topk_weights, topk_ids = select_experts(
Expand All @@ -84,6 +91,13 @@ def moe_forward_native(
sorted_tokens = x[idxs // topk_ids.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy()

if activation == "silu":
act = SiluAndMul()
elif activation == "gelu":
act = GeluAndMul()
else:
raise ValueError(f"Unsupported activation: {activation=}")

outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
Expand All @@ -96,7 +110,7 @@ def moe_forward_native(
layer_w2_weight = layer.w2_weight[i]

gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)
gate_up = SiluAndMul()(gate_up)
gate_up = act(gate_up)
expert_out = F.linear(gate_up, layer_w2_weight)
outputs.append(expert_out)
start_idx = end_idx
Expand Down
19 changes: 18 additions & 1 deletion python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,7 @@ def inplace_fused_experts(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
Expand All @@ -726,6 +727,7 @@ def inplace_fused_experts(
topk_weights,
topk_ids,
True,
activation,
use_fp8_w8a8,
use_int8_w8a16,
w1_scale,
Expand All @@ -742,6 +744,7 @@ def inplace_fused_experts_fake(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
Expand All @@ -767,6 +770,7 @@ def outplace_fused_experts(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
Expand All @@ -782,6 +786,7 @@ def outplace_fused_experts(
topk_weights,
topk_ids,
False,
activation,
use_fp8_w8a8,
use_int8_w8a16,
w1_scale,
Expand All @@ -798,6 +803,7 @@ def outplace_fused_experts_fake(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
Expand All @@ -824,6 +830,7 @@ def fused_experts(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
Expand All @@ -839,6 +846,7 @@ def fused_experts(
w2,
topk_weights,
topk_ids,
activation,
use_fp8_w8a8,
use_int8_w8a16,
w1_scale,
Expand All @@ -855,6 +863,7 @@ def fused_experts(
w2,
topk_weights,
topk_ids,
activation,
use_fp8_w8a8,
use_int8_w8a16,
w1_scale,
Expand All @@ -872,6 +881,7 @@ def fused_experts_impl(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -986,7 +996,12 @@ def fused_experts_impl(
block_shape=block_shape,
)

ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
if activation == "silu":
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
elif activation == "gelu":
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
else:
raise ValueError(f"Unsupported activation: {activation=}")

invoke_fused_moe_kernel(
intermediate_cache2,
Expand Down Expand Up @@ -1042,6 +1057,7 @@ def fused_moe(
topk: int,
renormalize: bool,
inplace: bool = False,
activation: str = "silu",
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
Expand Down Expand Up @@ -1111,6 +1127,7 @@ def fused_moe(
topk_weights,
topk_ids,
inplace=inplace,
activation=activation,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale,
Expand Down
9 changes: 9 additions & 0 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def apply(
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
return self.forward(
x=x,
Expand All @@ -138,6 +139,7 @@ def apply(
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
activation=activation,
)

def forward_cuda(
Expand All @@ -152,6 +154,7 @@ def forward_cuda(
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
hidden_states=x,
Expand All @@ -169,6 +172,8 @@ def forward_cuda(
import ater
from ater.fused_moe import fused_experts_ck

assert activation == "silu", f"{activation=} is not supported."

return fused_experts_ck(
hidden_states=x,
w1=layer.w13_weight,
Expand All @@ -184,6 +189,7 @@ def forward_cuda(
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
)

def forward_cpu(
Expand Down Expand Up @@ -256,6 +262,7 @@ def __init__(
prefix: str = "",
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
use_presharded_weights: bool = False,
):
super().__init__()
Expand All @@ -279,6 +286,7 @@ def __init__(
self.topk_group = topk_group
self.custom_routing_function = custom_routing_function
self.correction_bias = correction_bias
self.activation = activation

if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
Expand Down Expand Up @@ -589,6 +597,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
activation=self.activation,
)

if self.reduce_results and self.tp_size > 1:
Expand Down
5 changes: 4 additions & 1 deletion python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,8 +763,8 @@ def apply(
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts

Expand All @@ -785,6 +785,8 @@ def apply(
import ater
from ater.fused_moe import fused_experts_ck

assert activation == "silu", f"{activation=} is not supported."

return fused_experts_ck(
x,
layer.w13_weight,
Expand Down Expand Up @@ -815,6 +817,7 @@ def apply(
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_fp8_w8a8=True,
w1_scale=(
layer.w13_weight_scale_inv
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/models/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(
renormalize=False,
quant_config=quant_config,
tp_size=tp_size,
activation="gelu",
use_presharded_weights=use_presharded_weights,
)

Expand Down
2 changes: 0 additions & 2 deletions test/srt/test_fp8_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import torch

from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8,
w8a8_block_fp8_matmul,
Expand Down
Loading