diff --git a/python/pyproject.toml b/python/pyproject.toml index 97e0771cd90..d4063cf016b 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -27,7 +27,7 @@ runtime_common = [ ] srt = [ "sglang[runtime_common]", "cuda-python", - "sgl-kernel>=0.0.2.post18", "torch", "vllm==0.6.4.post1", + "sgl-kernel>=0.0.3", "torch", "vllm==0.6.4.post1", "flashinfer==0.1.6" ] diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index ebb0652c5d2..d69d854ab2e 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -20,10 +20,10 @@ import torch.nn as nn import torch.nn.functional as F -from sglang.srt.utils import is_flashinfer_available +from sglang.srt.utils import is_cuda_available -if is_flashinfer_available(): - from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul +if is_cuda_available(): + from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul from vllm.model_executor.custom_op import CustomOp @@ -149,8 +149,8 @@ def get_act_fn( return act_fn -if not is_flashinfer_available(): +if not is_cuda_available(): logger.info( - "FlashInfer is not available on Non-NV platforms. Fallback to other kernel libraries." + "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries." ) from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index bd95b9bccce..207ba8d1b7a 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -19,10 +19,10 @@ import torch import torch.nn as nn -from sglang.srt.utils import is_flashinfer_available +from sglang.srt.utils import is_cuda_available -if is_flashinfer_available(): - from flashinfer.norm import ( +if is_cuda_available(): + from sgl_kernel import ( fused_add_rmsnorm, gemma_fused_add_rmsnorm, gemma_rmsnorm, @@ -121,8 +121,8 @@ def forward_cuda( return out -if not is_flashinfer_available(): +if not is_cuda_available(): logger.info( - "FlashInfer is not available on Non-NV platforms. Fallback to other kernel libraries." + "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries." ) from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 3173d533d16..b24bfc8dacf 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -10,14 +10,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo -from sglang.srt.utils import ( - crash_on_warnings, - get_bool_env_var, - is_flashinfer_available, -) - -if is_flashinfer_available(): - from flashinfer.sampling import ( +from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda_available + +if is_cuda_available(): + from sgl_kernel import ( min_p_sampling_from_probs, top_k_renorm_prob, top_k_top_p_sampling_from_probs, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 17d7fcf8924..4384410476c 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -56,12 +56,12 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import is_flashinfer_available, is_hip +from sglang.srt.utils import is_cuda_available, is_hip is_hip_ = is_hip() -if is_flashinfer_available(): - from flashinfer import bmm_fp8 +if is_cuda_available(): + from sgl_kernel import bmm_fp8 class DeepseekV2MLP(nn.Module): diff --git a/python/sglang/srt/models/minicpm3.py b/python/sglang/srt/models/minicpm3.py index 118be8ff6c8..31ea7cd9f25 100644 --- a/python/sglang/srt/models/minicpm3.py +++ b/python/sglang/srt/models/minicpm3.py @@ -40,10 +40,10 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import is_flashinfer_available +from sglang.srt.utils import is_cuda_available -if is_flashinfer_available(): - from flashinfer import bmm_fp8 +if is_cuda_available(): + from sgl_kernel import bmm_fp8 class MiniCPM3MLP(nn.Module):