From 5317902670fcedc59861b41e1fa36a49866495db Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Sat, 1 Feb 2025 16:07:54 +0800 Subject: [PATCH 1/7] Add test for fp8 torch compile (#3246) --- test/srt/test_mla.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/srt/test_mla.py b/test/srt/test_mla.py index 34bc4b44645..6305732509b 100644 --- a/test/srt/test_mla.py +++ b/test/srt/test_mla.py @@ -62,7 +62,12 @@ def setUpClass(cls): cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--trust-remote-code"], + other_args=[ + "--trust-remote-code", + "--enable-torch-compile", + "--cuda-graph-max-bs", + "2", + ], ) @classmethod From 17dbf976c58de83ce1d410a177954d60278b3505 Mon Sep 17 00:00:00 2001 From: HAI Date: Sat, 1 Feb 2025 01:27:43 -0800 Subject: [PATCH 2/7] update ENV to ROCm dockers (#3248) --- docker/Dockerfile.rocm | 1 + 1 file changed, 1 insertion(+) diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index af9f9e24df7..e1a242c87fc 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -58,6 +58,7 @@ RUN git clone ${ATER_REPO} \ # Performance environment variable. ENV HIP_FORCE_DEV_KERNARG=1 +ENV HSA_NO_SCRATCH_RECLAIM=1 ENV SGLANG_SET_CPU_AFFINITY=1 ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1 ENV NCCL_MIN_NCHANNELS=112 From 4eb4b401cc552cab162165e22e1428086eb0f874 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 1 Feb 2025 18:56:44 +0800 Subject: [PATCH 3/7] update and simplify CustomOp (#3249) --- python/sglang/srt/custom_op.py | 40 +++++++++++++++++++ python/sglang/srt/layers/activation.py | 6 +-- python/sglang/srt/layers/custom_op_util.py | 25 ------------ python/sglang/srt/layers/layernorm.py | 6 +-- python/sglang/srt/layers/moe/ep_moe/layer.py | 4 +- .../srt/layers/moe/fused_moe_triton/layer.py | 4 +- python/sglang/srt/layers/rotary_embedding.py | 4 +- .../srt/model_executor/cuda_graph_runner.py | 2 +- 8 files changed, 46 insertions(+), 45 deletions(-) create mode 100644 python/sglang/srt/custom_op.py delete mode 100644 python/sglang/srt/layers/custom_op_util.py diff --git a/python/sglang/srt/custom_op.py b/python/sglang/srt/custom_op.py new file mode 100644 index 00000000000..a702e8f822c --- /dev/null +++ b/python/sglang/srt/custom_op.py @@ -0,0 +1,40 @@ +import torch +from torch import nn + +_is_cuda = torch.cuda.is_available() and torch.version.cuda +_is_rocm = torch.cuda.is_available() and torch.version.hip + + +class CustomOp(nn.Module): + def __init__(self): + super().__init__() + self._forward_method = self.dispatch_forward() + + def forward(self, *args, **kwargs): + return self._forward_method(*args, **kwargs) + + def forward_native(self, *args, **kwargs): + raise NotImplementedError + + def forward_cuda(self, *args, **kwargs): + raise NotImplementedError + + def forward_hip(self, *args, **kwargs): + raise NotImplementedError + + def forward_xpu(self, *args, **kwargs): + return self.forward_native(*args, **kwargs) + + def forward_hpu(self, *args, **kwargs): + return self.forward_native(*args, **kwargs) + + def forward_cpu(self, *args, **kwargs): + return self.forward_native(*args, **kwargs) + + def dispatch_forward(self): + if _is_cuda: + return self.forward_cuda + elif _is_rocm: + return self.forward_hip + else: + return self.forward_native diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index d69d854ab2e..08ea91b9c1f 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -25,21 +25,18 @@ if is_cuda_available(): from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul -from vllm.model_executor.custom_op import CustomOp - +from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import ( divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from sglang.srt.layers.custom_op_util import register_custom_op from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.utils import set_weight_attrs logger = logging.getLogger(__name__) -@register_custom_op("sglang_silu_and_mul") class SiluAndMul(CustomOp): def forward_native(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 @@ -53,7 +50,6 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: return out -@register_custom_op("sglang_gelu_and_mul") class GeluAndMul(CustomOp): def __init__(self, approximate="tanh"): super().__init__() diff --git a/python/sglang/srt/layers/custom_op_util.py b/python/sglang/srt/layers/custom_op_util.py deleted file mode 100644 index 92e186cd207..00000000000 --- a/python/sglang/srt/layers/custom_op_util.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from vllm.model_executor.custom_op import CustomOp - - -def register_custom_op(op_name): - def decorator(cls): - if hasattr(CustomOp, "register"): - return CustomOp.register(op_name)(cls) - else: - return cls - - return decorator diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 207ba8d1b7a..e3b23a2a926 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -29,14 +29,11 @@ rmsnorm, ) -from vllm.model_executor.custom_op import CustomOp - -from sglang.srt.layers.custom_op_util import register_custom_op +from sglang.srt.custom_op import CustomOp logger = logging.getLogger(__name__) -@register_custom_op("sglang_rmsnorm") class RMSNorm(CustomOp): def __init__( self, @@ -79,7 +76,6 @@ def forward_native( return x, residual -@register_custom_op("sglang_gemma_rmsnorm") class GemmaRMSNorm(CustomOp): def __init__( self, diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index bc927621a84..4d6040646b3 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -4,13 +4,12 @@ import torch from torch.nn import Module from vllm import _custom_ops as ops -from vllm.model_executor.custom_op import CustomOp +from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from sglang.srt.layers.custom_op_util import register_custom_op from sglang.srt.layers.moe.ep_moe.kernels import ( grouped_gemm_triton, post_reorder_triton_kernel, @@ -407,7 +406,6 @@ def _load_fp8_scale( param_data[expert_id] = loaded_weight -@register_custom_op("sglang_unquantized_ep_moe") class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp): def create_weights( self, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index b71a878a0ba..dc7152da934 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -5,14 +5,13 @@ from typing import Callable, List, Optional, Tuple import torch -from vllm.model_executor.custom_op import CustomOp +from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from sglang.srt.layers.custom_op_util import register_custom_op from sglang.srt.layers.moe.fused_moe_native import moe_forward_native from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import ( @@ -67,7 +66,6 @@ def apply( raise NotImplementedError -@register_custom_op("sglang_unquantized_fused_moe") class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 7093bb90d81..ef8a96c9854 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -7,9 +7,8 @@ import torch import torch.nn as nn from vllm import _custom_ops as ops -from vllm.model_executor.custom_op import CustomOp -from sglang.srt.layers.custom_op_util import register_custom_op +from sglang.srt.custom_op import CustomOp from sglang.srt.utils import is_cuda_available _is_cuda_available = is_cuda_available() @@ -59,7 +58,6 @@ def _apply_rotary_emb( return torch.stack((o1, o2), dim=-1).flatten(-2) -@register_custom_op("sglang_rotary_embedding") class RotaryEmbedding(CustomOp): """Original rotary positional embedding.""" diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 93b4d0ea57a..69615b8ff31 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -21,8 +21,8 @@ import torch import tqdm -from vllm.model_executor.custom_op import CustomOp +from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture from sglang.srt.layers.logits_processor import LogitsProcessorOutput From 8db776f049732141d1acd6f0c7c24d2297974f31 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 1 Feb 2025 19:31:47 +0800 Subject: [PATCH 4/7] support QuickGELU (#3250) --- python/sglang/srt/layers/activation.py | 9 +++++++++ python/sglang/srt/models/qwen2_vl.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 08ea91b9c1f..82c39c2acbc 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -72,6 +72,15 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: return out +class QuickGELU(CustomOp): + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + return x * torch.sigmoid(1.702 * x) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + # TODO(zhyncs): Implement the CUDA kernel for QuickGELU in sgl-kernel + return self.forward_native(x) + + class ScaledActivation(nn.Module): """An activation function with post-scale parameters. diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index 365891544e0..adc50508190 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -31,10 +31,10 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from vllm.model_executor.layers.activation import QuickGELU from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig from sglang.srt.hf_transformers_utils import get_processor +from sglang.srt.layers.activation import QuickGELU from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear from sglang.srt.layers.logits_processor import LogitsProcessor From ad6740977b0358caeac2606936aa18e0513b2a11 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 1 Feb 2025 19:47:44 +0800 Subject: [PATCH 5/7] add contact us in README (#3251) --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index b27271a1810..b0f28e985a5 100644 --- a/README.md +++ b/README.md @@ -60,5 +60,9 @@ Learn more in the release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-s ## Adoption and Sponsorship The project is supported by (alphabetically): AMD, Baseten, Cursor, DataCrunch, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, LMSYS.org, Meituan, Novita AI, NVIDIA, RunPod, Stanford, UC Berkeley, UCLA, xAI, 01.AI. +## Contact Us + +For enterprises interested in adopting or deploying SGLang at scale, including technical consulting, sponsorship opportunities, or partnership inquiries, please contact us at contact@sglang.ai or business@sglang.ai. + ## Acknowledgment and Citation We learned the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql). Please cite the paper, [SGLang: Efficient Execution of Structured Language Model Programs](https://arxiv.org/abs/2312.07104), if you find the project useful. From f2b3a3188ed5504c02b4f18fbae5c1ae49babe40 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 1 Feb 2025 21:19:15 +0800 Subject: [PATCH 6/7] Update README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b0f28e985a5..4b17633d817 100644 --- a/README.md +++ b/README.md @@ -62,7 +62,7 @@ The project is supported by (alphabetically): AMD, Baseten, Cursor, DataCrunch, ## Contact Us -For enterprises interested in adopting or deploying SGLang at scale, including technical consulting, sponsorship opportunities, or partnership inquiries, please contact us at contact@sglang.ai or business@sglang.ai. +For enterprises interested in adopting or deploying SGLang at scale, including technical consulting, sponsorship opportunities, or partnership inquiries, please contact us at contact@sglang.ai. ## Acknowledgment and Citation We learned the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql). Please cite the paper, [SGLang: Efficient Execution of Structured Language Model Programs](https://arxiv.org/abs/2312.07104), if you find the project useful. From 959dca4fc7d720b8885e74761f7b098bed2bdeb7 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 1 Feb 2025 22:23:09 +0800 Subject: [PATCH 7/7] use srt VocabParallelEmbedding (#3252) --- python/sglang/srt/lora/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index c8cbe36602b..871c1a2291f 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -23,7 +23,6 @@ import torch from torch import nn -from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.linear import ( ColumnParallelLinear, @@ -31,6 +30,7 @@ QKVParallelLinear, RowParallelLinear, ) +from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_loader.loader import DefaultModelLoader