From 8db776f049732141d1acd6f0c7c24d2297974f31 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 1 Feb 2025 19:31:47 +0800 Subject: [PATCH] support QuickGELU (#3250) --- python/sglang/srt/layers/activation.py | 9 +++++++++ python/sglang/srt/models/qwen2_vl.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 08ea91b9c1f..82c39c2acbc 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -72,6 +72,15 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: return out +class QuickGELU(CustomOp): + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + return x * torch.sigmoid(1.702 * x) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + # TODO(zhyncs): Implement the CUDA kernel for QuickGELU in sgl-kernel + return self.forward_native(x) + + class ScaledActivation(nn.Module): """An activation function with post-scale parameters. diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index 365891544e0..adc50508190 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -31,10 +31,10 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from vllm.model_executor.layers.activation import QuickGELU from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig from sglang.srt.hf_transformers_utils import get_processor +from sglang.srt.layers.activation import QuickGELU from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear from sglang.srt.layers.logits_processor import LogitsProcessor