Skip to content

Commit

Permalink
Add activation parameters to fused_moe (#3170)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Jan 27, 2025
1 parent 741fccd commit 52c03f1
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 7 deletions.
3 changes: 3 additions & 0 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(
tp_size: Optional[int] = None,
prefix: str = "",
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
):
super().__init__()

Expand All @@ -140,6 +141,7 @@ def __init__(
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias
self.activation = activation

if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
Expand All @@ -166,6 +168,7 @@ def __init__(

def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
assert self.quant_method is not None
assert self.activation == "silu"

if self.grouped_gemm_runner is None:
self.grouped_gemm_runner = GroupedGemmRunner(
Expand Down
20 changes: 17 additions & 3 deletions python/sglang/srt/layers/moe/fused_moe_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
from torch.nn import functional as F

from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
from sglang.srt.layers.moe.topk import select_experts


Expand All @@ -23,6 +23,7 @@ def fused_moe_forward_native(
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
hidden_states=x,
Expand All @@ -41,7 +42,12 @@ def fused_moe_forward_native(
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
w2_weights = layer.w2_weight[topk_ids]
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
x1 = F.silu(x1)
if activation == "silu":
x1 = F.silu(x1)
elif activation == "gelu":
x1 = F.gelu(x1)
else:
raise ValueError(f"Unsupported activation: {activation=}")
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
Expand All @@ -58,6 +64,7 @@ def moe_forward_native(
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:

topk_weights, topk_ids = select_experts(
Expand All @@ -84,6 +91,13 @@ def moe_forward_native(
sorted_tokens = x[idxs // topk_ids.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy()

if activation == "silu":
act = SiluAndMul()
elif activation == "gelu":
act = GeluAndMul()
else:
raise ValueError(f"Unsupported activation: {activation=}")

outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
Expand All @@ -96,7 +110,7 @@ def moe_forward_native(
layer_w2_weight = layer.w2_weight[i]

gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)
gate_up = SiluAndMul()(gate_up)
gate_up = act(gate_up)
expert_out = F.linear(gate_up, layer_w2_weight)
outputs.append(expert_out)
start_idx = end_idx
Expand Down
19 changes: 18 additions & 1 deletion python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,7 @@ def inplace_fused_experts(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
Expand All @@ -726,6 +727,7 @@ def inplace_fused_experts(
topk_weights,
topk_ids,
True,
activation,
use_fp8_w8a8,
use_int8_w8a16,
w1_scale,
Expand All @@ -742,6 +744,7 @@ def inplace_fused_experts_fake(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
Expand All @@ -767,6 +770,7 @@ def outplace_fused_experts(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
Expand All @@ -782,6 +786,7 @@ def outplace_fused_experts(
topk_weights,
topk_ids,
False,
activation,
use_fp8_w8a8,
use_int8_w8a16,
w1_scale,
Expand All @@ -798,6 +803,7 @@ def outplace_fused_experts_fake(
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
Expand All @@ -824,6 +830,7 @@ def fused_experts(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
Expand All @@ -839,6 +846,7 @@ def fused_experts(
w2,
topk_weights,
topk_ids,
activation,
use_fp8_w8a8,
use_int8_w8a16,
w1_scale,
Expand All @@ -855,6 +863,7 @@ def fused_experts(
w2,
topk_weights,
topk_ids,
activation,
use_fp8_w8a8,
use_int8_w8a16,
w1_scale,
Expand All @@ -872,6 +881,7 @@ def fused_experts_impl(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -986,7 +996,12 @@ def fused_experts_impl(
block_shape=block_shape,
)

ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
if activation == "silu":
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
elif activation == "gelu":
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
else:
raise ValueError(f"Unsupported activation: {activation=}")

invoke_fused_moe_kernel(
intermediate_cache2,
Expand Down Expand Up @@ -1042,6 +1057,7 @@ def fused_moe(
topk: int,
renormalize: bool,
inplace: bool = False,
activation: str = "silu",
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
Expand Down Expand Up @@ -1111,6 +1127,7 @@ def fused_moe(
topk_weights,
topk_ids,
inplace=inplace,
activation=activation,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale,
Expand Down
9 changes: 9 additions & 0 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def apply(
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
return self.forward(
x=x,
Expand All @@ -138,6 +139,7 @@ def apply(
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
activation=activation,
)

def forward_cuda(
Expand All @@ -152,6 +154,7 @@ def forward_cuda(
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
hidden_states=x,
Expand All @@ -169,6 +172,8 @@ def forward_cuda(
import ater
from ater.fused_moe import fused_experts_ck

assert activation == "silu", f"{activation=} is not supported."

return fused_experts_ck(
hidden_states=x,
w1=layer.w13_weight,
Expand All @@ -184,6 +189,7 @@ def forward_cuda(
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
)

def forward_cpu(
Expand Down Expand Up @@ -256,6 +262,7 @@ def __init__(
prefix: str = "",
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
use_presharded_weights: bool = False,
):
super().__init__()
Expand All @@ -279,6 +286,7 @@ def __init__(
self.topk_group = topk_group
self.custom_routing_function = custom_routing_function
self.correction_bias = correction_bias
self.activation = activation

if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
Expand Down Expand Up @@ -589,6 +597,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
activation=self.activation,
)

if self.reduce_results and self.tp_size > 1:
Expand Down
5 changes: 4 additions & 1 deletion python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,8 +763,8 @@ def apply(
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts

Expand All @@ -785,6 +785,8 @@ def apply(
import ater
from ater.fused_moe import fused_experts_ck

assert activation == "silu", f"{activation=} is not supported."

return fused_experts_ck(
x,
layer.w13_weight,
Expand Down Expand Up @@ -815,6 +817,7 @@ def apply(
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
use_fp8_w8a8=True,
w1_scale=(
layer.w13_weight_scale_inv
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/models/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(
renormalize=False,
quant_config=quant_config,
tp_size=tp_size,
activation="gelu",
use_presharded_weights=use_presharded_weights,
)

Expand Down
2 changes: 0 additions & 2 deletions test/srt/test_fp8_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import torch

from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8,
w8a8_block_fp8_matmul,
Expand Down

0 comments on commit 52c03f1

Please sign in to comment.