Skip to content

Commit

Permalink
Merge branch 'main' into fix_min_p
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaochenyang20 authored Feb 1, 2025
2 parents 3f61c4a + 959dca4 commit 18e3d04
Show file tree
Hide file tree
Showing 13 changed files with 68 additions and 48 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,9 @@ Learn more in the release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-s
## Adoption and Sponsorship
The project is supported by (alphabetically): AMD, Baseten, Cursor, DataCrunch, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, LMSYS.org, Meituan, Novita AI, NVIDIA, RunPod, Stanford, UC Berkeley, UCLA, xAI, 01.AI.

## Contact Us

For enterprises interested in adopting or deploying SGLang at scale, including technical consulting, sponsorship opportunities, or partnership inquiries, please contact us at [email protected].

## Acknowledgment and Citation
We learned the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql). Please cite the paper, [SGLang: Efficient Execution of Structured Language Model Programs](https://arxiv.org/abs/2312.07104), if you find the project useful.
1 change: 1 addition & 0 deletions docker/Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ RUN git clone ${ATER_REPO} \
# Performance environment variable.

ENV HIP_FORCE_DEV_KERNARG=1
ENV HSA_NO_SCRATCH_RECLAIM=1
ENV SGLANG_SET_CPU_AFFINITY=1
ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1
ENV NCCL_MIN_NCHANNELS=112
Expand Down
40 changes: 40 additions & 0 deletions python/sglang/srt/custom_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
from torch import nn

_is_cuda = torch.cuda.is_available() and torch.version.cuda
_is_rocm = torch.cuda.is_available() and torch.version.hip


class CustomOp(nn.Module):
def __init__(self):
super().__init__()
self._forward_method = self.dispatch_forward()

def forward(self, *args, **kwargs):
return self._forward_method(*args, **kwargs)

def forward_native(self, *args, **kwargs):
raise NotImplementedError

def forward_cuda(self, *args, **kwargs):
raise NotImplementedError

def forward_hip(self, *args, **kwargs):
raise NotImplementedError

def forward_xpu(self, *args, **kwargs):
return self.forward_native(*args, **kwargs)

def forward_hpu(self, *args, **kwargs):
return self.forward_native(*args, **kwargs)

def forward_cpu(self, *args, **kwargs):
return self.forward_native(*args, **kwargs)

def dispatch_forward(self):
if _is_cuda:
return self.forward_cuda
elif _is_rocm:
return self.forward_hip
else:
return self.forward_native
15 changes: 10 additions & 5 deletions python/sglang/srt/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,18 @@
if is_cuda_available():
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul

from vllm.model_executor.custom_op import CustomOp

from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import (
divide,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.utils import set_weight_attrs

logger = logging.getLogger(__name__)


@register_custom_op("sglang_silu_and_mul")
class SiluAndMul(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
Expand All @@ -53,7 +50,6 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
return out


@register_custom_op("sglang_gelu_and_mul")
class GeluAndMul(CustomOp):
def __init__(self, approximate="tanh"):
super().__init__()
Expand All @@ -76,6 +72,15 @@ def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
return out


class QuickGELU(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(1.702 * x)

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
# TODO(zhyncs): Implement the CUDA kernel for QuickGELU in sgl-kernel
return self.forward_native(x)


class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters.
Expand Down
25 changes: 0 additions & 25 deletions python/sglang/srt/layers/custom_op_util.py

This file was deleted.

6 changes: 1 addition & 5 deletions python/sglang/srt/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,11 @@
rmsnorm,
)

from vllm.model_executor.custom_op import CustomOp

from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.custom_op import CustomOp

logger = logging.getLogger(__name__)


@register_custom_op("sglang_rmsnorm")
class RMSNorm(CustomOp):
def __init__(
self,
Expand Down Expand Up @@ -79,7 +76,6 @@ def forward_native(
return x, residual


@register_custom_op("sglang_gemma_rmsnorm")
class GemmaRMSNorm(CustomOp):
def __init__(
self,
Expand Down
4 changes: 1 addition & 3 deletions python/sglang/srt/layers/moe/ep_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
import torch
from torch.nn import Module
from vllm import _custom_ops as ops
from vllm.model_executor.custom_op import CustomOp

from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.layers.moe.ep_moe.kernels import (
grouped_gemm_triton,
post_reorder_triton_kernel,
Expand Down Expand Up @@ -407,7 +406,6 @@ def _load_fp8_scale(
param_data[expert_id] = loaded_weight


@register_custom_op("sglang_unquantized_ep_moe")
class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
def create_weights(
self,
Expand Down
4 changes: 1 addition & 3 deletions python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
from typing import Callable, List, Optional, Tuple

import torch
from vllm.model_executor.custom_op import CustomOp

from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import (
Expand Down Expand Up @@ -67,7 +66,6 @@ def apply(
raise NotImplementedError


@register_custom_op("sglang_unquantized_fused_moe")
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""

Expand Down
4 changes: 1 addition & 3 deletions python/sglang/srt/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
import torch
import torch.nn as nn
from vllm import _custom_ops as ops
from vllm.model_executor.custom_op import CustomOp

from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda_available

_is_cuda_available = is_cuda_available()
Expand Down Expand Up @@ -59,7 +58,6 @@ def _apply_rotary_emb(
return torch.stack((o1, o2), dim=-1).flatten(-2)


@register_custom_op("sglang_rotary_embedding")
class RotaryEmbedding(CustomOp):
"""Original rotary positional embedding."""

Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/lora/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@

import torch
from torch import nn
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding

from sglang.srt.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_loader.loader import DefaultModelLoader


Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/model_executor/cuda_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

import torch
import tqdm
from vllm.model_executor.custom_op import CustomOp

from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
Expand Down
2 changes: 1 addition & 1 deletion python/sglang/srt/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from vllm.model_executor.layers.activation import QuickGELU

from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.activation import QuickGELU
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.logits_processor import LogitsProcessor
Expand Down
7 changes: 6 additions & 1 deletion test/srt/test_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,12 @@ def setUpClass(cls):
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--trust-remote-code"],
other_args=[
"--trust-remote-code",
"--enable-torch-compile",
"--cuda-graph-max-bs",
"2",
],
)

@classmethod
Expand Down

0 comments on commit 18e3d04

Please sign in to comment.