From 72f8ed5134e5e637214f9bcd3bcd77e04787f1ba Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 26 Jan 2025 23:57:01 -0800 Subject: [PATCH 1/3] Fix fused moe --- .../sglang/srt/layers/moe/fused_moe_native.py | 20 ++++++++++++++++--- .../layers/moe/fused_moe_triton/fused_moe.py | 19 +++++++++++++++++- .../srt/layers/moe/fused_moe_triton/layer.py | 8 ++++++++ python/sglang/srt/layers/quantization/fp8.py | 5 ++++- python/sglang/srt/models/grok.py | 1 + test/srt/test_fp8_kernel.py | 2 -- 6 files changed, 48 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_native.py b/python/sglang/srt/layers/moe/fused_moe_native.py index 0703e840ca6..042c0a52c56 100644 --- a/python/sglang/srt/layers/moe/fused_moe_native.py +++ b/python/sglang/srt/layers/moe/fused_moe_native.py @@ -8,7 +8,7 @@ import torch from torch.nn import functional as F -from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.activation import GeluAndMul, SiluAndMul from sglang.srt.layers.moe.topk import select_experts @@ -23,6 +23,7 @@ def fused_moe_forward_native( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: topk_weights, topk_ids = select_experts( hidden_states=x, @@ -41,7 +42,12 @@ def fused_moe_forward_native( w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2) w2_weights = layer.w2_weight[topk_ids] x1 = torch.einsum("ti,taoi -> tao", x, w1_weights) - x1 = F.silu(x1) + if activation == "silu": + x1 = F.silu(x1) + elif activation == "gelu": + x1 = F.gelu(x1) + else: + raise ValueError(f"Unsupported activation: {activation=}") x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype)) @@ -58,6 +64,7 @@ def moe_forward_native( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: topk_weights, topk_ids = select_experts( @@ -84,6 +91,13 @@ def moe_forward_native( sorted_tokens = x[idxs // topk_ids.shape[1]] tokens_per_expert = tokens_per_expert.cpu().numpy() + if activation == "silu": + act = SiluAndMul() + elif activation == "gelu": + act = GeluAndMul() + else: + raise ValueError(f"Unsupported activation: {activation=}") + outputs = [] start_idx = 0 for i, num_tokens in enumerate(tokens_per_expert): @@ -96,7 +110,7 @@ def moe_forward_native( layer_w2_weight = layer.w2_weight[i] gate_up = F.linear(tokens_for_this_expert, layer_w13_weight) - gate_up = SiluAndMul()(gate_up) + gate_up = act(gate_up) expert_out = F.linear(gate_up, layer_w2_weight) outputs.append(expert_out) start_idx = end_idx diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index c0d55808558..32c8fcbb625 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -711,6 +711,7 @@ def inplace_fused_experts( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, @@ -726,6 +727,7 @@ def inplace_fused_experts( topk_weights, topk_ids, True, + activation, use_fp8_w8a8, use_int8_w8a16, w1_scale, @@ -742,6 +744,7 @@ def inplace_fused_experts_fake( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, @@ -767,6 +770,7 @@ def outplace_fused_experts( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, @@ -782,6 +786,7 @@ def outplace_fused_experts( topk_weights, topk_ids, False, + activation, use_fp8_w8a8, use_int8_w8a16, w1_scale, @@ -798,6 +803,7 @@ def outplace_fused_experts_fake( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, @@ -824,6 +830,7 @@ def fused_experts( topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, @@ -839,6 +846,7 @@ def fused_experts( w2, topk_weights, topk_ids, + activation, use_fp8_w8a8, use_int8_w8a16, w1_scale, @@ -855,6 +863,7 @@ def fused_experts( w2, topk_weights, topk_ids, + activation, use_fp8_w8a8, use_int8_w8a16, w1_scale, @@ -872,6 +881,7 @@ def fused_experts_impl( topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, @@ -986,7 +996,12 @@ def fused_experts_impl( block_shape=block_shape, ) - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + if activation == "silu": + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + elif activation == "gelu": + ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + else: + raise ValueError(f"Unsupported activation: {activation=}") invoke_fused_moe_kernel( intermediate_cache2, @@ -1042,6 +1057,7 @@ def fused_moe( topk: int, renormalize: bool, inplace: bool = False, + activation: str = "silu", use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, @@ -1111,6 +1127,7 @@ def fused_moe( topk_weights, topk_ids, inplace=inplace, + activation=activation, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, w1_scale=w1_scale, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 75d4c5ead65..06374563bd4 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -126,6 +126,7 @@ def apply( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: return self.forward( x=x, @@ -138,6 +139,7 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, correction_bias=correction_bias, + activation=activation, ) def forward_cuda( @@ -152,6 +154,7 @@ def forward_cuda( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: topk_weights, topk_ids = select_experts( hidden_states=x, @@ -169,6 +172,8 @@ def forward_cuda( import ater from ater.fused_moe import fused_experts_ck + assert activation == "silu", f"{activation=} is not supported." + return fused_experts_ck( hidden_states=x, w1=layer.w13_weight, @@ -184,6 +189,7 @@ def forward_cuda( topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, + activation=activation, ) def forward_cpu( @@ -256,6 +262,7 @@ def __init__( prefix: str = "", custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", use_presharded_weights: bool = False, ): super().__init__() @@ -589,6 +596,7 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): num_expert_group=self.num_expert_group, custom_routing_function=self.custom_routing_function, correction_bias=self.correction_bias, + activation=self.activation, ) if self.reduce_results and self.tp_size > 1: diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index bd59352a796..b0b5b8952a1 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -763,8 +763,8 @@ def apply( num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: - from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.topk import select_experts @@ -785,6 +785,8 @@ def apply( import ater from ater.fused_moe import fused_experts_ck + assert activation == "silu", f"{activation=} is not supported." + return fused_experts_ck( x, layer.w13_weight, @@ -815,6 +817,7 @@ def apply( topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, + activation=activation, use_fp8_w8a8=True, w1_scale=( layer.w13_weight_scale_inv diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index c13d3e25368..0471e37d982 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -133,6 +133,7 @@ def __init__( renormalize=False, quant_config=quant_config, tp_size=tp_size, + activation="gelu", use_presharded_weights=use_presharded_weights, ) diff --git a/test/srt/test_fp8_kernel.py b/test/srt/test_fp8_kernel.py index bd2d5d16815..fe92bfd0769 100644 --- a/test/srt/test_fp8_kernel.py +++ b/test/srt/test_fp8_kernel.py @@ -2,8 +2,6 @@ import torch -from sglang.srt.layers.activation import SiluAndMul -from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.quantization.fp8_kernel import ( per_token_group_quant_fp8, w8a8_block_fp8_matmul, From 559fb792ebd7aba108419d380fb02755a14d679b Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 27 Jan 2025 00:00:07 -0800 Subject: [PATCH 2/3] Fix --- python/sglang/srt/layers/moe/fused_moe_triton/layer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 06374563bd4..b71a878a0ba 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -286,6 +286,7 @@ def __init__( self.topk_group = topk_group self.custom_routing_function = custom_routing_function self.correction_bias = correction_bias + self.activation = activation if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( From 7ef80eaa99013a73d02edfc9225639093226b9af Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 27 Jan 2025 00:08:07 -0800 Subject: [PATCH 3/3] Fix activation --- python/sglang/srt/layers/moe/ep_moe/layer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 8f5a71dff8c..20e07d3a597 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -114,6 +114,7 @@ def __init__( tp_size: Optional[int] = None, prefix: str = "", correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ): super().__init__() @@ -140,6 +141,7 @@ def __init__( self.num_expert_group = num_expert_group self.topk_group = topk_group self.correction_bias = correction_bias + self.activation = activation if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod() @@ -166,6 +168,7 @@ def __init__( def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None + assert self.activation == "silu" if self.grouped_gemm_runner is None: self.grouped_gemm_runner = GroupedGemmRunner(