From 1dde914a1e6b8a590dbb02a7c69c53b438b7c4c0 Mon Sep 17 00:00:00 2001 From: Stephen Baione Date: Thu, 24 Oct 2024 09:39:51 -0500 Subject: [PATCH 1/5] Add support for `Shortfin` backend --- .../quick_start/shortfin_example_chat.py | 66 +++++++++++++ python/sglang/__init__.py | 3 +- python/sglang/lang/backend/shortfin.py | 95 +++++++++++++++++++ python/sglang/lang/ir.py | 17 ++++ python/sglang/utils.py | 10 +- 5 files changed, 187 insertions(+), 4 deletions(-) create mode 100644 examples/frontend_language/quick_start/shortfin_example_chat.py create mode 100644 python/sglang/lang/backend/shortfin.py diff --git a/examples/frontend_language/quick_start/shortfin_example_chat.py b/examples/frontend_language/quick_start/shortfin_example_chat.py new file mode 100644 index 00000000000..3e943477dc0 --- /dev/null +++ b/examples/frontend_language/quick_start/shortfin_example_chat.py @@ -0,0 +1,66 @@ +""" +Usage: +# Prior to running this script, you need to have a Shortfin server running. +# Build: +# https://github.com/nod-ai/SHARK-Platform/blob/main/shortfin/README.md +# Run: +# https://github.com/nod-ai/SHARK-Platform/blob/main/shortfin/python/shortfin_apps/llm/README.md + +python3 shortfin_example_chat.py --base_url +""" + +import argparse +import os + +import sglang as sgl + + +@sgl.function +def multi_turn_question(s, question_1, question_2): + s += sgl.system("You are a helpful assistant.") + s += sgl.user(question_1) + s += sgl.assistant(sgl.gen("answer_1", max_tokens=256)) + s += sgl.user(question_2) + s += sgl.assistant(sgl.gen("answer_2", max_tokens=256)) + + +def single(): + state = multi_turn_question.run( + question_1="What is the capital of the United States?", + question_2="List two local attractions.", + ) + + for m in state.messages(): + print(m["role"], ":", m["content"]) + + print("\n-- answer_1 --\n", state["answer_1"]) + + +def stream(): + state = multi_turn_question.run( + question_1="What is the capital of the United States?", + question_2="List two local attractions.", + stream=True, + ) + + for out in state.text_iter(): + print(out, end="", flush=True) + print() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base_url", default="http://localhost:8000") + args = parser.parse_args() + base_url = args.base_url + + backend = sgl.Shortfin(base_url=base_url) + sgl.set_default_backend(backend) + + # Run a single request + print("\n========== single ==========\n") + single() + + # Stream output + print("\n========== stream ==========\n") + stream() diff --git a/python/sglang/__init__.py b/python/sglang/__init__.py index 3c4457c983a..c4d5108c57b 100644 --- a/python/sglang/__init__.py +++ b/python/sglang/__init__.py @@ -74,5 +74,6 @@ LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM") OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI") VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI") +Shortfin = LazyImport("sglang.lang.backend.shortfin", "Shortfin") -__all__ += ["Anthropic", "LiteLLM", "OpenAI", "VertexAI", "RuntimeEndpoint"] +__all__ += ["Anthropic", "LiteLLM", "OpenAI", "VertexAI", "Shortfin", "RuntimeEndpoint"] diff --git a/python/sglang/lang/backend/shortfin.py b/python/sglang/lang/backend/shortfin.py new file mode 100644 index 00000000000..d67a49e0d12 --- /dev/null +++ b/python/sglang/lang/backend/shortfin.py @@ -0,0 +1,95 @@ +import json +from typing import Optional + +from sglang.lang.backend.base_backend import BaseBackend +from sglang.lang.chat_template import get_chat_template_by_model_path +from sglang.lang.interpreter import StreamExecutor +from sglang.lang.ir import SglSamplingParams +from sglang.utils import http_request + + +class Shortfin(BaseBackend): + def __init__( + self, + chat_template=None, + base_url: Optional[str] = None, + timeout: Optional[float] = None, + ): + super().__init__() + + if base_url is None: + raise ValueError("`base_url` is required for Shortfin backend") + + self.chat_template = chat_template or get_chat_template_by_model_path("default") + + self.client_params = {"base_url": base_url, "timeout": timeout} + + def _make_generate_request(self, shortfin_kwargs, stream=False): + resp = http_request( + f"{self.client_params['base_url']}/generate", + json=shortfin_kwargs, + timeout=self.client_params["timeout"], + stream=stream, + ) + self._assert_success(resp) + return resp + + def _assert_success(self, res): + if res.status_code != 200: + raise RuntimeError(res.json()) + + def _clean_response_message(self, text): + return text.replace(b"data: ", b"").strip(b"\n") + + def get_chat_template(self): + return self.chat_template + + def generate( + self, + s: StreamExecutor, + sampling_params: SglSamplingParams, + ): + shortfin_kwargs = sampling_params.to_shortfin_kwargs() + + messages = s.text_ + shortfin_kwargs["text"] = messages + + resp = http_request( + f"{self.client_params['base_url']}/generate", + json=shortfin_kwargs, + timeout=self.client_params["timeout"], + ) + self._assert_success(resp) + + response_message = resp.resp.read() + response_message = self._clean_response_message(response_message) + return response_message.decode("utf-8"), {} + + def generate_stream( + self, + s: StreamExecutor, + sampling_params: SglSamplingParams, + ): + shortfin_kwargs = sampling_params.to_shortfin_kwargs() + shortfin_kwargs["stream"] = True + + messages = s.text_ + shortfin_kwargs["text"] = messages + + resp = http_request( + f"{self.client_params['base_url']}/generate", + json=shortfin_kwargs, + stream=True, + timeout=self.client_params["timeout"], + ) + self._assert_success(resp) + + prefix = b"" + for chunk in resp: + if chunk == b"data: [DONE]\n\n": + break + text = chunk[len(prefix) :] + prefix += text.strip(b"\n") + text = self._clean_response_message(text) + if text is not None: + yield text.decode("utf-8"), {} diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index 5c03db06819..cb47ee6d78b 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -112,6 +112,23 @@ def to_litellm_kwargs(self): "presence_penalty": self.presence_penalty, } + def to_shortfin_kwargs(self): + kwargs = { + "return_logprob": self.return_logprob, + "logprob_start_len": self.logprob_start_len, + "top_logprobs_num": self.top_logprobs_num, + } + kwargs["return_text_in_logprobs"] = ( + self.return_text_in_logprobs + if self.return_text_in_logprobs is not None + else False + ) + kwargs["sampling_params"] = { + "max_tokens": self.max_new_tokens, + "temperature": self.temperature, + } + return kwargs + def to_srt_kwargs(self): return { "max_new_tokens": self.max_new_tokens, diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 621efb5373c..4d466a7cb0c 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -75,7 +75,7 @@ def status_code(self): return self.resp.status -def http_request(url, json=None, stream=False, api_key=None, verify=None): +def http_request(url, json=None, stream=False, api_key=None, verify=None, timeout=None): """A faster version of requests.post with low-level urllib API.""" headers = {"Content-Type": "application/json; charset=utf-8"} @@ -84,7 +84,9 @@ def http_request(url, json=None, stream=False, api_key=None, verify=None): headers["Authorization"] = f"Bearer {api_key}" if stream: - return requests.post(url, json=json, stream=True, headers=headers) + return requests.post( + url, json=json, stream=True, headers=headers, timeout=timeout + ) else: req = urllib.request.Request(url, headers=headers) if json is None: @@ -93,7 +95,9 @@ def http_request(url, json=None, stream=False, api_key=None, verify=None): data = bytes(dumps(json), encoding="utf-8") try: - resp = urllib.request.urlopen(req, data=data, cafile=verify) + resp = urllib.request.urlopen( + req, data=data, cafile=verify, timeout=timeout + ) return HttpResponse(resp) except urllib.error.HTTPError as e: return HttpResponse(e) From 6b4aae19efe270da49f735c8e1d7b876440d7cb2 Mon Sep 17 00:00:00 2001 From: Stephen Baione Date: Thu, 24 Oct 2024 09:56:14 -0500 Subject: [PATCH 2/5] Add test for shortfin backend --- test/lang/test_shortfin_backend.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 test/lang/test_shortfin_backend.py diff --git a/test/lang/test_shortfin_backend.py b/test/lang/test_shortfin_backend.py new file mode 100644 index 00000000000..806a358a387 --- /dev/null +++ b/test/lang/test_shortfin_backend.py @@ -0,0 +1,25 @@ +import os +import unittest + +from sglang import Shortfin, set_default_backend +from sglang.test.test_programs import test_mt_bench, test_stream + + +class TestShortfinBackend(unittest.TestCase): + chat_backend = None + + @classmethod + def setUpClass(cls): + base_url = os.environ["SHORTFIN_BASE_URL"] + cls.chat_backend = Shortfin(base_url=base_url) + set_default_backend(cls.chat_backend) + + def test_mt_bench(self): + test_mt_bench() + + def test_stream(self): + test_stream() + + +if __name__ == "__main__": + unittest.main() From 59f1a53e6f3515d0bb128d8c56d626de3aa5dee3 Mon Sep 17 00:00:00 2001 From: Stephen Baione Date: Fri, 8 Nov 2024 09:30:02 +0000 Subject: [PATCH 3/5] Enable bench_serving script for Shortfin, Add examples for fork and batch, Update max_tokens to max_completion_tokens --- .../quick_start/shortfin_example_chat.py | 48 ++++++++++- python/sglang/bench_serving.py | 86 ++++++++++++++++++- python/sglang/lang/ir.py | 2 +- 3 files changed, 131 insertions(+), 5 deletions(-) diff --git a/examples/frontend_language/quick_start/shortfin_example_chat.py b/examples/frontend_language/quick_start/shortfin_example_chat.py index 3e943477dc0..c6d0f6a5acc 100644 --- a/examples/frontend_language/quick_start/shortfin_example_chat.py +++ b/examples/frontend_language/quick_start/shortfin_example_chat.py @@ -23,6 +23,21 @@ def multi_turn_question(s, question_1, question_2): s += sgl.user(question_2) s += sgl.assistant(sgl.gen("answer_2", max_tokens=256)) +@sgl.function +def tip_suggestion(s): + s += ( + "Here are two tips for staying healthy: " + "1. Balanced Diet. 2. Regular Exercise.\n\n" + ) + + forks = s.fork(2) + for i, f in enumerate(forks): + f += f"Now, expand tip {i+1} into a paragraph:\n" + f += sgl.gen(f"detailed_tip", max_tokens=256, stop="\n\n") + + s += "Tip 1:" + forks[0]["detailed_tip"] + "\n" + s += "Tip 2:" + forks[1]["detailed_tip"] + "\n" + s += "In summary" + sgl.gen("summary") def single(): state = multi_turn_question.run( @@ -35,7 +50,6 @@ def single(): print("\n-- answer_1 --\n", state["answer_1"]) - def stream(): state = multi_turn_question.run( question_1="What is the capital of the United States?", @@ -47,6 +61,30 @@ def stream(): print(out, end="", flush=True) print() +def fork(): + state = tip_suggestion.run() + print(state.text()) + +def batch(): + states = multi_turn_question.run_batch( + [ + { + "question_1": "What is the capital of the United States?", + "question_2": "List two local attractions.", + }, + { + "question_1": "What is the capital of France?", + "question_2": "What is the population of this city?", + }, + ] + ) + + for s in states: + for m in s.messages(): + print(m["role"], m["content"]) + + print() + print() if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -64,3 +102,11 @@ def stream(): # Stream output print("\n========== stream ==========\n") stream() + + # Run a single prompt in parallel + print("\n========== fork ==========\n") + fork() + + # Run a batch of prompts + print("\n========== batch ==========\n") + batch() diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 8bb452cd065..dc3465e0c8b 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -380,6 +380,84 @@ async def async_request_sglang_generate( return output +async def async_request_shortfin_generate( + request_func_input: RequestFuncInput, + pbar: Optional[tqdm] = None, +) -> RequestFuncOutput: + print("SNB: Using Shortfin Generate") + api_url = request_func_input.api_url + prompt = request_func_input.prompt + + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: + payload = { + "text": prompt, + "sampling_params": { + "temperature": 0.0, + "max_completion_tokens": request_func_input.output_len, + "ignore_eos": not args.disable_ignore_eos, + }, + "stream": not args.disable_stream, + **request_func_input.extra_request_body, + } + headers = {} + + output = RequestFuncOutput() + output.prompt_len = request_func_input.prompt_len + + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + try: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") + latency = time.perf_counter() - st + if chunk == "[DONE]": + pass + else: + data = chunk + + # NOTE: Some completion API might have a last + # usage summary response without a token so we + # want to check a token was generated + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text = data + + output.generated_text = generated_text + output.success = True + output.latency = latency + output.output_len = request_func_input.output_len + else: + output.error = response.reason or "" + output.success = False + except Exception: + output.success = False + exc_info = sys.exc_info() + output.error = "".join(traceback.format_exception(*exc_info)) + + if pbar: + pbar.update(1) + return output + + async def async_request_gserver( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, @@ -425,6 +503,7 @@ def get_tokenizer( "sglang": async_request_sglang_generate, "sglang-native": async_request_sglang_generate, "sglang-oai": async_request_openai_completions, + "shortfin": async_request_shortfin_generate, "vllm": async_request_openai_completions, "lmdeploy": async_request_openai_completions, "trt": async_request_trt_llm, @@ -954,6 +1033,7 @@ def run_benchmark(args_: argparse.Namespace): "trt": 8000, "gserver": 9988, "truss": 8080, + "shortfin": 8000, }.get(args.backend, 30000) model_url = ( @@ -962,7 +1042,7 @@ def run_benchmark(args_: argparse.Namespace): else f"http://{args.host}:{args.port}/v1/models" ) - if args.backend in ["sglang", "sglang-native"]: + if args.backend in ["sglang", "sglang-native", "shortfin"]: api_url = ( f"{args.base_url}/generate" if args.base_url @@ -994,7 +1074,7 @@ def run_benchmark(args_: argparse.Namespace): ) # Get model name - if args.model is None: + if args.model is None and args.backend != "shortfin": if args.backend == "truss": print( "Please provide a model with `--model` when using truss backend. e.g. --model meta-llama/Llama-3.1-8B-Instruct" @@ -1011,7 +1091,7 @@ def run_benchmark(args_: argparse.Namespace): ) sys.exit(1) - if args.model is None: + if args.model is None and args.backend != "shortfin": print("No model specified or found. Please provide a model using `--model`.") sys.exit(1) diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index 86b5a6e7ae7..95d7c29751d 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -124,7 +124,7 @@ def to_shortfin_kwargs(self): else False ) kwargs["sampling_params"] = { - "max_tokens": self.max_new_tokens, + "max_completion_tokens": self.max_new_tokens, "temperature": self.temperature, } return kwargs From ded7e00b1fe3892c5933b45b3a8027c2b03382f3 Mon Sep 17 00:00:00 2001 From: Stephen Baione Date: Fri, 8 Nov 2024 09:37:09 +0000 Subject: [PATCH 4/5] Remove debug print statement --- python/sglang/bench_serving.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index dc3465e0c8b..553d135090a 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -384,7 +384,6 @@ async def async_request_shortfin_generate( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: - print("SNB: Using Shortfin Generate") api_url = request_func_input.api_url prompt = request_func_input.prompt From 4fd490c1e43a607265f29626cba9342ca05e1c4e Mon Sep 17 00:00:00 2001 From: Stephen Baione Date: Fri, 31 Jan 2025 18:44:20 +0000 Subject: [PATCH 5/5] Merge branch 'main' of https://github.com/sgl-project/sglang, fix shortfin backend streaming --- .clang-format-ignore | 1 + .devcontainer/Dockerfile | 35 + .devcontainer/devcontainer.json | 24 + .github/CODEOWNERS | 4 +- .github/pull_request_template.md | 7 +- .github/workflows/execute-notebook.yml | 2 +- .github/workflows/experiment-runner.yml | 30 + .../{nightly-eval.yml => nightly-test.yml} | 16 +- .github/workflows/pr-test-rust.yml | 19 +- .github/workflows/pr-test-sgl-kernel.yml | 103 ++ .github/workflows/pr-test.yml | 69 +- .github/workflows/release-docker-amd.yml | 16 +- .github/workflows/release-docker-dev.yml | 35 + .github/workflows/release-docker.yml | 6 +- .github/workflows/release-docs.yml | 4 +- .github/workflows/release-pypi-kernel.yml | 44 + .github/workflows/release-pypi-router.yml | 11 +- .github/workflows/release-whl-kernel.yml | 92 ++ .gitignore | 7 + .gitmodules | 12 + 3rdparty/amd/profiling/PROFILING.md | 2 +- 3rdparty/amd/profiling/server.sh | 2 +- 3rdparty/amd/tuning/TUNING.md | 2 +- 3rdparty/amd/tuning/benchmark_moe_rocm.py | 5 +- Makefile | 15 +- README.md | 35 +- .../bench_in_batch_prefix.py | 130 ++ benchmark/blog_v0_2/405b_sglang.sh | 2 +- benchmark/deepseek_v3/README.md | 128 ++ benchmark/gsm8k/bench_sglang.py | 7 +- benchmark/hellaswag/bench_sglang.py | 7 +- benchmark/hicache/bench_multiturn.py | 334 +++++ .../triton_flashinfer_cudnn.py | 405 ++++++ benchmark/kernels/fused_moe_triton/README.md | 8 +- .../benchmark_deepseekv3_moe_align_blocks.py | 313 ++++ .../benchmark_torch_compile_fused_moe.py | 11 +- ...nchmark_vllm_vs_sglang_fused_moe_triton.py | 11 +- .../tuning_fused_moe_triton.py | 120 +- .../benchmark_lightning_attention_decode.py | 577 ++++++++ .../benchmark_lightning_attention_prefill.py | 603 ++++++++ .../kernels/quantization/bench_int8_quant.py | 94 ++ .../kernels/rmsnorm/benchmark_rmsnorm.py | 231 +++ ...enchmark_write_req_to_token_pool_triton.py | 345 +++++ .../tree_of_thought_deep/bench_sglang.py | 1 + docker/Dockerfile | 18 + docker/Dockerfile.dev | 13 +- docker/Dockerfile.rocm | 29 +- docs/README.md | 82 +- docs/backend/backend.md | 168 --- docs/backend/function_calling.ipynb | 523 +++++++ docs/backend/native_api.ipynb | 70 + docs/backend/offline_engine_api.ipynb | 48 +- docs/backend/openai_api_completions.ipynb | 90 +- docs/backend/openai_api_embeddings.ipynb | 9 +- docs/backend/openai_api_vision.ipynb | 8 +- docs/backend/server_arguments.md | 184 +++ docs/backend/speculative_decoding.ipynb | 182 +++ docs/backend/structured_outputs.ipynb | 598 ++++++++ .../development_guide_using_docker.md | 47 + docs/developer/setup_github_runner.md | 4 +- docs/index.rst | 10 +- docs/references/benchmark_and_profiling.md | 23 +- docs/references/contribution_guide.md | 45 + docs/references/contributor_guide.md | 14 - docs/references/deepseek.md | 56 + docs/references/llama_405B.md | 19 + docs/references/modelscope.md | 28 + docs/references/production_metrics.md | 280 ++-- docs/references/sampling_params.md | 149 +- docs/references/supported_models.md | 12 +- docs/references/torch_compile_cache.md | 13 + docs/router/router.md | 66 +- docs/start/install.md | 19 +- .../quick_start/shortfin_example_chat.py | 7 + .../frontend_language/usage/json_decode.py | 2 +- .../models/character_generation/1/model.py | 4 +- examples/runtime/async_io_api.py | 46 - .../engine/EAGLE_offline_batch_inference.py | 37 + .../runtime/engine/offline_batch_inference.py | 5 + python/pyproject.toml | 33 +- python/sglang/README.md | 1 + python/sglang/__init__.py | 47 +- python/sglang/api.py | 9 +- python/sglang/bench_offline_throughput.py | 81 +- python/sglang/bench_one_batch.py | 67 +- python/sglang/bench_one_batch_server.py | 2 +- python/sglang/bench_serving.py | 185 ++- python/sglang/check_env.py | 188 ++- python/sglang/lang/backend/openai.py | 10 + .../sglang/lang/backend/runtime_endpoint.py | 188 ++- python/sglang/lang/backend/shortfin.py | 18 +- python/sglang/lang/chat_template.py | 82 +- python/sglang/lang/interpreter.py | 72 +- python/sglang/lang/ir.py | 4 +- python/sglang/launch_server.py | 2 +- python/sglang/launch_server_llavavid.py | 25 - python/sglang/llama3_eval.py | 316 ++++ python/sglang/srt/_custom_ops.py | 122 +- python/sglang/srt/aio_rwlock.py | 100 ++ python/sglang/srt/configs/__init__.py | 4 + python/sglang/srt/configs/chatglm.py | 78 + python/sglang/srt/configs/dbrx.py | 279 ++++ python/sglang/srt/configs/device_config.py | 2 +- python/sglang/srt/configs/load_config.py | 1 + python/sglang/srt/configs/model_config.py | 42 +- python/sglang/srt/constrained/__init__.py | 16 - .../srt/constrained/base_grammar_backend.py | 21 + .../srt/constrained/xgrammar_backend.py | 23 +- python/sglang/srt/conversation.py | 15 +- python/sglang/srt/distributed/__init__.py | 6 +- .../srt/distributed/communication_op.py | 3 +- .../device_communicators/cuda_wrapper.py | 3 +- .../device_communicators/custom_all_reduce.py | 154 +- .../custom_all_reduce_utils.py | 4 +- .../device_communicators/hpu_communicator.py | 3 +- .../device_communicators/pynccl.py | 81 +- .../device_communicators/pynccl_wrapper.py | 114 +- .../device_communicators/shm_broadcast.py | 77 +- .../device_communicators/xpu_communicator.py | 3 +- .../sglang/srt/distributed/parallel_state.py | 2 +- python/sglang/srt/distributed/utils.py | 3 +- python/sglang/srt/entrypoints/engine.py | 452 ++++++ python/sglang/srt/entrypoints/http_server.py | 603 ++++++++ python/sglang/srt/function_call_parser.py | 494 +++++++ python/sglang/srt/hf_transformers_utils.py | 23 +- python/sglang/srt/layers/activation.py | 16 +- .../sglang/srt/layers/attention/__init__.py | 29 +- .../attention/double_sparsity_backend.py | 52 - .../layers/attention/flashinfer_backend.py | 478 ++++-- .../layers/attention/torch_native_backend.py | 39 +- .../srt/layers/attention/triton_backend.py | 82 +- .../attention/triton_ops/decode_attention.py | 659 ++++----- .../attention/triton_ops/extend_attention.py | 34 +- .../attention/triton_ops/prefill_attention.py | 6 + python/sglang/srt/layers/attention/vision.py | 407 ++++++ python/sglang/srt/layers/dp_attention.py | 71 + python/sglang/srt/layers/ep_moe/__init__.py | 0 python/sglang/srt/layers/fused_moe_patch.py | 133 -- python/sglang/srt/layers/layernorm.py | 10 +- python/sglang/srt/layers/linear.py | 221 ++- python/sglang/srt/layers/logits_processor.py | 464 +++--- .../moe/ep_moe}/__init__.py | 0 .../srt/layers/{ => moe}/ep_moe/kernels.py | 0 .../srt/layers/{ => moe}/ep_moe/layer.py | 97 +- .../sglang/srt/layers/moe/fused_moe_native.py | 129 ++ .../{ => moe}/fused_moe_triton/__init__.py | 10 +- ...VIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json | 0 ...336,device_name=NVIDIA_A100-SXM4-80GB.json | 0 ...VIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json | 0 ...792,device_name=NVIDIA_A100-SXM4-80GB.json | 0 ...VIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json | 0 ...VIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json | 0 ...072,device_name=NVIDIA_H100_80GB_HBM3.json | 0 ...VIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json | 0 ...584,device_name=NVIDIA_A100-SXM4-80GB.json | 0 ...VIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json | 0 ...168,device_name=NVIDIA_A100-SXM4-80GB.json | 0 ...344,device_name=NVIDIA_A100-SXM4-40GB.json | 0 ...344,device_name=NVIDIA_A100-SXM4-80GB.json | 0 ...344,device_name=NVIDIA_H100_80GB_HBM3.json | 0 ...VIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json | 0 ...336,device_name=NVIDIA_A100-SXM4-80GB.json | 0 ...VIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json | 0 ...792,device_name=NVIDIA_A100-SXM4-80GB.json | 0 ...688,device_name=NVIDIA_A100-SXM4-80GB.json | 0 ...688,device_name=NVIDIA_H100_80GB_HBM3.json | 0 ...VIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json | 0 ...VIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json | 0 ...=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json | 0 ...VIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json | 0 ...584,device_name=NVIDIA_A100-SXM4-80GB.json | 0 ...=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json | 0 ...VIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json | 0 ...168,device_name=NVIDIA_A100-SXM4-80GB.json | 0 ...VIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json | 0 ...=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json | 0 ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 164 +++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...280,device_name=NVIDIA_A100-SXM4-80GB.json | 0 ...280,device_name=NVIDIA_A800-SXM4-80GB.json | 146 ++ ...=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json | 146 ++ ...280,device_name=NVIDIA_H100_80GB_HBM3.json | 62 +- ...evice_name=NVIDIA_H200,dtype=fp8_w8a8.json | 146 ++ .../E=64,N=1280,device_name=NVIDIA_H200.json | 146 ++ ...=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json | 146 ++ ...evice_name=NVIDIA_H200,dtype=fp8_w8a8.json | 146 ++ .../E=64,N=2560,device_name=NVIDIA_H200.json | 146 ++ ...=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json | 146 ++ ...320,device_name=NVIDIA_H100_80GB_HBM3.json | 146 ++ ...evice_name=NVIDIA_H200,dtype=fp8_w8a8.json | 146 ++ .../E=64,N=320,device_name=NVIDIA_H200.json | 146 ++ ...640,device_name=NVIDIA_A100-SXM4-80GB.json | 0 ...640,device_name=NVIDIA_A800-SXM4-80GB.json | 146 ++ ...VIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json | 0 ...=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json | 146 ++ ...640,device_name=NVIDIA_H100_80GB_HBM3.json | 42 +- ...evice_name=NVIDIA_H200,dtype=fp8_w8a8.json | 146 ++ .../E=64,N=640,device_name=NVIDIA_H200.json | 146 ++ ...14336,device_name=AMD_Instinct_MI300X.json | 0 ...=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json | 0 ...evice_name=NVIDIA_H200,dtype=fp8_w8a8.json | 146 ++ .../E=8,N=14336,device_name=NVIDIA_H200.json | 146 ++ ...=1792,device_name=AMD_Instinct_MI300X.json | 0 ...792,device_name=NVIDIA_A100-SXM4-40GB.json | 0 ...792,device_name=NVIDIA_A100-SXM4-80GB.json | 0 ...792,device_name=NVIDIA_H100_80GB_HBM3.json | 0 ...evice_name=NVIDIA_H200,dtype=fp8_w8a8.json | 146 ++ .../E=8,N=1792,device_name=NVIDIA_H200.json | 146 ++ ...048,device_name=NVIDIA_A100-SXM4-80GB.json | 0 ...=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json | 0 ...048,device_name=NVIDIA_H100_80GB_HBM3.json | 0 ...evice_name=NVIDIA_H200,dtype=fp8_w8a8.json | 146 ++ .../E=8,N=2048,device_name=NVIDIA_H200.json | 146 ++ ...=3584,device_name=AMD_Instinct_MI300X.json | 0 ...584,device_name=NVIDIA_A100-SXM4-40GB.json | 0 ...584,device_name=NVIDIA_A100-SXM4-80GB.json | 0 ...VIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json | 0 ...=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json | 0 ...584,device_name=NVIDIA_H100_80GB_HBM3.json | 0 ...evice_name=NVIDIA_H200,dtype=fp8_w8a8.json | 146 ++ .../E=8,N=3584,device_name=NVIDIA_H200.json | 146 ++ .../E=8,N=3584,device_name=NVIDIA_L40S.json | 0 ...me=AMD_Instinct_MI300X,dtype=fp8_w8a8.json | 28 +- ...096,device_name=NVIDIA_A100-SXM4-80GB.json | 0 ...=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json | 0 ...096,device_name=NVIDIA_H100_80GB_HBM3.json | 0 ...evice_name=NVIDIA_H200,dtype=fp8_w8a8.json | 146 ++ .../E=8,N=4096,device_name=NVIDIA_H200.json | 146 ++ ...=7168,device_name=AMD_Instinct_MI300X.json | 0 ...168,device_name=NVIDIA_A100-SXM4-80GB.json | 0 ...=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json | 0 ...168,device_name=NVIDIA_H100_80GB_HBM3.json | 0 ...evice_name=NVIDIA_H200,dtype=fp8_w8a8.json | 146 ++ .../E=8,N=7168,device_name=NVIDIA_H200.json | 146 ++ ...me=AMD_Instinct_MI300X,dtype=fp8_w8a8.json | 0 ...=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json | 0 ...evice_name=NVIDIA_H200,dtype=fp8_w8a8.json | 146 ++ .../{ => moe}/fused_moe_triton/configs/README | 2 + .../{ => moe}/fused_moe_triton/fused_moe.py | 541 +++++-- .../{ => moe}/fused_moe_triton/layer.py | 176 ++- python/sglang/srt/layers/moe/topk.py | 211 +++ python/sglang/srt/layers/parameter.py | 449 ++++++ .../srt/layers/quantization/__init__.py | 101 +- ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 164 +++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 164 +++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 164 +++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 164 +++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 164 +++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 164 +++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 164 +++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 164 +++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 164 +++ ...dtype=fp8_w8a8,block_shape=[128, 128].json | 146 ++ python/sglang/srt/layers/quantization/fp8.py | 396 ++++- .../srt/layers/quantization/fp8_kernel.py | 351 +++++ .../srt/layers/quantization/fp8_utils.py | 97 +- .../srt/layers/quantization/int8_kernel.py | 54 + .../srt/layers/quantization/modelopt_quant.py | 173 +++ .../srt/layers/quantization/w8a8_int8.py | 117 ++ python/sglang/srt/layers/radix_attention.py | 11 +- python/sglang/srt/layers/rotary_embedding.py | 1231 +++++++++++++++- python/sglang/srt/layers/sampler.py | 145 +- python/sglang/srt/layers/torchao_utils.py | 63 +- .../srt/layers/vocab_parallel_embedding.py | 23 +- python/sglang/srt/lora/lora.py | 10 +- .../sglang/srt/managers/cache_controller.py | 307 ++++ .../sglang/srt/managers/configure_logging.py | 46 + .../srt/managers/data_parallel_controller.py | 153 +- .../srt/managers/detokenizer_manager.py | 82 +- python/sglang/srt/managers/image_processor.py | 199 ++- python/sglang/srt/managers/io_struct.py | 171 ++- python/sglang/srt/managers/schedule_batch.py | 274 ++-- python/sglang/srt/managers/schedule_policy.py | 284 +++- python/sglang/srt/managers/scheduler.py | 990 ++++++++----- .../sglang/srt/managers/session_controller.py | 130 +- .../sglang/srt/managers/tokenizer_manager.py | 835 +++++++---- python/sglang/srt/managers/tp_worker.py | 35 +- .../srt/managers/tp_worker_overlap_thread.py | 25 +- python/sglang/srt/managers/utils.py | 44 + .../sglang/srt/mem_cache/base_prefix_cache.py | 8 +- python/sglang/srt/mem_cache/chunk_cache.py | 7 +- python/sglang/srt/mem_cache/memory_pool.py | 405 +++++- python/sglang/srt/mem_cache/radix_cache.py | 47 +- python/sglang/srt/metrics/collector.py | 88 +- .../srt/model_executor/cuda_graph_runner.py | 230 +-- .../srt/model_executor/forward_batch_info.py | 91 +- .../sglang/srt/model_executor/model_runner.py | 268 ++-- python/sglang/srt/model_loader/loader.py | 122 +- .../sglang/srt/model_loader/weight_utils.py | 151 +- python/sglang/srt/model_parallel.py | 71 +- python/sglang/srt/models/baichuan.py | 12 +- python/sglang/srt/models/chatglm.py | 6 +- python/sglang/srt/models/commandr.py | 22 +- python/sglang/srt/models/dbrx.py | 20 +- python/sglang/srt/models/deepseek.py | 8 +- python/sglang/srt/models/deepseek_v2.py | 138 +- python/sglang/srt/models/exaone.py | 4 +- python/sglang/srt/models/gemma.py | 4 +- python/sglang/srt/models/gemma2.py | 93 +- python/sglang/srt/models/gemma2_reward.py | 1 - python/sglang/srt/models/gpt2.py | 8 +- python/sglang/srt/models/gpt_bigcode.py | 2 +- python/sglang/srt/models/granite.py | 517 +++++++ python/sglang/srt/models/grok.py | 203 ++- python/sglang/srt/models/internlm2.py | 4 +- python/sglang/srt/models/llama.py | 93 +- .../sglang/srt/models/llama_classification.py | 34 +- python/sglang/srt/models/llama_eagle.py | 132 ++ python/sglang/srt/models/llama_reward.py | 2 - python/sglang/srt/models/llava.py | 51 +- python/sglang/srt/models/minicpm.py | 4 +- python/sglang/srt/models/minicpm3.py | 18 +- python/sglang/srt/models/minicpmv.py | 1291 +++++++++++++++++ python/sglang/srt/models/mixtral.py | 10 +- python/sglang/srt/models/mixtral_quant.py | 6 +- python/sglang/srt/models/mllama.py | 76 +- python/sglang/srt/models/olmo.py | 6 +- python/sglang/srt/models/olmo2.py | 8 +- python/sglang/srt/models/olmoe.py | 22 +- python/sglang/srt/models/phi3_small.py | 4 +- python/sglang/srt/models/qwen.py | 4 +- python/sglang/srt/models/qwen2.py | 77 +- python/sglang/srt/models/qwen2_eagle.py | 131 ++ python/sglang/srt/models/qwen2_moe.py | 8 +- python/sglang/srt/models/qwen2_vl.py | 168 +-- python/sglang/srt/models/stablelm.py | 4 +- .../sglang/srt/models/torch_native_llama.py | 27 +- python/sglang/srt/models/xverse.py | 12 +- python/sglang/srt/models/xverse_moe.py | 14 +- python/sglang/srt/openai_api/adapter.py | 232 ++- python/sglang/srt/openai_api/protocol.py | 64 +- .../srt/sampling/custom_logit_processor.py | 38 + .../penalizers/repetition_penalty.py | 17 +- .../srt/sampling/sampling_batch_info.py | 164 ++- python/sglang/srt/sampling/sampling_params.py | 25 +- python/sglang/srt/server.py | 1027 +------------ python/sglang/srt/server_args.py | 333 +++-- .../srt/speculative/build_eagle_tree.py | 347 +++++ python/sglang/srt/speculative/eagle_utils.py | 648 +++++++++ python/sglang/srt/speculative/eagle_worker.py | 183 +++ python/sglang/srt/speculative/spec_info.py | 24 + .../sglang/srt/torch_memory_saver_adapter.py | 59 + python/sglang/srt/utils.py | 380 +++-- python/sglang/test/runners.py | 21 +- python/sglang/test/test_block_fp8.py | 341 +++++ python/sglang/test/test_programs.py | 25 +- python/sglang/test/test_utils.py | 151 +- python/sglang/utils.py | 70 +- python/sglang/version.py | 2 +- rust/README.md | 183 --- rust/py_test/test_launch_server.py | 184 --- rust/src/main.rs | 125 -- rust/src/router.rs | 399 ----- rust/src/server.rs | 208 --- scripts/ci_install_dependency.sh | 15 +- scripts/ci_install_rust.sh | 14 +- .../deprecated/test_httpserver_classify.py | 69 - .../test_httpserver_decode_stream.py | 1 - scripts/deprecated/test_jump_forward.py | 2 +- scripts/killall_sglang.sh | 32 +- scripts/playground/reference_hf.py | 86 +- scripts/update_kernel_whl_index.py | 16 + scripts/version_branch_to_tag.sh | 1 + sgl-kernel/.clang-format | 8 + sgl-kernel/3rdparty/cccl | 1 + sgl-kernel/3rdparty/cutlass | 1 + sgl-kernel/3rdparty/flashinfer | 1 + sgl-kernel/3rdparty/turbomind | 1 + sgl-kernel/CMakeLists.txt | 47 - sgl-kernel/Makefile | 25 +- sgl-kernel/README.md | 16 +- sgl-kernel/THIRDPARTYNOTICES.txt | 430 ++++++ sgl-kernel/benchmark/bench_fp8_gemm.py | 164 +++ sgl-kernel/benchmark/bench_int8_gemm.py | 146 ++ .../bench_lightning_attention_decode.py | 299 ++++ sgl-kernel/build.sh | 25 +- sgl-kernel/developer_guide.md | 55 + sgl-kernel/pyproject.toml | 13 +- sgl-kernel/setup.py | 175 ++- sgl-kernel/src/sgl-kernel/__init__.py | 52 +- .../epilogue/epilogue_per_row_per_col_scale.h | 275 ++++ .../gemm/gemm_universal_base_compat.h | 339 +++++ .../gemm/gemm_with_epilogue_visitor.h | 453 ++++++ .../src/sgl-kernel/csrc/fp8_gemm_kernel.cu | 624 ++++++++ .../csrc/fused_add_rms_norm_kernel.cu | 35 + .../src/sgl-kernel/csrc/int8_gemm_kernel.cu | 428 ++++++ .../csrc/lightning_attention_decode_kernel.cu | 118 ++ .../src/sgl-kernel/csrc/moe_align_kernel.cu | 101 ++ .../sgl-kernel/csrc/trt_reduce_internal.cu | 515 +++++++ .../src/sgl-kernel/csrc/trt_reduce_kernel.cu | 201 +++ sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc | 20 - .../src/sgl-kernel/csrc/warp_reduce_kernel.cu | 97 -- .../src/sgl-kernel/include/sgl_kernels_ops.h | 114 ++ .../include/trt_reduce_internal.cuh | 94 ++ sgl-kernel/src/sgl-kernel/include/utils.h | 66 + sgl-kernel/src/sgl-kernel/ops/__init__.py | 498 ++++++- sgl-kernel/src/sgl-kernel/ops/utils.py | 26 + sgl-kernel/src/sgl-kernel/torch_extension.cc | 120 ++ sgl-kernel/tests/.gitkeep | 0 sgl-kernel/tests/test_activation.py | 39 + sgl-kernel/tests/test_bmm_fp8.py | 43 + sgl-kernel/tests/test_fp8_gemm.py | 67 + sgl-kernel/tests/test_int8_gemm.py | 56 + .../tests/test_lightning_attention_decode.py | 88 ++ sgl-kernel/tests/test_moe_align.py | 67 + sgl-kernel/tests/test_norm.py | 133 ++ sgl-kernel/tests/test_rotary_embedding.py | 202 +++ sgl-kernel/tests/test_sampling.py | 141 ++ sgl-kernel/tests/test_trt_reduce.py | 246 ++++ sgl-kernel/version.py | 1 + {rust => sgl-router}/Cargo.lock | 9 +- {rust => sgl-router}/Cargo.toml | 7 +- {rust => sgl-router}/MANIFEST.in | 0 sgl-router/README.md | 97 ++ .../py_src/sglang_router/__init__.py | 6 +- .../py_src/sglang_router/launch_router.py | 48 +- .../py_src/sglang_router/launch_server.py | 158 +- .../py_src/sglang_router/router.py | 9 + sgl-router/py_src/sglang_router/version.py | 1 + {rust => sgl-router}/py_test/run_suite.py | 0 .../py_test/test_launch_router.py | 29 +- sgl-router/py_test/test_launch_server.py | 394 +++++ {rust => sgl-router}/pyproject.toml | 6 +- {rust => sgl-router}/src/lib.rs | 32 +- sgl-router/src/router.rs | 809 +++++++++++ sgl-router/src/server.rs | 199 +++ {rust => sgl-router}/src/tree.rs | 0 sgl-router/v0.1.0.md | 63 + test/README.md | 21 +- test/lang/run_suite.py | 6 +- test/lang/test_srt_backend.py | 3 +- test/srt/configs/random_config.yaml | 25 + .../random_flashinfer_vs_triton_config.yaml | 25 + test/srt/configs/sharegpt_config.yaml | 7 + test/srt/experiment_runner.py | 359 +++++ test/srt/kv_cache_scales_llama3_1_8b.json | 42 + test/srt/kv_cache_scales_llama3_8b.json | 42 + test/srt/kv_cache_scales_qwen2_1_5b.json | 38 + test/srt/models/test_generation_models.py | 1 + test/srt/models/test_qwen_models.py | 76 + test/srt/models/test_reward_models.py | 4 +- test/srt/run_suite.py | 25 +- .../test_srt_endpoint_with_penalizers.py | 9 +- test/srt/test_bench_one_batch.py | 26 +- test/srt/test_bench_serving.py | 44 +- test/srt/test_custom_allreduce.py | 164 +++ test/srt/test_eagle_infer.py | 180 +++ test/srt/test_ebnf_constrained.py | 240 +++ test/srt/test_fp8_kernel.py | 127 ++ test/srt/test_fp8_kvcache.py | 113 ++ test/srt/test_function_calling.py | 249 ++++ test/srt/test_fused_moe.py | 126 ++ test/srt/test_json_constrained.py | 9 - test/srt/test_metrics.py | 4 +- test/srt/test_mla.py | 35 +- test/srt/test_mla_fp8.py | 2 - test/srt/test_moe_ep.py | 4 +- test/srt/test_moe_eval_accuracy_large.py | 2 +- test/srt/test_nightly_gsm8k_eval.py | 55 +- test/srt/test_nightly_human_eval.py | 27 +- test/srt/test_nightly_math_eval.py | 46 + test/srt/test_openai_server.py | 129 ++ test/srt/test_regex_constrained.py | 186 +++ test/srt/test_release_memory_occupation.py | 98 ++ test/srt/test_request_length_validation.py | 71 + test/srt/test_schedule_policy.py | 52 + test/srt/test_session_control.py | 423 +++++- test/srt/test_skip_tokenizer_init.py | 117 +- test/srt/test_srt_endpoint.py | 254 +++- test/srt/test_srt_engine.py | 148 +- test/srt/test_srt_engine_with_quant_args.py | 60 + test/srt/test_torch_compile.py | 2 +- test/srt/test_triton_attention_backend.py | 2 +- test/srt/test_triton_attention_kernels.py | 45 +- test/srt/test_update_weights_from_tensor.py | 38 + test/srt/test_vision_chunked_prefill.py | 173 +++ test/srt/test_vision_llm.py | 210 +++ test/srt/test_vision_openai_server.py | 76 +- test/srt/test_w8a8_quantization.py | 74 + 517 files changed, 48682 insertions(+), 7617 deletions(-) create mode 100644 .clang-format-ignore create mode 100644 .devcontainer/Dockerfile create mode 100644 .devcontainer/devcontainer.json create mode 100644 .github/workflows/experiment-runner.yml rename .github/workflows/{nightly-eval.yml => nightly-test.yml} (68%) create mode 100644 .github/workflows/pr-test-sgl-kernel.yml create mode 100644 .github/workflows/release-docker-dev.yml create mode 100644 .github/workflows/release-pypi-kernel.yml create mode 100644 .github/workflows/release-whl-kernel.yml create mode 100644 benchmark/bench_in_batch_prefix/bench_in_batch_prefix.py create mode 100644 benchmark/deepseek_v3/README.md create mode 100644 benchmark/hicache/bench_multiturn.py create mode 100644 benchmark/kernels/decoding_attention_triton/triton_flashinfer_cudnn.py create mode 100644 benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py create mode 100644 benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py create mode 100644 benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_prefill.py create mode 100644 benchmark/kernels/quantization/bench_int8_quant.py create mode 100644 benchmark/kernels/rmsnorm/benchmark_rmsnorm.py create mode 100644 benchmark/kernels/scheduler_batch/benchmark_write_req_to_token_pool_triton.py delete mode 100644 docs/backend/backend.md create mode 100644 docs/backend/function_calling.ipynb create mode 100644 docs/backend/server_arguments.md create mode 100644 docs/backend/speculative_decoding.ipynb create mode 100644 docs/backend/structured_outputs.ipynb create mode 100644 docs/developer/development_guide_using_docker.md create mode 100644 docs/references/contribution_guide.md delete mode 100644 docs/references/contributor_guide.md create mode 100644 docs/references/deepseek.md create mode 100644 docs/references/llama_405B.md create mode 100644 docs/references/modelscope.md create mode 100644 docs/references/torch_compile_cache.md delete mode 100644 examples/runtime/async_io_api.py create mode 100644 examples/runtime/engine/EAGLE_offline_batch_inference.py delete mode 100644 python/sglang/launch_server_llavavid.py create mode 100644 python/sglang/llama3_eval.py create mode 100644 python/sglang/srt/aio_rwlock.py create mode 100644 python/sglang/srt/configs/chatglm.py create mode 100644 python/sglang/srt/configs/dbrx.py delete mode 100644 python/sglang/srt/constrained/__init__.py create mode 100644 python/sglang/srt/entrypoints/engine.py create mode 100644 python/sglang/srt/entrypoints/http_server.py create mode 100644 python/sglang/srt/function_call_parser.py create mode 100644 python/sglang/srt/layers/attention/vision.py create mode 100644 python/sglang/srt/layers/dp_attention.py delete mode 100644 python/sglang/srt/layers/ep_moe/__init__.py delete mode 100644 python/sglang/srt/layers/fused_moe_patch.py rename python/sglang/srt/{distributed/device_communicators => layers/moe/ep_moe}/__init__.py (100%) rename python/sglang/srt/layers/{ => moe}/ep_moe/kernels.py (100%) rename python/sglang/srt/layers/{ => moe}/ep_moe/layer.py (91%) create mode 100644 python/sglang/srt/layers/moe/fused_moe_native.py rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/__init__.py (72%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json (100%) create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json (100%) create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json (87%) create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json (100%) create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json (100%) create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json (90%) create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json (100%) create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json (100%) create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json (100%) create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json (100%) create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_L40S.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json (87%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json (100%) create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json (100%) create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json (100%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json (100%) create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/configs/README (84%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/fused_moe.py (63%) rename python/sglang/srt/layers/{ => moe}/fused_moe_triton/layer.py (85%) create mode 100644 python/sglang/srt/layers/moe/topk.py create mode 100644 python/sglang/srt/layers/parameter.py create mode 100644 python/sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json create mode 100644 python/sglang/srt/layers/quantization/fp8_kernel.py create mode 100644 python/sglang/srt/layers/quantization/int8_kernel.py create mode 100644 python/sglang/srt/layers/quantization/modelopt_quant.py create mode 100644 python/sglang/srt/layers/quantization/w8a8_int8.py create mode 100644 python/sglang/srt/managers/cache_controller.py create mode 100644 python/sglang/srt/managers/configure_logging.py create mode 100644 python/sglang/srt/managers/utils.py create mode 100644 python/sglang/srt/models/granite.py create mode 100644 python/sglang/srt/models/llama_eagle.py create mode 100644 python/sglang/srt/models/minicpmv.py mode change 100755 => 100644 python/sglang/srt/models/olmo2.py create mode 100644 python/sglang/srt/models/qwen2_eagle.py create mode 100644 python/sglang/srt/sampling/custom_logit_processor.py create mode 100644 python/sglang/srt/speculative/build_eagle_tree.py create mode 100644 python/sglang/srt/speculative/eagle_utils.py create mode 100644 python/sglang/srt/speculative/eagle_worker.py create mode 100644 python/sglang/srt/speculative/spec_info.py create mode 100644 python/sglang/srt/torch_memory_saver_adapter.py create mode 100644 python/sglang/test/test_block_fp8.py delete mode 100644 rust/README.md delete mode 100644 rust/py_test/test_launch_server.py delete mode 100644 rust/src/main.rs delete mode 100644 rust/src/router.rs delete mode 100644 rust/src/server.rs delete mode 100644 scripts/deprecated/test_httpserver_classify.py create mode 100644 scripts/update_kernel_whl_index.py create mode 100644 sgl-kernel/.clang-format create mode 160000 sgl-kernel/3rdparty/cccl create mode 160000 sgl-kernel/3rdparty/cutlass create mode 160000 sgl-kernel/3rdparty/flashinfer create mode 160000 sgl-kernel/3rdparty/turbomind delete mode 100644 sgl-kernel/CMakeLists.txt create mode 100644 sgl-kernel/THIRDPARTYNOTICES.txt create mode 100644 sgl-kernel/benchmark/bench_fp8_gemm.py create mode 100644 sgl-kernel/benchmark/bench_int8_gemm.py create mode 100644 sgl-kernel/benchmark/bench_lightning_attention_decode.py create mode 100644 sgl-kernel/developer_guide.md create mode 100644 sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h create mode 100644 sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h create mode 100644 sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h create mode 100644 sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu create mode 100644 sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu create mode 100644 sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu create mode 100644 sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu create mode 100644 sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu create mode 100644 sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu create mode 100644 sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu delete mode 100644 sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc delete mode 100644 sgl-kernel/src/sgl-kernel/csrc/warp_reduce_kernel.cu create mode 100644 sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h create mode 100644 sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh create mode 100644 sgl-kernel/src/sgl-kernel/include/utils.h create mode 100644 sgl-kernel/src/sgl-kernel/ops/utils.py create mode 100644 sgl-kernel/src/sgl-kernel/torch_extension.cc delete mode 100644 sgl-kernel/tests/.gitkeep create mode 100644 sgl-kernel/tests/test_activation.py create mode 100644 sgl-kernel/tests/test_bmm_fp8.py create mode 100644 sgl-kernel/tests/test_fp8_gemm.py create mode 100644 sgl-kernel/tests/test_int8_gemm.py create mode 100644 sgl-kernel/tests/test_lightning_attention_decode.py create mode 100644 sgl-kernel/tests/test_moe_align.py create mode 100644 sgl-kernel/tests/test_norm.py create mode 100644 sgl-kernel/tests/test_rotary_embedding.py create mode 100644 sgl-kernel/tests/test_sampling.py create mode 100644 sgl-kernel/tests/test_trt_reduce.py create mode 100644 sgl-kernel/version.py rename {rust => sgl-router}/Cargo.lock (99%) rename {rust => sgl-router}/Cargo.toml (86%) rename {rust => sgl-router}/MANIFEST.in (100%) create mode 100644 sgl-router/README.md rename {rust => sgl-router}/py_src/sglang_router/__init__.py (52%) rename {rust => sgl-router}/py_src/sglang_router/launch_router.py (82%) rename {rust => sgl-router}/py_src/sglang_router/launch_server.py (56%) rename {rust => sgl-router}/py_src/sglang_router/router.py (82%) create mode 100644 sgl-router/py_src/sglang_router/version.py rename {rust => sgl-router}/py_test/run_suite.py (100%) rename {rust => sgl-router}/py_test/test_launch_router.py (63%) create mode 100644 sgl-router/py_test/test_launch_server.py rename {rust => sgl-router}/pyproject.toml (87%) rename {rust => sgl-router}/src/lib.rs (70%) create mode 100644 sgl-router/src/router.rs create mode 100644 sgl-router/src/server.rs rename {rust => sgl-router}/src/tree.rs (100%) create mode 100644 sgl-router/v0.1.0.md create mode 100644 test/srt/configs/random_config.yaml create mode 100644 test/srt/configs/random_flashinfer_vs_triton_config.yaml create mode 100644 test/srt/configs/sharegpt_config.yaml create mode 100644 test/srt/experiment_runner.py create mode 100644 test/srt/kv_cache_scales_llama3_1_8b.json create mode 100644 test/srt/kv_cache_scales_llama3_8b.json create mode 100644 test/srt/kv_cache_scales_qwen2_1_5b.json create mode 100644 test/srt/models/test_qwen_models.py create mode 100644 test/srt/test_custom_allreduce.py create mode 100644 test/srt/test_eagle_infer.py create mode 100644 test/srt/test_ebnf_constrained.py create mode 100644 test/srt/test_fp8_kernel.py create mode 100644 test/srt/test_fp8_kvcache.py create mode 100644 test/srt/test_function_calling.py create mode 100644 test/srt/test_fused_moe.py create mode 100644 test/srt/test_nightly_math_eval.py create mode 100644 test/srt/test_regex_constrained.py create mode 100644 test/srt/test_release_memory_occupation.py create mode 100644 test/srt/test_request_length_validation.py create mode 100644 test/srt/test_schedule_policy.py create mode 100644 test/srt/test_srt_engine_with_quant_args.py create mode 100644 test/srt/test_update_weights_from_tensor.py create mode 100644 test/srt/test_vision_chunked_prefill.py create mode 100644 test/srt/test_vision_llm.py create mode 100644 test/srt/test_w8a8_quantization.py diff --git a/.clang-format-ignore b/.clang-format-ignore new file mode 100644 index 00000000000..15c76cc457f --- /dev/null +++ b/.clang-format-ignore @@ -0,0 +1 @@ +sgl-kernel/3rdparty/tensorrt_llm/* diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 00000000000..0c061cd1871 --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,35 @@ +From lmsysorg/sglang:dev + +# Create non-root user with specified UID and GID +# NOTE: Replace with your own UID and GID. This is a workaround from https://github.com/microsoft/vscode-remote-release/issues/49#issuecomment-489060908. +ARG HOST_UID=1003 +ARG HOST_GID=1003 +RUN groupadd -g $HOST_GID devuser && \ + useradd -m -u $HOST_UID -g $HOST_GID -s /bin/zsh devuser + +# Give devuser sudo access +RUN apt-get update && apt-get install -y sudo && \ + echo "devuser ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/devuser && \ + rm -rf /var/lib/apt/lists/* && \ + apt-get clean + +# Set up oh-my-zsh for devuser +RUN cp -r /root/.oh-my-zsh /home/devuser/.oh-my-zsh && \ + cp /root/.zshrc /home/devuser/.zshrc && \ + cp /root/.vimrc /home/devuser/.vimrc && \ + cp /root/.tmux.conf /home/devuser/.tmux.conf && \ + sed -i 's|/root/.oh-my-zsh|/home/devuser/.oh-my-zsh|g' /home/devuser/.zshrc && \ + chown -R devuser:devuser /home/devuser/ + +# Set workspace directory and ownership +WORKDIR /sgl-workspace/sglang +RUN chown -R devuser:devuser /sgl-workspace + +# Switch to devuser +USER devuser + +# Install uv +RUN curl -LsSf https://astral.sh/uv/install.sh | sh + +# Install rust +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 00000000000..5767aa2631a --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,24 @@ +{ + "name": "sglang", + "build": { + "dockerfile": "Dockerfile" + }, + "remoteUser": "devuser", + "customizations": { + "vscode": { + "extensions": [ + // Python development + "ms-python.python", + "charliermarsh.ruff", + // Rust development + "rust-lang.rust-analyzer", + "tamasfe.even-better-toml" + ] + } + }, + "forwardPorts": [], + "runArgs": [ + "--gpus", + "all" + ] +} diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 62612312f55..0f5d7459c39 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -2,6 +2,7 @@ /python/sglang/srt @merrymercy @Ying1123 @hnyls2002 @zhyncs @ispobock @ByronHsu /python/sglang/srt/constrained @hnyls2002 /python/sglang/srt/layers @merrymercy @Ying1123 @zhyncs @ispobock +/python/sglang/srt/layers/moe/fused_moe_triton @zhyncs @ispobock @HaiShaw /python/sglang/srt/lora @Ying1123 /python/sglang/srt/managers @merrymercy @Ying1123 @hnyls2002 /python/sglang/srt/mem_cache @merrymercy @Ying1123 @hnyls2002 @@ -11,4 +12,5 @@ /python/sglang/srt/sampling @merrymercy @hnyls2002 /test/lang @merrymercy @Ying1123 @ByronHsu /test/srt @merrymercy @Ying1123 @zhyncs -/rust @ByronHsu @Ying1123 +/sgl-router @ByronHsu @Ying1123 +/sgl-kernel @zhyncs @ispobock @HandH1998 @BBuf @yizhang2077 @merrymercy diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index c84a764150a..5493c4201c4 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -10,6 +10,7 @@ ## Checklist -- [ ] Format your code according to the [Contributor Guide](https://github.com/sgl-project/sglang/blob/main/docs/references/contributor_guide.md). -- [ ] Add unit tests as outlined in the [Contributor Guide](https://github.com/sgl-project/sglang/blob/main/docs/references/contributor_guide.md). -- [ ] Update documentation as needed, including docstrings or example tutorials. +- [ ] Format your code according to the [Code Formatting with Pre-Commit](https://docs.sglang.ai/references/contribution_guide.html#code-formatting-with-pre-commit). +- [ ] Add unit tests as outlined in the [Running Unit Tests](https://docs.sglang.ai/references/contribution_guide.html#running-unit-tests-adding-to-ci). +- [ ] Update documentation / docstrings / example tutorials as needed, according to [Writing Documentation](https://docs.sglang.ai/references/contribution_guide.html#writing-documentation-running-docs-ci). +- [ ] Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to [Benchmark and Profiling](https://docs.sglang.ai/references/benchmark_and_profiling.html). diff --git a/.github/workflows/execute-notebook.yml b/.github/workflows/execute-notebook.yml index e03edd6ce79..49d649797ed 100644 --- a/.github/workflows/execute-notebook.yml +++ b/.github/workflows/execute-notebook.yml @@ -42,7 +42,7 @@ jobs: python -m ipykernel install --user --name python3 --display-name "Python 3" - name: Execute notebooks - timeout-minutes: 30 + timeout-minutes: 40 run: | cd docs make clean diff --git a/.github/workflows/experiment-runner.yml b/.github/workflows/experiment-runner.yml new file mode 100644 index 00000000000..5ccb8ad28ff --- /dev/null +++ b/.github/workflows/experiment-runner.yml @@ -0,0 +1,30 @@ +name: Experiment Runner + +on: + workflow_dispatch: + inputs: + script: + description: "Experiment Runner Script" + default: "configs/sharegpt_config.yaml" + +concurrency: + group: experiment-runner-${{ github.ref }} + cancel-in-progress: true + +jobs: + experiment-runner-1-gpu: + if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + runs-on: 1-gpu-runner + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Install dependencies + run: | + bash scripts/ci_install_dependency.sh + + - name: Test experiment runner + timeout-minutes: 120 + run: | + cd test/srt + python3 experiment_runner.py --config ${{ inputs.script }} diff --git a/.github/workflows/nightly-eval.yml b/.github/workflows/nightly-test.yml similarity index 68% rename from .github/workflows/nightly-eval.yml rename to .github/workflows/nightly-test.yml index 7b77c63a54c..687fea7f733 100644 --- a/.github/workflows/nightly-eval.yml +++ b/.github/workflows/nightly-test.yml @@ -1,4 +1,4 @@ -name: Nightly Evaluation +name: Nightly Test on: schedule: @@ -11,11 +11,11 @@ on: workflow_dispatch: concurrency: - group: nightly-eval-${{ github.ref }} + group: nightly-test-${{ github.ref }} cancel-in-progress: true jobs: - nightly-eval-2-gpu: + nightly-test: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' runs-on: 2-gpu-runner steps: @@ -27,14 +27,8 @@ jobs: bash scripts/ci_install_dependency.sh pip install --upgrade "evalplus[vllm] @ git+https://github.com/evalplus/evalplus" - - name: Test gsm8k + - name: Run test timeout-minutes: 120 run: | cd test/srt - python3 test_nightly_gsm8k_eval.py - - - name: Test human eval - timeout-minutes: 120 - run: | - cd test/srt - python3 test_nightly_human_eval.py + python3 run_suite.py --suite nightly --timeout-per-file 2400 diff --git a/.github/workflows/pr-test-rust.yml b/.github/workflows/pr-test-rust.yml index 0df81b487b5..277ddef774e 100644 --- a/.github/workflows/pr-test-rust.yml +++ b/.github/workflows/pr-test-rust.yml @@ -4,11 +4,11 @@ on: push: branches: [ main ] paths: - - "rust/**" + - "sgl-router/**" pull_request: branches: [ main ] paths: - - "rust/**" + - "sgl-router/**" workflow_dispatch: concurrency: @@ -30,17 +30,17 @@ jobs: - name: Run fmt run: | source "$HOME/.cargo/env" - cd rust/ + cd sgl-router/ cargo fmt -- --check - name: Run test timeout-minutes: 20 run: | source "$HOME/.cargo/env" - cd rust/ + cd sgl-router/ cargo test - e2e-rust: + e2e-python: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' runs-on: 2-gpu-runner steps: @@ -54,17 +54,18 @@ jobs: - name: Build python binding run: | source "$HOME/.cargo/env" - cd rust + cd sgl-router pip install setuptools-rust wheel build python3 -m build - pip install dist/*.whl + pip install --force-reinstall dist/*.whl - name: Run e2e test run: | - cd rust/py_test + bash scripts/killall_sglang.sh "nuk_gpus" + cd sgl-router/py_test python3 run_suite.py finish: - needs: [unit-test-rust, e2e-rust] + needs: [unit-test-rust, e2e-python] runs-on: ubuntu-latest steps: - name: Finish diff --git a/.github/workflows/pr-test-sgl-kernel.yml b/.github/workflows/pr-test-sgl-kernel.yml new file mode 100644 index 00000000000..df059c1f402 --- /dev/null +++ b/.github/workflows/pr-test-sgl-kernel.yml @@ -0,0 +1,103 @@ +name: PR Test (sgl-kernel) + +on: + push: + branches: [ main ] + paths: + - "sgl-kernel/**" + pull_request: + branches: [ main ] + paths: + - "sgl-kernel/**" + workflow_dispatch: + +concurrency: + group: pr-test-sgl-kernel-${{ github.ref }} + cancel-in-progress: true + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Check clang-format + uses: DoozyX/clang-format-lint-action@v0.18.1 + with: + source: sgl-kernel + extensions: h,c,cpp,hpp,cu,cuh,cc + clangFormatVersion: 16 + style: file + + build-wheels: + if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + runs-on: sgl-kernel-build-node + strategy: + matrix: + python-version: ['3.9'] + cuda-version: ['12.4'] + + steps: + - name: Cleanup + run: | + sudo rm -rf $GITHUB_WORKSPACE/* || true + + - uses: actions/checkout@v4 + with: + submodules: 'recursive' + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Build wheels for Python ${{ matrix.python-version }} and CUDA ${{ matrix.cuda-version }} + run: | + cd sgl-kernel + chmod +x ./build.sh + ./build.sh "${{ matrix.python-version }}" "${{ matrix.cuda-version }}" + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + name: wheel-python${{ matrix.python-version }}-cuda${{ matrix.cuda-version }} + path: sgl-kernel/dist/* + + unit-test: + if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + needs: build-wheels + runs-on: 1-gpu-runner + steps: + - uses: actions/checkout@v4 + + - name: Download artifacts + uses: actions/download-artifact@v4 + with: + path: sgl-kernel/dist/ + merge-multiple: true + pattern: wheel-* + + - name: Install + run: | + pip3 install torch==2.5.1 && pip3 install pytest && pip3 install vllm==0.6.4.post1 + pip3 uninstall sgl-kernel -y || true + pip3 install sgl-kernel/dist/*whl --force-reinstall --no-deps + pip3 list | grep sgl-kernel + + - name: Run test + timeout-minutes: 30 + run: | + cd sgl-kernel + find tests -name "test_*.py" | xargs -n 1 python3 + + - name: Uninstall dependencies + run: | + pip3 uninstall sgl-kernel -y + + finish: + needs: [unit-test, lint] + runs-on: ubuntu-latest + steps: + - name: Finish + run: echo "This is an empty step to ensure that all jobs are completed." diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index 49c6ec88327..6ed6046ee6a 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -29,7 +29,7 @@ concurrency: jobs: unit-test-frontend: - if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false runs-on: 1-gpu-runner steps: - name: Checkout code @@ -37,7 +37,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu121/torch2.4/' || 'https://flashinfer.ai/whl/cu121/torch2.4/' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }} run: | bash scripts/ci_install_dependency.sh @@ -45,35 +45,38 @@ jobs: timeout-minutes: 10 run: | cd test/lang - python3 run_suite.py --suite minimal + python3 run_suite.py --suite per-commit unit-test-backend-1-gpu: - if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false runs-on: 1-gpu-runner strategy: + fail-fast: false matrix: - range: [0-6, 6-15, 15-23, 23-30, 30-100] + range: [0-6, 6-15, 15-22, 22-32, 32-40, 40-100] steps: - name: Checkout code uses: actions/checkout@v3 - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu121/torch2.4/' || 'https://flashinfer.ai/whl/cu121/torch2.4/' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }} run: | bash scripts/ci_install_dependency.sh - name: Run test timeout-minutes: 25 run: | - cd test/srt RANGE=${{ matrix.range }} range_begin=${RANGE%-*} range_end=${RANGE#*-} - python3 run_suite.py --suite minimal --range-begin ${range_begin} --range-end ${range_end} + + cd test/srt + python3 run_suite.py --suite per-commit --range-begin ${range_begin} --range-end ${range_end} + unit-test-backend-2-gpu: - if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false runs-on: 2-gpu-runner steps: - name: Checkout code @@ -81,22 +84,20 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu121/torch2.4/' || 'https://flashinfer.ai/whl/cu121/torch2.4/' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }} run: | bash scripts/ci_install_dependency.sh - - name: Evaluate data parallelism accuracy (DP=2) + - name: Test data parallelism (DP=2) timeout-minutes: 10 run: | cd test/srt python3 test_data_parallelism.py - - name: Evaluate MLA accuracy (TP=2) + - name: Test data parallelism attention (DP=2) timeout-minutes: 10 run: | cd test/srt - python3 test_mla.py - python3 test_mla_fp8.py python3 test_dp_attention.py - name: Test update weights from distributed @@ -105,14 +106,14 @@ jobs: cd test/srt python3 test_update_weights_from_distributed.py - - name: Evaluate MoE EP accuracy (TP=2) + - name: Test expert parallelism (EP=2) timeout-minutes: 10 run: | cd test/srt python3 test_moe_ep.py performance-test-1-gpu-part-1: - if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false runs-on: 1-gpu-runner steps: - name: Checkout code @@ -120,7 +121,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu121/torch2.4/' || 'https://flashinfer.ai/whl/cu121/torch2.4/' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }} run: | bash scripts/ci_install_dependency.sh @@ -128,7 +129,7 @@ jobs: timeout-minutes: 10 run: | cd test/srt - python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_default + python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_bs1 - name: Benchmark online latency timeout-minutes: 10 @@ -148,8 +149,15 @@ jobs: cd test/srt python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size + - name: Benchmark online latency (EAGLE) + timeout-minutes: 10 + run: | + cd test/srt + python3 -m unittest test_bench_serving.TestBenchServing.test_online_latency_eagle + + performance-test-1-gpu-part-2: - if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false runs-on: 1-gpu-runner steps: - name: Checkout code @@ -157,7 +165,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu121/torch2.4/' || 'https://flashinfer.ai/whl/cu121/torch2.4/' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }} run: | bash scripts/ci_install_dependency.sh @@ -180,7 +188,7 @@ jobs: python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_default_fp8 performance-test-2-gpu: - if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false runs-on: 2-gpu-runner steps: - name: Checkout code @@ -188,7 +196,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu121/torch2.4/' || 'https://flashinfer.ai/whl/cu121/torch2.4/' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }} run: | bash scripts/ci_install_dependency.sh @@ -196,7 +204,13 @@ jobs: timeout-minutes: 10 run: | cd test/srt - python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_moe_default + python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_moe_tp2_bs1 + + - name: Benchmark single latency + torch.compile (TP=2) + timeout-minutes: 10 + run: | + cd test/srt + python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_torch_compile_tp2_bs1 - name: Benchmark offline throughput (TP=2) timeout-minutes: 10 @@ -210,8 +224,9 @@ jobs: cd test/srt python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_without_radix_cache + accuracy-test-1-gpu: - if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false runs-on: 1-gpu-runner steps: - name: Checkout code @@ -219,7 +234,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu121/torch2.4/' || 'https://flashinfer.ai/whl/cu121/torch2.4/' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }} run: | bash scripts/ci_install_dependency.sh @@ -235,7 +250,7 @@ jobs: accuracy-test-2-gpu: - if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false runs-on: 2-gpu-runner steps: - name: Checkout code @@ -243,7 +258,7 @@ jobs: - name: Install dependencies env: - FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu121/torch2.4/' || 'https://flashinfer.ai/whl/cu121/torch2.4/' }} + FLASHINFER_REPO: ${{ inputs.version == 'nightly' && 'https://flashinfer.ai/whl/nightly/cu124/torch2.4/flashinfer' || 'https://flashinfer.ai/whl/cu124/torch2.4/flashinfer' }} run: | bash scripts/ci_install_dependency.sh diff --git a/.github/workflows/release-docker-amd.yml b/.github/workflows/release-docker-amd.yml index 866cc5fa520..228eecdb9c5 100644 --- a/.github/workflows/release-docker-amd.yml +++ b/.github/workflows/release-docker-amd.yml @@ -10,19 +10,27 @@ on: jobs: publish: if: github.repository == 'sgl-project/sglang' - runs-on: docker-builder-amd + runs-on: amd-docker environment: 'prod' strategy: matrix: rocm_version: ['6.2.0'] build_type: ['all', 'srt'] steps: - - name: Delete huge unnecessary tools folder - run: rm -rf /opt/hostedtoolcache - - name: Checkout repository uses: actions/checkout@v3 + - name: Free disk space + uses: jlumbroso/free-disk-space@main + with: + tool-cache: false + docker-images: false + android: true + dotnet: true + haskell: true + large-packages: true + swap-storage: false + - name: Login to Docker Hub uses: docker/login-action@v2 with: diff --git a/.github/workflows/release-docker-dev.yml b/.github/workflows/release-docker-dev.yml new file mode 100644 index 00000000000..1526f802e53 --- /dev/null +++ b/.github/workflows/release-docker-dev.yml @@ -0,0 +1,35 @@ +name: Build Development Docker Image + +on: + workflow_dispatch: + schedule: + - cron: '0 0 * * *' + +jobs: + build-dev: + runs-on: ubuntu-22.04 + steps: + - name: Checkout repository + uses: actions/checkout@v3 + + - name: Free disk space + uses: jlumbroso/free-disk-space@main + with: + tool-cache: false + docker-images: false + android: true + dotnet: true + haskell: true + large-packages: true + swap-storage: false + + - name: Login to Docker Hub + uses: docker/login-action@v2 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Build and Push Dev Image + run: | + docker build . -f docker/Dockerfile.dev -t lmsysorg/sglang:dev --no-cache + docker push lmsysorg/sglang:dev diff --git a/.github/workflows/release-docker.yml b/.github/workflows/release-docker.yml index 619b07fe341..d5669886d18 100644 --- a/.github/workflows/release-docker.yml +++ b/.github/workflows/release-docker.yml @@ -14,7 +14,7 @@ jobs: environment: 'prod' strategy: matrix: - cuda_version: ['11.8.0', '12.1.1', '12.4.1'] + cuda_version: ['11.8.0', '12.1.1', '12.4.1', '12.5.1'] build_type: ['all', 'srt'] steps: - name: Delete huge unnecessary tools folder @@ -39,6 +39,8 @@ jobs: cuda_tag="cu121" elif [ "${{ matrix.cuda_version }}" = "12.4.1" ]; then cuda_tag="cu124" + elif [ "${{ matrix.cuda_version }}" = "12.5.1" ]; then + cuda_tag="cu125" else echo "Unsupported CUDA version" exit 1 @@ -58,7 +60,7 @@ jobs: docker build . -f docker/Dockerfile --build-arg CUDA_VERSION=${{ matrix.cuda_version }} --build-arg BUILD_TYPE=${{ matrix.build_type }} -t lmsysorg/sglang:${tag}${tag_suffix} --no-cache docker push lmsysorg/sglang:${tag}${tag_suffix} - if [ "${{ matrix.cuda_version }}" = "12.1.1" ]; then + if [ "${{ matrix.cuda_version }}" = "12.5.1" ]; then docker tag lmsysorg/sglang:${tag}${tag_suffix} lmsysorg/sglang:latest${tag_suffix} docker push lmsysorg/sglang:latest${tag_suffix} fi diff --git a/.github/workflows/release-docs.yml b/.github/workflows/release-docs.yml index ab2129e3721..37db70c7c4b 100644 --- a/.github/workflows/release-docs.yml +++ b/.github/workflows/release-docs.yml @@ -39,7 +39,7 @@ jobs: - name: Execute notebooks and push to documents env: - GITHUB_TOKEN: ${{ secrets.PAT_TOKEN }} + GITHUB_TOKEN: ${{ secrets.DOCUMENTATION_PAT_TOKEN }} run: | cd docs make clean @@ -49,7 +49,7 @@ jobs: cd _build/html git clone https://$GITHUB_TOKEN@github.com/sgl-project/sgl-project.github.io.git ../sgl-project.github.io --depth 1 - rm -rf ../sgl-project.github.io/* + find ../sgl-project.github.io/ -mindepth 1 -not -path "../sgl-project.github.io/.git*" -not -name CNAME -not -name ".jekyll" -not -name ".nojekyll" -delete cp -r * ../sgl-project.github.io cp ../../README.md ../sgl-project.github.io/README.md cd ../sgl-project.github.io diff --git a/.github/workflows/release-pypi-kernel.yml b/.github/workflows/release-pypi-kernel.yml new file mode 100644 index 00000000000..af34c8423ce --- /dev/null +++ b/.github/workflows/release-pypi-kernel.yml @@ -0,0 +1,44 @@ +name: Release SGLang Kernel to PyPI + +on: + push: + branches: + - main + paths: + - sgl-kernel/version.py + workflow_dispatch: + +concurrency: + group: release-pypi-kernel-${{ github.ref }} + cancel-in-progress: true + +jobs: + build-wheels: + if: github.repository == 'sgl-project/sglang' + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.9'] + cuda-version: ['12.4'] + + steps: + - uses: actions/checkout@v4 + with: + submodules: 'recursive' + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Build wheels for Python ${{ matrix.python-version }} and CUDA ${{ matrix.cuda-version }} + run: | + cd sgl-kernel + chmod +x ./build.sh + ./build.sh "${{ matrix.python-version }}" "${{ matrix.cuda-version }}" + + - name: Upload to pypi + working-directory: sgl-kernel + run: | + pip install twine + python3 -m twine upload dist/* -u __token__ -p ${{ secrets.PYPI_TOKEN }} diff --git a/.github/workflows/release-pypi-router.yml b/.github/workflows/release-pypi-router.yml index fbbb1f0243b..547522e8aa6 100644 --- a/.github/workflows/release-pypi-router.yml +++ b/.github/workflows/release-pypi-router.yml @@ -7,7 +7,7 @@ on: branches: - main paths: - - rust/pyproject.toml + - sgl-router/pyproject.toml workflow_dispatch: jobs: @@ -26,9 +26,9 @@ jobs: with: path: sglang-repo - - name: Move rust folder to root and delete sglang-repo + - name: Move sgl-router folder to root and delete sglang-repo run: | - mv sglang-repo/rust/* . + mv sglang-repo/sgl-router/* . rm -rf sglang-repo ls -alt @@ -69,9 +69,9 @@ jobs: with: path: sglang-repo - - name: Move rust folder to root, copy the license file, and delete sglang-repo + - name: Move sgl-router folder to root, copy the license file, and delete sglang-repo run: | - mv sglang-repo/rust/* . + mv sglang-repo/sgl-router/* . mv sglang-repo/LICENSE . rm -rf sglang-repo ls -alt @@ -84,6 +84,7 @@ jobs: - name: Build SDist run: | pip install build + python -m pip install -U packaging python -m build --sdist - uses: actions/upload-artifact@v4 diff --git a/.github/workflows/release-whl-kernel.yml b/.github/workflows/release-whl-kernel.yml new file mode 100644 index 00000000000..70c451778fa --- /dev/null +++ b/.github/workflows/release-whl-kernel.yml @@ -0,0 +1,92 @@ +name: Release SGLang Kernel Wheel (cu118) + +on: + workflow_dispatch: + inputs: + tag_name: + type: string + push: + branches: + - main + paths: + - sgl-kernel/version.py + +jobs: + build-wheels: + if: github.repository == 'sgl-project/sglang' + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.9'] + cuda-version: ['11.8'] + + steps: + - uses: actions/checkout@v4 + with: + submodules: 'recursive' + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Build wheels for Python ${{ matrix.python-version }} and CUDA ${{ matrix.cuda-version }} + run: | + cd sgl-kernel + chmod +x ./build.sh + ./build.sh "${{ matrix.python-version }}" "${{ matrix.cuda-version }}" + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + name: wheel-python${{ matrix.python-version }}-cuda${{ matrix.cuda-version }} + path: sgl-kernel/dist/* + + release: + needs: build-wheels + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Download artifacts + uses: actions/download-artifact@v4 + with: + path: sgl-kernel/dist/ + merge-multiple: true + pattern: wheel-* + + - name: Set tag name + id: set_tag_name + run: | + if [ -z "${{ inputs.tag_name }}" ]; then + TAG_NAME="v$(cat sgl-kernel/version.py | cut -d'"' -f2)" + echo "tag_name=$TAG_NAME" >> $GITHUB_OUTPUT + else + echo "tag_name=${{ inputs.tag_name }}" >> $GITHUB_OUTPUT + fi + + - name: Release + uses: softprops/action-gh-release@v2 + with: + tag_name: ${{ steps.set_tag_name.outputs.tag_name }} + repository: sgl-project/whl + token: ${{ secrets.WHL_TOKEN }} + files: | + sgl-kernel/dist/* + + - name: Clone wheel index + run: git clone https://oauth2:${WHL_TOKEN}@github.com/sgl-project/whl.git sgl-whl + env: + WHL_TOKEN: ${{ secrets.WHL_TOKEN }} + + - name: Update wheel index + run: python3 scripts/update_kernel_whl_index.py + + - name: Push wheel index + run: | + cd sgl-whl + git config --local user.name "github-actions[bot]" + git config --local user.email "41898282+github-actions[bot]@users.noreply.github.com" + git add -A + git commit -m "update whl index" + git push diff --git a/.gitignore b/.gitignore index 6d0987f2782..75e29fac373 100644 --- a/.gitignore +++ b/.gitignore @@ -220,3 +220,10 @@ work_dirs/ *.app compile_commands.json + +*.iml + +# VSCode +.vscode + +1 diff --git a/.gitmodules b/.gitmodules index e69de29bb2d..97f3421449d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -0,0 +1,12 @@ +[submodule "sgl-kernel/3rdparty/cutlass"] + path = sgl-kernel/3rdparty/cutlass + url = https://github.com/NVIDIA/cutlass.git +[submodule "sgl-kernel/3rdparty/cccl"] + path = sgl-kernel/3rdparty/cccl + url = https://github.com/NVIDIA/cccl.git +[submodule "sgl-kernel/3rdparty/flashinfer"] + path = sgl-kernel/3rdparty/flashinfer + url = https://github.com/flashinfer-ai/flashinfer.git +[submodule "sgl-kernel/3rdparty/turbomind"] + path = sgl-kernel/3rdparty/turbomind + url = https://github.com/InternLM/turbomind diff --git a/3rdparty/amd/profiling/PROFILING.md b/3rdparty/amd/profiling/PROFILING.md index 79bc75b503b..7e15ec844f2 100644 --- a/3rdparty/amd/profiling/PROFILING.md +++ b/3rdparty/amd/profiling/PROFILING.md @@ -336,7 +336,7 @@ loadTracer.sh python3 -m sglang.launch_server \ --model-path /sgl-workspace/sglang/dummy_grok1 \ --tokenizer-path Xenova/grok-1-tokenizer \ --load-format dummy \ - --quant fp8 \ + --quantization fp8 \ --tp 8 \ --port 30000 \ --disable-radix-cache 2>&1 | tee "$LOGFILE" diff --git a/3rdparty/amd/profiling/server.sh b/3rdparty/amd/profiling/server.sh index aa574f64c94..f877e6c7acd 100755 --- a/3rdparty/amd/profiling/server.sh +++ b/3rdparty/amd/profiling/server.sh @@ -14,7 +14,7 @@ loadTracer.sh python3 -m sglang.launch_server \ --model-path /sgl-workspace/sglang/dummy_grok1 \ --tokenizer-path Xenova/grok-1-tokenizer \ --load-format dummy \ - --quant fp8 \ + --quantization fp8 \ --tp 8 \ --port 30000 \ --disable-radix-cache 2>&1 | tee "$LOGFILE" diff --git a/3rdparty/amd/tuning/TUNING.md b/3rdparty/amd/tuning/TUNING.md index a38a16d4f7a..0638041c974 100644 --- a/3rdparty/amd/tuning/TUNING.md +++ b/3rdparty/amd/tuning/TUNING.md @@ -104,7 +104,7 @@ To maximize moe kernel efficiency, need to use below scripts to find out the bes ```bash #Tuning -#for example, we have one case like this "python3 -m sglang.bench_latency --model dummy_grok1/ --load-format dummy --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --batch-size 32 --input 1024 --output 8 --attention-backend triton --sampling-backend pytorch --quant fp" to run, it defined batch-size 32 input lenth 1024 and output length 8, from "--batch" in moe view point, the prefill batch is 32*1024 = 32768, the decode batch is 32*1(only one output token generated in each run). +#for example, we have one case like this "python3 -m sglang.bench_latency --model dummy_grok1/ --load-format dummy --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --batch-size 32 --input 1024 --output 8 --attention-backend triton --sampling-backend pytorch --quantization fp8" to run, it defined batch-size 32 input lenth 1024 and output length 8, from "--batch" in moe view point, the prefill batch is 32*1024 = 32768, the decode batch is 32*1(only one output token generated in each run). #so we can tune decode moe use below command python benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch "32" # and use this command to tune prefill moe diff --git a/3rdparty/amd/tuning/benchmark_moe_rocm.py b/3rdparty/amd/tuning/benchmark_moe_rocm.py index a3f26e8e502..5aff8c0d664 100644 --- a/3rdparty/amd/tuning/benchmark_moe_rocm.py +++ b/3rdparty/amd/tuning/benchmark_moe_rocm.py @@ -10,7 +10,10 @@ from tqdm import tqdm from transformers import AutoConfig -from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe, get_config_file_name +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( + fused_moe, + get_config_file_name, +) padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0 diff --git a/Makefile b/Makefile index 0cf0dcd1a2e..23d1ddf72ee 100644 --- a/Makefile +++ b/Makefile @@ -1,13 +1,18 @@ -.PHONY: check-deps install-deps format update +.PHONY: check-deps install-deps format update help -check-deps: +# Show help for each target +help: + @echo "Available targets:" + @grep -E '^[a-zA-Z0-9_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}' + +check-deps: ## Check and install required Python formatting dependencies @command -v isort >/dev/null 2>&1 || (echo "Installing isort..." && pip install isort) @command -v black >/dev/null 2>&1 || (echo "Installing black..." && pip install black) -install-deps: +install-deps: ## Install Python formatting tools (isort and black) pip install isort black -format: check-deps +format: check-deps ## Format modified Python files using isort and black @echo "Formatting modified Python files..." git diff --name-only --diff-filter=M | grep '\.py$$' | xargs -I {} sh -c 'isort {} && black {}' @@ -17,7 +22,7 @@ FILES_TO_UPDATE = docker/Dockerfile.rocm \ docs/developer/setup_github_runner.md \ docs/start/install.md -update: +update: ## Update version numbers across project files. Usage: make update @if [ -z "$(filter-out $@,$(MAKECMDGOALS))" ]; then \ echo "Version required. Usage: make update "; \ exit 1; \ diff --git a/README.md b/README.md index bc8734936cd..b27271a1810 100644 --- a/README.md +++ b/README.md @@ -12,20 +12,23 @@ -------------------------------------------------------------------------------- -| [**Blog**](https://lmsys.org/blog/2024-07-25-sglang-llama3/) | [**Documentation**](https://sgl-project.github.io/) | [**Join Slack**](https://join.slack.com/t/sgl-fru7574/shared_invite/zt-2tmmp6flg-89dOlJW2TjnBrTRk1I_~GA) | -[**Join Bi-Weekly Development Meeting**](https://docs.google.com/document/d/1xEow4eIM152xNcRxqZz9VEcOiTQo8-CEuuQ5qTmkt-E/edit?usp=sharing) | [**Slides**](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#slides) | +| [**Blog**](https://lmsys.org/blog/2024-07-25-sglang-llama3/) +| [**Documentation**](https://docs.sglang.ai/) +| [**Join Slack**](https://slack.sglang.ai/) +| [**Join Bi-Weekly Development Meeting**](https://meeting.sglang.ai/) +| [**Slides**](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#slides) | ## News -- [2024/12] 🔥 SGLang v0.4: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs ([blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)). -- [2024/10] 🔥 The First SGLang Online Meetup ([slides](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#the-first-sglang-online-meetup)). -- [2024/09] SGLang v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)). -- [2024/07] Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)). +- [2025/01] 🔥 SGLang provides day one support for DeepSeek V3/R1 models on NVIDIA and AMD GPUs with DeepSeek-specific optimizations. ([instructions](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3), [AMD blog](https://www.amd.com/en/developer/resources/technical-articles/amd-instinct-gpus-power-deepseek-v3-revolutionizing-ai-development-with-sglang.html)) +- [2024/12] 🔥 v0.4 Release: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs ([blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)). +- [2024/09] v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)). +- [2024/07] v0.2 Release: Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)).
More +- [2024/10] The First SGLang Online Meetup ([slides](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#the-first-sglang-online-meetup)). - [2024/02] SGLang enables **3x faster JSON decoding** with compressed finite state machine ([blog](https://lmsys.org/blog/2024-02-05-compressed-fsm/)). -- [2024/04] SGLang is used by the official **LLaVA-NeXT (video)** release ([blog](https://llava-vl.github.io/blog/2024-04-30-llava-next-video/)). - [2024/01] SGLang provides up to **5x faster inference** with RadixAttention ([blog](https://lmsys.org/blog/2024-01-17-sglang/)). - [2024/01] SGLang powers the serving of the official **LLaVA v1.6** release demo ([usage](https://github.com/haotian-liu/LLaVA?tab=readme-ov-file#demo)). @@ -42,20 +45,20 @@ The core features include: - **Active Community**: SGLang is open-source and backed by an active community with industry adoption. ## Getting Started -- [Install SGLang](https://sgl-project.github.io/start/install.html) -- [Send requests](https://sgl-project.github.io/start/send_request.html) -- [Backend: SGLang Runtime (SRT)](https://sgl-project.github.io/backend/backend.html) -- [Frontend: Structured Generation Language (SGLang)](https://sgl-project.github.io/frontend/frontend.html) +- [Install SGLang](https://docs.sglang.ai/start/install.html) +- [Quick Start](https://docs.sglang.ai/start/send_request.html) +- [Backend Tutorial](https://docs.sglang.ai/backend/openai_api_completions.html) +- [Frontend Tutorial](https://docs.sglang.ai/frontend/frontend.html) +- [Contribution Guide](https://docs.sglang.ai/references/contribution_guide.html) -## Benchmark And Performance -Learn more in our release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/), [v0.3 blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/), [v0.4 blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/) +## Benchmark and Performance +Learn more in the release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/), [v0.3 blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/), [v0.4 blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/) ## Roadmap [Development Roadmap (2024 Q4)](https://github.com/sgl-project/sglang/issues/1487) ## Adoption and Sponsorship -The project is supported by (alphabetically): AMD, Baseten, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, Meituan, NVIDIA, RunPod, Stanford, UC Berkeley, xAI and 01.AI. +The project is supported by (alphabetically): AMD, Baseten, Cursor, DataCrunch, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, LMSYS.org, Meituan, Novita AI, NVIDIA, RunPod, Stanford, UC Berkeley, UCLA, xAI, 01.AI. ## Acknowledgment and Citation -We learned from the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql). -Please cite our paper, [SGLang: Efficient Execution of Structured Language Model Programs](https://arxiv.org/abs/2312.07104), if you find the project useful. +We learned the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql). Please cite the paper, [SGLang: Efficient Execution of Structured Language Model Programs](https://arxiv.org/abs/2312.07104), if you find the project useful. diff --git a/benchmark/bench_in_batch_prefix/bench_in_batch_prefix.py b/benchmark/bench_in_batch_prefix/bench_in_batch_prefix.py new file mode 100644 index 00000000000..86648e5ff17 --- /dev/null +++ b/benchmark/bench_in_batch_prefix/bench_in_batch_prefix.py @@ -0,0 +1,130 @@ +# Benchmark with lots of common prefixes. Used to benchmark prefix caching performance. +# +# Launch a server: +# python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --log-level-http warning + +import random +import string +import time + +from tqdm import tqdm +from transformers import AutoTokenizer + +import sglang as sgl +from sglang import set_default_backend +from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint + + +def generate_random_string(token_length: int) -> str: + random_string = "".join( + random.choices(string.ascii_letters + string.digits, k=token_length * 100) + ) + tokenized_output = tokenizer.encode(random_string, add_special_tokens=False)[ + :token_length + ] + + if len(tokenized_output) < token_length: + tokenized_output = tokenized_output + [tokenizer.pad_token_id] * ( + token_length - len(tokenized_output) + ) + + decoded_string = tokenizer.decode(tokenized_output, skip_special_tokens=False) + return decoded_string + + +def generate_unique_prefix(base_text, index): + return str(index) + base_text[len(str(index)) :] + + +@sgl.function +def text_qa(s, question, gen_len): + s += "Q: " + question + "\n" + s += "A:" + sgl.gen("answer", stop="\n", temperature=0, max_tokens=gen_len) + + +def prepare_prompts(num_prefix, num_samples_per_prefix, prefix_length, suffix_length): + base_prefix = generate_random_string(prefix_length) + + tot_input_len = 0 + all_prompts = [] + for i in tqdm(range(num_prefix), desc="prepare prompts"): + unique_prefix = generate_unique_prefix(base_prefix, i) + prompt_list = [] + for j in range(num_samples_per_prefix): + suffix = generate_random_string(suffix_length) + prompt = unique_prefix + suffix + prompt_list.append(prompt) + tot_input_len += len(tokenizer.encode(prompt)) + all_prompts.append(prompt_list) + return all_prompts, tot_input_len + + +def test_batch_by_batch(all_prompts, gen_len): + backend.flush_cache() + + tot_time = 0 + for i in range(len(all_prompts)): + tic = time.time() + text_qa.run_batch( + list(zip(all_prompts[i], [gen_len] * len(all_prompts[i]))), + ) + tot_time += time.time() - tic + + return tot_time + + +def test_batch_by_batch_with_hint(all_prompts, gen_len): + backend.flush_cache() + + tot_time = 0 + for i in range(len(all_prompts)): + tic = time.time() + # Send a hint to cache the prefix + text_qa.run_batch(list(zip(all_prompts[i][:1], [gen_len]))) + # Send the batch + text_qa.run_batch(list(zip(all_prompts[i], [gen_len] * len(all_prompts[i])))) + + tot_time += time.time() - tic + + return tot_time + + +def test_send_all(all_prompts, gen_len): + backend.flush_cache() + + all_prompts = [x for prompt_list in all_prompts for x in prompt_list] + + tic = time.time() + text_qa.run_batch( + list(zip(all_prompts, [gen_len] * len(all_prompts))), + ) + tot_time = time.time() - tic + + return tot_time + + +if __name__ == "__main__": + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + backend = RuntimeEndpoint("http://127.0.0.1:30000") + set_default_backend(backend) + + random.seed(0) + num_prefix = 10 + num_samples_per_prefix = 32 + prefix_length = 1024 + suffix_length = 128 + gen_len = 1 + all_prompts, tot_input_len = prepare_prompts( + num_prefix, num_samples_per_prefix, prefix_length, suffix_length + ) + + print(f"Total input token length: {tot_input_len}\n") + + cost = test_batch_by_batch(all_prompts, gen_len) + print(f"Latency of test_batch_by_batch : {cost:.4f} s\n") + + cost = test_batch_by_batch_with_hint(all_prompts, gen_len) + print(f"Latency of test_batch_by_batch_with_hint: {cost:.4f} s\n") + + cost = test_send_all(all_prompts, gen_len) + print(f"Latency of test_send_all : {cost:.4f} s\n") diff --git a/benchmark/blog_v0_2/405b_sglang.sh b/benchmark/blog_v0_2/405b_sglang.sh index 4e3372ae8c7..49185378280 100644 --- a/benchmark/blog_v0_2/405b_sglang.sh +++ b/benchmark/blog_v0_2/405b_sglang.sh @@ -6,7 +6,7 @@ # wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json # Launch sglang -# python -m sglang.launch_server --model-path ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quant fp8 --disable-radix --mem-frac 0.87 +# python -m sglang.launch_server --model-path ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quantization fp8 --disable-radix --mem-frac 0.87 # offline python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > sglang_log11 diff --git a/benchmark/deepseek_v3/README.md b/benchmark/deepseek_v3/README.md new file mode 100644 index 00000000000..ea972831a36 --- /dev/null +++ b/benchmark/deepseek_v3/README.md @@ -0,0 +1,128 @@ +# DeepSeek V3 Support + +The SGLang and DeepSeek teams collaborated to get DeepSeek V3 FP8 running on NVIDIA and AMD GPUs **from day one**. SGLang also supports [MLA optimization](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#deepseek-multi-head-latent-attention-mla-throughput-optimizations) and [DP attention](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models), making SGLang one of the best open-source LLM engines for running DeepSeek models. SGLang is the inference engine recommended by the official [DeepSeek team](https://github.com/deepseek-ai/DeepSeek-V3/tree/main?tab=readme-ov-file#62-inference-with-sglang-recommended). + +Special thanks to Meituan's Search & Recommend Platform Team and Baseten's Model Performance Team for implementing the model, and DataCrunch for providing GPU resources. + +For optimizations made on the DeepSeek series models regarding SGLang, please refer to [DeepSeek Model Optimizations in SGLang](https://docs.sglang.ai/references/deepseek.html). + +## Hardware Recommendation +- 8 x NVIDIA H200 GPUs + +If you do not have GPUs with large enough memory, please try multi-node tensor parallelism. There is an example serving with [2 H20 nodes](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-2-h208) below. + +## Installation & Launch + +If you encounter errors when starting the server, ensure the weights have finished downloading. It's recommended to download them beforehand or restart multiple times until all weights are downloaded. + +### Using Docker (Recommended) +```bash +# Pull latest image +# https://hub.docker.com/r/lmsysorg/sglang/tags +docker pull lmsysorg/sglang:latest + +# Launch +docker run --gpus all --shm-size 32g -p 30000:30000 -v ~/.cache/huggingface:/root/.cache/huggingface --ipc=host lmsysorg/sglang:latest \ + python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code --port 30000 +``` + +For high QPS scenarios, add the `--enable-dp-attention` argument to boost throughput. + +### Using pip +```bash +# Installation +pip install "sglang[all]>=0.4.1.post5" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer + +# Launch +python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code +``` + +For high QPS scenarios, add the `--enable-dp-attention` argument to boost throughput. + +### Example: Sending requests with OpenAI API + +```python3 +import openai +client = openai.Client( + base_url="http://127.0.0.1:30000/v1", api_key="EMPTY") + +# Chat completion +response = client.chat.completions.create( + model="default", + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": "List 3 countries and their capitals."}, + ], + temperature=0, + max_tokens=64, +) +print(response) +``` + +### Example: Serving with two H20*8 nodes +For example, there are two H20 nodes, each with 8 GPUs. The first node's IP is `10.0.0.1`, and the second node's IP is `10.0.0.2`. Please **use the first node's IP** for both commands. + +If the command fails, try setting the `GLOO_SOCKET_IFNAME` parameter. For more information, see [Common Environment Variables](https://pytorch.org/docs/stable/distributed.html#common-environment-variables). + +```bash +# node 1 +python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 10.0.0.1:5000 --nnodes 2 --node-rank 0 --trust-remote-code + +# node 2 +python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 10.0.0.1:5000 --nnodes 2 --node-rank 1 --trust-remote-code +``` + +If you have two H100 nodes, the usage is similar to the aforementioned H20. + +### Example: Serving with two H200*8 nodes and docker +There are two H200 nodes, each with 8 GPUs. The first node's IP is `192.168.114.10`, and the second node's IP is `192.168.114.11`. Configure the endpoint to expose it to another Docker container using `--host 0.0.0.0` and `--port 40000`, and set up communications with `--dist-init-addr 192.168.114.10:20000`. +A single H200 with 8 devices can run DeepSeek V3, the dual H200 setup is just to demonstrate multi-node usage. + +```bash +# node 1 +docker run --gpus all \ + --shm-size 32g \ + --network=host \ + -v ~/.cache/huggingface:/root/.cache/huggingface \ + --name sglang_multinode1 \ + -it \ + --rm \ + --env "HF_TOKEN=$HF_TOKEN" \ + --ipc=host \ + lmsysorg/sglang:latest \ + python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 192.168.114.10:20000 --nnodes 2 --node-rank 0 --trust-remote-code --host 0.0.0.0 --port 40000 +``` + +```bash +# node 2 +docker run --gpus all \ + --shm-size 32g \ + --network=host \ + -v ~/.cache/huggingface:/root/.cache/huggingface \ + --name sglang_multinode2 \ + -it \ + --rm \ + --env "HF_TOKEN=$HF_TOKEN" \ + --ipc=host \ + lmsysorg/sglang:latest \ + python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --dist-init-addr 192.168.114.10:20000 --nnodes 2 --node-rank 1 --trust-remote-code --host 0.0.0.0 --port 40000 +``` + +To ensure functionality, we include a test from a client Docker container. +```bash +docker run --gpus all \ + --shm-size 32g \ + --network=host \ + -v ~/.cache/huggingface:/root/.cache/huggingface \ + --name sglang_multinode_client \ + -it \ + --rm \ + --env "HF_TOKEN=$HF_TOKEN" \ + --ipc=host \ + lmsysorg/sglang:latest \ + python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 1 --random-output 512 --random-range-ratio 1 --num-prompts 1 --host 0.0.0.0 --port 40000 --output-file "deepseekv3_multinode.jsonl" +``` + +## DeepSeek V3 Optimization Plan + +https://github.com/sgl-project/sglang/issues/2591 diff --git a/benchmark/gsm8k/bench_sglang.py b/benchmark/gsm8k/bench_sglang.py index 9fe9b79baaf..f01734f0afb 100644 --- a/benchmark/gsm8k/bench_sglang.py +++ b/benchmark/gsm8k/bench_sglang.py @@ -1,6 +1,7 @@ import argparse import ast import json +import os import re import time @@ -46,9 +47,11 @@ def main(args): set_default_backend(select_sglang_backend(args)) # Read data + data_path = args.data_path url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" - filename = download_and_cache_file(url) - lines = list(read_jsonl(filename)) + if not os.path.isfile(data_path): + data_path = download_and_cache_file(url) + lines = list(read_jsonl(data_path)) # Construct prompts num_questions = args.num_questions diff --git a/benchmark/hellaswag/bench_sglang.py b/benchmark/hellaswag/bench_sglang.py index f09d7256da9..798521f9766 100644 --- a/benchmark/hellaswag/bench_sglang.py +++ b/benchmark/hellaswag/bench_sglang.py @@ -1,5 +1,6 @@ import argparse import json +import os import time import numpy as np @@ -31,9 +32,11 @@ def main(args): set_default_backend(select_sglang_backend(args)) # Read data + data_path = args.data_path url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl" - filename = download_and_cache_file(url) - lines = list(read_jsonl(filename)) + if not os.path.isfile(data_path): + data_path = download_and_cache_file(url) + lines = list(read_jsonl(data_path)) # Construct prompts num_questions = args.num_questions diff --git a/benchmark/hicache/bench_multiturn.py b/benchmark/hicache/bench_multiturn.py new file mode 100644 index 00000000000..ab34c33da44 --- /dev/null +++ b/benchmark/hicache/bench_multiturn.py @@ -0,0 +1,334 @@ +import argparse +import asyncio +import json +import queue +import random +import threading +import time +from typing import Optional + +import aiohttp +import requests +from tqdm.asyncio import tqdm + +from sglang.bench_serving import ( + RequestFuncOutput, + get_tokenizer, + remove_prefix, + sample_random_requests, +) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Script to benchmark concurrent requests to a server." + ) + parser.add_argument( + "--num-clients", + type=int, + default=200, + help="Number of concurrent clients", + ) + parser.add_argument( + "--request-length", + type=int, + default=512, + help="Length of each new request", + ) + parser.add_argument( + "--output-length", + type=int, + default=64, + help="Length of each output", + ) + parser.add_argument( + "--num-rounds", + type=int, + default=5, + help="Number of rounds per client", + ) + parser.add_argument( + "--distribution", + type=str, + default="poisson", + choices=["poisson", "uniform"], + help="Distribution type for request intervals (poisson or uniform)", + ) + parser.add_argument( + "--request-rate", + type=float, + default=1.0, + help="Average number of requests per second", + ) + parser.add_argument( + "--host", + type=str, + default="localhost", + help="Server hostname or IP (default: localhost)", + ) + parser.add_argument( + "--port", + type=int, + default=30000, + help="Server port (default: 30000)", + ) + parser.add_argument( + "--model", + type=str, + default="meta-llama/Llama-3.1-8B-Instruct", + help="model path compatible with Hugging Face Transformers", + ) + return parser.parse_args() + + +async def async_request_sglang_generate( + payload, + url, + pbar: Optional[tqdm] = None, +): + """ + Sends a streaming request to the server. Gathers text token-by-token. + """ + async with aiohttp.ClientSession() as session: + headers = {} + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + output = RequestFuncOutput() + + try: + async with session.post(url=url, json=payload, headers=headers) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") + latency = time.perf_counter() - st + if chunk == "[DONE]": + pass + else: + data = json.loads(chunk) + + if data["text"]: + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text = data["text"] + + output.generated_text = generated_text + output.success = True + output.latency = latency + else: + output.error = response.reason or "" + output.success = False + except Exception as e: + output.success = False + output.error = str(e) + print(f"Request failed: {e}") + + if pbar: + pbar.update(1) + return output + + +def gen_payload(prompt, output_len): + payload = { + "text": prompt, + "sampling_params": { + "temperature": 0.0, + "max_new_tokens": output_len, + "ignore_eos": True, + }, + "stream": True, + "lora_path": "", + "return_logprob": False, + "logprob_start_len": -1, + } + return payload + + +class ReadyQueue: + """ + Thread-safe queue that can pop requests in different orders based on given policy. + """ + + def __init__(self, init_requests=None, policy="random"): + self.lock = threading.Lock() + self.requests = init_requests or [] + self.policy = policy + + def append(self, item): + with self.lock: + self.requests.append(item) + + def pop(self): + with self.lock: + if not self.requests: + return None + if self.policy == "random": + index = random.randrange(len(self.requests)) + return self.requests.pop(index) + elif self.policy == "fifo": + return self.requests.pop(0) + else: + # todo, varying thinking time of clients + raise ValueError(f"{self.policy} not implemented") + + +class WorkloadGenerator: + def __init__(self, args): + # Construct the base URL for requests + self.url = f"http://{args.host}:{args.port}/generate" + + self.tokenizer = get_tokenizer(args.model) + self.distribution = args.distribution + self.request_rate = args.request_rate + self.start_time = None + self.finished_time = None + + self.candidate_inputs = sample_random_requests( + input_len=args.request_length, + output_len=args.output_length, + num_prompts=args.num_clients * args.num_rounds, + range_ratio=1.0, + tokenizer=self.tokenizer, + dataset_path="", + ) + self.candidate_inputs = [i[0] for i in self.candidate_inputs] + + init_requests = [ + (i, gen_payload(self.candidate_inputs[i], args.output_length)) + for i in range(args.num_clients) + ] + self.client_records = { + i: {"round": 0, "history": init_requests[i][1]["text"]} + for i in range(args.num_clients) + } + self.ready_queue = ReadyQueue(init_requests=init_requests) + self.candidate_inputs = self.candidate_inputs[args.num_clients :] + + self.response_queue = queue.Queue() + self.pbar = tqdm(total=args.num_clients * args.num_rounds) + self.performance_metrics = {"ttft": [], "latency": []} + + async def handle_request(self, item): + try: + client_id, payload = item + response = await async_request_sglang_generate(payload, self.url, self.pbar) + if self.pbar.n == self.pbar.total: + self.finished_time = time.time() + self.response_queue.put((client_id, response)) + except Exception as e: + print(f"Request failed: {e}") + + def request_sender(self): + async def request_loop(): + while True: + # Calculate Poisson-distributed wait time + if self.distribution == "poisson": + sleep_time = random.expovariate(self.request_rate) + elif self.distribution == "uniform": + avg_interval = ( + 1.0 / self.request_rate if self.request_rate > 0 else 1.0 + ) + sleep_time = random.uniform(0, 2 * avg_interval) + else: + raise ValueError("Invalid distribution type") + await asyncio.sleep(sleep_time) # Wait before sending the next request + + new_request = self.ready_queue.pop() + # Submit async request + if new_request: + asyncio.create_task(self.handle_request(new_request)) + else: + if self.pbar.n == self.pbar.total: + break + + # Create and run the event loop for asynchronous requests + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(request_loop()) + loop.close() + + def response_handler(self): + while True: + try: + client_id, response = self.response_queue.get( + timeout=10 + ) # Block until response is available + if not response.success: + raise ValueError(f"Request failed with error: {response.error}") + self.client_records[client_id]["history"] += response.generated_text + self.client_records[client_id]["round"] += 1 + self.performance_metrics["ttft"].append(response.ttft) + self.performance_metrics["latency"].append(response.latency) + + if self.client_records[client_id]["round"] < args.num_rounds: + self.client_records[client_id][ + "history" + ] += self.candidate_inputs.pop() + self.ready_queue.append( + ( + client_id, + gen_payload( + self.client_records[client_id]["history"], + args.output_length, + ), + ) + ) + except queue.Empty: + if self.pbar.n == self.pbar.total: + break + + def run(self): + request_thread = threading.Thread(target=self.request_sender, daemon=True) + response_thread = threading.Thread(target=self.response_handler, daemon=True) + + self.start_time = time.time() + request_thread.start() + response_thread.start() + + request_thread.join() + response_thread.join() + + self.pbar.close() + print("All requests completed.") + print("Performance metrics summary:") + print( + f" Total requests: {len(self.performance_metrics['ttft'])} at {self.request_rate} requests per second" + ) + print( + f" Average TTFT: {sum(self.performance_metrics['ttft']) / len(self.performance_metrics['ttft']):.2f}" + ) + print( + f" Median TTFT: {sorted(self.performance_metrics['ttft'])[len(self.performance_metrics['ttft']) // 2]:.2f}" + ) + print( + f" Average latency: {sum(self.performance_metrics['latency']) / len(self.performance_metrics['latency']):.2f}" + ) + print( + f" Median latency: {sorted(self.performance_metrics['latency'])[len(self.performance_metrics['latency']) // 2]:.2f}" + ) + throughput = self.pbar.total / (self.finished_time - self.start_time) + print(f"Throughput: {throughput:.2f} requests per second") + + +if __name__ == "__main__": + args = parse_args() + flush_cache_url = f"http://{args.host}:{args.port}/flush_cache" + + for request_rate in range(1, 41, 2): + args.request_rate = request_rate + requests.post(flush_cache_url) + WorkloadGenerator(args).run() diff --git a/benchmark/kernels/decoding_attention_triton/triton_flashinfer_cudnn.py b/benchmark/kernels/decoding_attention_triton/triton_flashinfer_cudnn.py new file mode 100644 index 00000000000..f8c87d48db7 --- /dev/null +++ b/benchmark/kernels/decoding_attention_triton/triton_flashinfer_cudnn.py @@ -0,0 +1,405 @@ +import itertools +import math + +import cudnn +import torch +import torch.utils.benchmark as benchmark +import triton +import triton.language as tl +from flashinfer import BatchDecodeWithPagedKVCacheWrapper + +from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd +from sglang.srt.utils import should_use_tensor_core + + +def benchmark_forward( + fn, + *inputs, + repeats=10, + amp=False, + amp_dtype=torch.float16, + **kwinputs, +): + def amp_wrapper(*inputs, **kwinputs): + with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp): + fn(*inputs, **kwinputs) + + t = benchmark.Timer( + stmt="fn_amp(*inputs, **kwinputs)", + globals={"fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs}, + num_threads=torch.get_num_threads(), + ) + m = t.timeit(repeats) + return t, m + + +def time_fwd(func, *args, **kwargs): + time_f = benchmark_forward(func, *args, **kwargs) + return time_f[1].mean * 1e6 + + +def decode_attention_sglang( + q, + kv_data, + batch_size, + kv_len, + head_num_q, + head_num_kv, + head_dim, + num_kv_splits, + warmup=10, +): + + k_buffer = kv_data[0].view(-1, head_num_kv, head_dim) + v_buffer = kv_data[1].view(-1, head_num_kv, head_dim) + o = torch.empty_like(q) + total_tokens = batch_size * kv_len + req_to_token = torch.arange(0, total_tokens).to(0).int().view(batch_size, kv_len) + b_req_idx = torch.arange(0, batch_size).to(0).int() + b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32, device="cuda") + max_len_in_batch = kv_len + sm_scale = 1.0 / (head_dim**0.5) + + attn_logits = torch.empty( + (batch_size, head_num_q, num_kv_splits, head_dim + 1), + dtype=torch.float32, + device="cuda", + ) + + for _ in range(warmup): + decode_attention_fwd( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_req_idx, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + ) + + f = time_fwd( + decode_attention_fwd, + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_req_idx, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + ) + + return f, o + + +def decode_attention_flashinfer(dtype, head_num_q, head_num_kv): + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda") + use_tensor_cores = should_use_tensor_core( + kv_cache_dtype=dtype, + num_attention_heads=head_num_q, + num_kv_heads=head_num_kv, + ) + flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores + ) + + class FlashinferAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + kv_data, + batch_size, + kv_len, + head_num_q, + head_num_kv, + head_dim, + dtype, + warmup=10, + ): + total_tokens = batch_size * kv_len + kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len + kv_indices = torch.arange(0, total_tokens).to(0).int() + kv_last_page_len = torch.full( + (batch_size,), 1, dtype=torch.int32, device="cuda" + ) + + flashinfer_decode_wrapper.end_forward() + flashinfer_decode_wrapper.begin_forward( + kv_indptr, + kv_indices, + kv_last_page_len, + head_num_q, + head_num_kv, + head_dim, + 1, + pos_encoding_mode="NONE", + data_type=dtype, + ) + + for _ in range(warmup): + o = flashinfer_decode_wrapper.forward( + q.contiguous().view(-1, head_num_q, head_dim), kv_data + ) + + f = time_fwd( + flashinfer_decode_wrapper.forward, + q.contiguous().view(-1, head_num_q, head_dim), + kv_data, + ) + + return f, o + + return FlashinferAttention + + +def convert_to_cudnn_type(torch_type): + if torch_type == torch.float16: + return cudnn.data_type.HALF + elif torch_type == torch.bfloat16: + return cudnn.data_type.BFLOAT16 + elif torch_type == torch.float32: + return cudnn.data_type.FLOAT + elif torch_type == torch.int32: + return cudnn.data_type.INT32 + elif torch_type == torch.int64: + return cudnn.data_type.INT64 + else: + raise ValueError("Unsupported tensor data type.") + + +def decode_attention_cudnn( + q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype, warmup=10 +): + # Prepare data: continuous q,k,v + dims_q = (batch_size, head_num_q, 1, head_dim) + strides_q = (head_num_q * head_dim, head_dim, head_num_q * head_dim, 1) + q_gpu = q.as_strided(dims_q, strides_q) + o_gpu = ( + torch.empty(batch_size * head_num_q * head_dim) + .half() + .cuda() + .as_strided(dims_q, strides_q) + ) + + dims_kv = (batch_size, head_num_kv, kv_len, head_dim) + strides_kv = ( + kv_len * head_num_kv * head_dim, + head_dim, + head_num_kv * head_dim, + 1, + ) + k_gpu = kv_data[0].as_strided(dims_kv, strides_kv) + v_gpu = kv_data[1].as_strided(dims_kv, strides_kv) + + seq_len_q_gpu = torch.full((batch_size, 1, 1, 1), 1, device="cuda") + seq_len_kv_gpu = torch.full((batch_size, 1, 1, 1), kv_len, device="cuda") + attn_scale = 1.0 / (head_dim**0.5) + + # Prepare data: paged k,v + block_size = 1 + blocks_per_batch = math.ceil(kv_len / block_size) + # [num_blocks, head_num_kv, block_size, head_dim], num_blocks = batch_size * blocks_per_batch + container_k_gpu = torch.cat(k_gpu.chunk(blocks_per_batch, dim=2), dim=0) + container_v_gpu = torch.cat(v_gpu.chunk(blocks_per_batch, dim=2), dim=0) + page_table_k_gpu = ( + torch.linspace( + 0, + batch_size * blocks_per_batch - 1, + batch_size * blocks_per_batch, + device="cuda", + dtype=torch.int32, + ) + .reshape(blocks_per_batch, 1, batch_size, 1) + .transpose(0, 2) + ) + page_table_v_gpu = page_table_k_gpu.clone() + + graph = cudnn.pygraph( + io_data_type=convert_to_cudnn_type(dtype), + intermediate_data_type=cudnn.data_type.FLOAT, + compute_data_type=cudnn.data_type.FLOAT, + ) + + q = graph.tensor_like(q_gpu) + container_k = graph.tensor_like(container_k_gpu) + container_v = graph.tensor_like(container_v_gpu) + page_table_k = graph.tensor_like(page_table_k_gpu) + page_table_v = graph.tensor_like(page_table_v_gpu) + + seq_len_q = graph.tensor_like(seq_len_q_gpu) + seq_len_kv = graph.tensor_like(seq_len_kv_gpu) + + o, _ = graph.sdpa( + name="sdpa", + q=q, + k=container_k, # Container K: non contiguous container with K blocks + v=container_v, # Container V: non contiguous container with V blocks + is_inference=True, + attn_scale=attn_scale, + use_causal_mask=False, + use_padding_mask=True, + seq_len_q=seq_len_q, + seq_len_kv=seq_len_kv, + paged_attention_k_table=page_table_k, # Page Table K: Tensor containing offsets to the container with K blocks + paged_attention_v_table=page_table_v, # Page Table V: Tensor containing offsets to the container with V blocks + paged_attention_max_seq_len_kv=kv_len, # The maximum sequence length for K caches (this is optional, but recommended) + ) + + o.set_output(True).set_dim(dims_q).set_stride(strides_q) + + graph.validate() + graph.build_operation_graph() + graph.create_execution_plans([cudnn.heur_mode.A]) + graph.check_support() + graph.build_plans() + + workspace = torch.empty( + graph.get_workspace_size(), device="cuda", dtype=torch.uint8 + ) + + variant_pack = { + q: q_gpu, + container_k: container_k_gpu, + container_v: container_v_gpu, + page_table_k: page_table_k_gpu, + page_table_v: page_table_v_gpu, + seq_len_q: seq_len_q_gpu, + seq_len_kv: seq_len_kv_gpu, + o: o_gpu, + } + + for _ in range(warmup): + graph.execute(variant_pack, workspace) + + f = time_fwd( + graph.execute, + variant_pack, + workspace, + ) + + return f, o_gpu.squeeze(dim=2) + + +def calculate_diff(): + + dtype = torch.float16 + batch_size = 64 + kv_len = 4096 + head_num_q = 64 + head_num_kv = 8 + head_dim = 128 + + q = torch.randn(batch_size, head_num_q, head_dim, dtype=dtype, device="cuda") + kv_data = ( + torch.randn( + batch_size * kv_len, head_num_kv, head_dim, dtype=dtype, device="cuda" + ), + torch.randn( + batch_size * kv_len, head_num_kv, head_dim, dtype=dtype, device="cuda" + ), + ) + + _, output_sglang = decode_attention_sglang( + q, + kv_data, + batch_size, + kv_len, + head_num_q, + head_num_kv, + head_dim, + num_kv_splits=8, + ) + + attn_flashinfer = decode_attention_flashinfer(dtype, head_num_q, head_num_kv).apply + _, output_flashinfer = attn_flashinfer( + q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype + ) + + _, output_cudnn = decode_attention_cudnn( + q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype + ) + + print(f"SGLang output={output_sglang}") + print(f"FlashInfer output={output_flashinfer}") + print(f"cuDNN output={output_cudnn}") + if torch.allclose(output_sglang, output_flashinfer, atol=1e-2, rtol=1e-2): + print("✅ SGLang[Triton] and FlashInfer match") + else: + print("❌ SGLang[Triton] and FlashInfer differ") + + if torch.allclose(output_sglang, output_cudnn, atol=1e-2, rtol=1e-2): + print("✅ SGLang[Triton] and cuDNN match") + else: + print("❌ SGLang[Triton] and cuDNN differ") + + +if __name__ == "__main__": + calculate_diff() + + head_dim = 128 + dtype = torch.float16 + batch_size_range = [2**i for i in range(0, 8, 2)] + kv_len_range = [2**i for i in range(6, 13, 1)] + configs = list(itertools.product(batch_size_range, kv_len_range)) + + for head_num_q, head_num_kv in [[32, 32], [64, 8], [40, 8]]: + attn_flashinfer = decode_attention_flashinfer( + dtype, head_num_q, head_num_kv + ).apply + for batch_size, kv_len in configs: + q = torch.randn( + batch_size, head_num_q, head_dim, dtype=dtype, device="cuda" + ) + kv_data = ( + torch.randn( + batch_size * kv_len, + head_num_kv, + head_dim, + dtype=dtype, + device="cuda", + ), + torch.randn( + batch_size * kv_len, + head_num_kv, + head_dim, + dtype=dtype, + device="cuda", + ), + ) + us_cudnn, output_cudnn = decode_attention_cudnn( + q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype + ) + us_sglang, output_sglang = decode_attention_sglang( + q, + kv_data, + batch_size, + kv_len, + head_num_q, + head_num_kv, + head_dim, + num_kv_splits=8, + ) + us_flashinfer, _ = attn_flashinfer( + q, kv_data, batch_size, kv_len, head_num_q, head_num_kv, head_dim, dtype + ) + print( + head_num_q, + " ", + head_num_kv, + " ", + batch_size, + " ", + kv_len, + " ", + us_cudnn, + " ", + us_sglang, + " ", + us_flashinfer, + ) diff --git a/benchmark/kernels/fused_moe_triton/README.md b/benchmark/kernels/fused_moe_triton/README.md index ba29ede5099..2a3e37f6874 100644 --- a/benchmark/kernels/fused_moe_triton/README.md +++ b/benchmark/kernels/fused_moe_triton/README.md @@ -10,7 +10,7 @@ Example usage: ```bash # Tune Qwen2-57B with FP8 and TP=4 python benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py \ - --model Qwen/Qwen2-57B-A14B-Instruct-FP8 \ + --model Qwen/Qwen2-57B-A14B-Instruct \ --tp-size 4 \ --dtype fp8_w8a8 \ --tune @@ -34,7 +34,7 @@ python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_tri # Compare with FP8 mode for Qwen2-57B python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py \ - --model Qwen/Qwen2-57B-A14B-Instruct-FP8 \ + --model Qwen/Qwen2-57B-A14B-Instruct \ --use-fp8 # Compare with custom TP size @@ -43,3 +43,7 @@ python benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_tri ``` The benchmark results will be saved as plots and data files in the specified output directory (default: `./configs/benchmark_ops/vllm_sglang_fused_moe/`). + +- `benchmark_torch_compile_fused_moe.py`: A tool for benchmarking the performance of the fused MoE kernel with `torch.compile` and original fused MoE kernel. + +Usage is the same as `benchmark_vllm_vs_sglang_fused_moe_triton.py`. diff --git a/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py b/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py new file mode 100644 index 00000000000..e2c4d8d3506 --- /dev/null +++ b/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py @@ -0,0 +1,313 @@ +import argparse +import itertools + +import torch +import triton +import triton.language as tl +from sgl_kernel import moe_align_block_size + +USE_RANDOM_PERM = False + + +def ceil_div(a, b): + return (a + b - 1) // b + + +@triton.jit +def moe_align_block_size_stage1( + topk_ids_ptr, + tokens_cnts_ptr, + num_experts: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + start_idx = pid * tokens_per_thread + off_c = (pid + 1) * num_experts + + for i in range(tokens_per_thread): + if start_idx + i < numel: + idx = tl.load(topk_ids_ptr + start_idx + i) + token_cnt = tl.load(tokens_cnts_ptr + off_c + idx) + tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1) + + +@triton.jit +def moe_align_block_size_stage2( + tokens_cnts_ptr, + num_experts: tl.constexpr, +): + pid = tl.program_id(0) + last_cnt = 0 + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) + last_cnt = last_cnt + token_cnt + tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) + + +@triton.jit +def moe_align_block_size_stage3( + total_tokens_post_pad_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, +): + last_cumsum = 0 + off_cnt = num_experts * num_experts + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1) + last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size + tl.store(cumsum_ptr + i, last_cumsum) + tl.store(total_tokens_post_pad_ptr, last_cumsum) + + +@triton.jit +def moe_align_block_size_stage4( + topk_ids_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + start_idx = tl.load(cumsum_ptr + pid) + end_idx = tl.load(cumsum_ptr + pid + 1) + + for i in range(start_idx, end_idx, block_size): + tl.store(expert_ids_ptr + i // block_size, pid) + + start_idx = pid * tokens_per_thread + off_t = pid * num_experts + + for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)): + expert_id = tl.load(topk_ids_ptr + i) + token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id) + rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id) + tl.store(sorted_token_ids_ptr + rank_post_pad, i) + tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1) + + +def moe_align_block_size_triton( + topk_ids: torch.Tensor, + num_experts: int, + block_size: int, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + numel = topk_ids.numel() + grid = (num_experts,) + tokens_cnts = torch.zeros( + (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device + ) + cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device) + tokens_per_thread = ceil_div(numel, num_experts) + + moe_align_block_size_stage1[grid]( + topk_ids, + tokens_cnts, + num_experts, + numel, + tokens_per_thread, + ) + moe_align_block_size_stage2[grid]( + tokens_cnts, + num_experts, + ) + moe_align_block_size_stage3[(1,)]( + num_tokens_post_pad, + tokens_cnts, + cumsum, + num_experts, + block_size, + ) + moe_align_block_size_stage4[grid]( + topk_ids, + sorted_token_ids, + expert_ids, + tokens_cnts, + cumsum, + num_experts, + block_size, + numel, + tokens_per_thread, + ) + + +def calculate_diff(batch_size, seq_len): + num_experts = 256 + block_size = 128 + topk = 8 + + topk_ids = torch.stack( + [ + torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk] + for _ in range(batch_size * seq_len) + ] + ) + + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + sorted_ids_cuda = torch.empty( + (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device + ) + sorted_ids_cuda.fill_(topk_ids.numel()) + max_num_m_blocks = max_num_tokens_padded // block_size + expert_ids_cuda = torch.empty( + (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device + ) + num_tokens_post_pad_cuda = torch.empty( + (1), dtype=torch.int32, device=topk_ids.device + ) + token_cnts_buffer = torch.empty( + (num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device + ) + cumsum_buffer = torch.empty( + num_experts + 1, dtype=torch.int32, device=topk_ids.device + ) + + sorted_ids_triton = torch.empty_like(sorted_ids_cuda) + sorted_ids_triton.fill_(topk_ids.numel()) + expert_ids_triton = torch.empty_like(expert_ids_cuda) + num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda) + + # compare the performance of cuda and triton implementation + moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids_cuda, + expert_ids_cuda, + num_tokens_post_pad_cuda, + token_cnts_buffer, + cumsum_buffer, + ) + moe_align_block_size_triton( + topk_ids, + num_experts, + block_size, + sorted_ids_triton, + expert_ids_triton, + num_tokens_post_pad_triton, + ) + + if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose( + num_tokens_post_pad_cuda, num_tokens_post_pad_triton + ): + print("✅ CUDA and Triton implementations match") + else: + print("❌ CUDA and Triton implementations do not match") + print("CUDA expert_ids:", expert_ids_cuda) + print("Triton expert_ids:", expert_ids_triton) + print("CUDA num_tokens_post_pad:", num_tokens_post_pad_cuda) + print("Triton num_tokens_post_pad:", num_tokens_post_pad_triton) + + +batch_size_range = [2**i for i in range(0, 8)] +seq_length_range = [2**i for i in range(0, 16)] +configs = list(itertools.product(batch_size_range, seq_length_range)) + + +def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor: + topk_ids = torch.zeros((num_tokens, topk), dtype=torch.int32, device="cuda") + for i in range(num_tokens): + topk_ids[i, :] = torch.randperm(num_experts, dtype=torch.int32, device="cuda")[ + :topk + ] + return topk_ids + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "seq_len"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["cuda", "triton"], + line_names=["CUDA", "Triton"], + styles=[("blue", "-"), ("red", "-")], + ylabel="us", + plot_name="moe-align-block-size-performance", + args={}, + ) +) +def benchmark(batch_size, seq_len, provider): + num_experts = 256 + block_size = 128 + topk = 8 + + if USE_RANDOM_PERM: + topk_ids = get_topk_ids(batch_size * seq_len, num_experts, topk) + else: + topk_ids = torch.randint( + 0, + num_experts, + (batch_size * seq_len, topk), + dtype=torch.int32, + device="cuda", + ) + + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + sorted_ids = torch.empty( + (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device + ) + sorted_ids.fill_(topk_ids.numel()) + max_num_m_blocks = max_num_tokens_padded // block_size + expert_ids = torch.empty( + (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device + ) + num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) + token_cnts_buffer = torch.empty( + (num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device + ) + cumsum_buffer = torch.empty( + num_experts + 1, dtype=torch.int32, device=topk_ids.device + ) + + quantiles = [0.5, 0.2, 0.8] + if provider == "cuda": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids.clone(), + expert_ids.clone(), + num_tokens_post_pad.clone(), + token_cnts_buffer, + cumsum_buffer, + ), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: moe_align_block_size_triton( + topk_ids, + num_experts, + block_size, + sorted_ids.clone(), + expert_ids.clone(), + num_tokens_post_pad.clone(), + ), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_path", + type=str, + default="./configs/benchmark_ops/moe_align_blocks/", + help="Path to save moe align benchmark results", + ) + args = parser.parse_args() + + calculate_diff(batch_size=4, seq_len=1024) + + benchmark.run(print_data=True, save_path=args.save_path) diff --git a/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py b/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py index 1bd6eec1645..b81a2280024 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_torch_compile_fused_moe.py @@ -5,13 +5,15 @@ from torch.nn import functional as F from transformers import AutoConfig -from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_triton +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( + fused_moe as fused_moe_triton, +) from sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config def get_model_config(model_name: str, tp_size: int): """Get model configuration parameters""" - config = AutoConfig.from_pretrained(model_name) + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) if config.architectures[0] == "DbrxForCausalLM": E = config.ffn_config.moe_num_experts @@ -28,6 +30,11 @@ def get_model_config(model_name: str, tp_size: int): topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // tp_size + elif config.architectures[0] == "DeepseekV2ForCausalLM": + E = config.n_routed_experts + topk = config.num_experts_per_tok + intermediate_size = config.intermediate_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size else: # Default: Mixtral E = config.num_local_experts diff --git a/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py index 7bfb2731b98..faf5c6b4e78 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_vllm_vs_sglang_fused_moe_triton.py @@ -5,12 +5,14 @@ from transformers import AutoConfig from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm -from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_sglang +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( + fused_moe as fused_moe_sglang, +) def get_model_config(model_name: str, tp_size: int): """Get model configuration parameters""" - config = AutoConfig.from_pretrained(model_name) + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) if config.architectures[0] == "DbrxForCausalLM": E = config.ffn_config.moe_num_experts @@ -27,6 +29,11 @@ def get_model_config(model_name: str, tp_size: int): topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // tp_size + elif config.architectures[0] == "DeepseekV2ForCausalLM": + E = config.n_routed_experts + topk = config.num_experts_per_tok + intermediate_size = config.intermediate_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size else: # Default: Mixtral E = config.num_local_experts diff --git a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py index 6f6a57be1de..249401d0910 100644 --- a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py +++ b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py @@ -11,13 +11,16 @@ from ray.experimental.tqdm_ray import tqdm from transformers import AutoConfig -from sglang.srt.layers.fused_moe_triton.fused_moe import ( +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( fused_moe, get_config_dtype_str, get_config_file_name, get_default_config, get_moe_configs, ) +from sglang.srt.utils import is_hip + +_is_hip_ = is_hip() class BenchmarkConfig(TypedDict): @@ -39,6 +42,7 @@ def benchmark_config( dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + block_shape: List[int] = None, num_iters: int = 100, ) -> float: init_dtype = torch.float16 if use_fp8_w8a8 else dtype @@ -83,13 +87,26 @@ def benchmark_config( ) w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32) if use_fp8_w8a8: - w1_scale = torch.randn(num_experts, dtype=torch.float32) - w2_scale = torch.randn(num_experts, dtype=torch.float32) - a1_scale = torch.randn(1, dtype=torch.float32) - a2_scale = torch.randn(1, dtype=torch.float32) + if block_shape is None: + w1_scale = torch.randn(num_experts, dtype=torch.float32) + w2_scale = torch.randn(num_experts, dtype=torch.float32) + a1_scale = torch.randn(1, dtype=torch.float32) + a2_scale = torch.randn(1, dtype=torch.float32) + else: + block_n, block_k = block_shape[0], block_shape[1] + n_tiles_w1 = (shard_intermediate_size + block_n - 1) // block_n + n_tiles_w2 = (hidden_size + block_n - 1) // block_n + k_tiles_w1 = (hidden_size + block_k - 1) // block_k + k_tiles_w2 = (shard_intermediate_size // 2 + block_k - 1) // block_k + w1_scale = torch.rand( + (num_experts, n_tiles_w1, k_tiles_w1), dtype=torch.float32 + ) + w2_scale = torch.rand( + (num_experts, n_tiles_w2, k_tiles_w2), dtype=torch.float32 + ) - w1 = w1.to(torch.float8_e4m3fn) - w2 = w2.to(torch.float8_e4m3fn) + w1 = w1.to(torch.float8_e4m3fnuz if _is_hip_ else torch.float8_e4m3fn) + w2 = w2.to(torch.float8_e4m3fnuz if _is_hip_ else torch.float8_e4m3fn) input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32) @@ -97,7 +114,7 @@ def prepare(i: int): input_gating.copy_(gating_output[i]) def run(): - from sglang.srt.layers.fused_moe_triton import override_config + from sglang.srt.layers.moe.fused_moe_triton import override_config with override_config(config): fused_moe( @@ -114,6 +131,7 @@ def run(): w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, + block_shape=block_shape, ) # JIT compilation & warmup @@ -150,17 +168,15 @@ def run(): return avg -def get_configs_compute_bound() -> List[Dict[str, int]]: - # Reduced search space for faster tuning. - # TODO(woosuk): Increase the search space and use a performance model to - # prune the search space. +def get_rocm_configs_compute_bound() -> List[Dict[str, int]]: configs: List[BenchmarkConfig] = [] - for num_stages in [2, 3, 4, 5]: - for block_m in [16, 32, 64, 128, 256]: - for block_k in [64, 128, 256]: - for block_n in [32, 64, 128, 256]: + waves_per_eu_range = 0 + for num_stages in [2]: + for block_m in [32, 64, 128, 256]: + for block_k in [32, 64, 128, 256]: + for block_n in [16, 32, 64, 128, 256]: for num_warps in [4, 8]: - for group_size in [1, 16, 32, 64]: + for group_size in [1, 4, 8, 16, 32]: configs.append( { "BLOCK_SIZE_M": block_m, @@ -169,11 +185,39 @@ def get_configs_compute_bound() -> List[Dict[str, int]]: "GROUP_SIZE_M": group_size, "num_warps": num_warps, "num_stages": num_stages, + "waves_per_eu": waves_per_eu_range, } ) return configs +def get_configs_compute_bound() -> List[Dict[str, int]]: + # Reduced search space for faster tuning. + # TODO(woosuk): Increase the search space and use a performance model to + # prune the search space. + configs: List[BenchmarkConfig] = [] + if _is_hip_: + configs = get_rocm_configs_compute_bound() + else: + for num_stages in [2, 3, 4, 5]: + for block_m in [16, 32, 64, 128, 256]: + for block_k in [64, 128, 256]: + for block_n in [32, 64, 128, 256]: + for num_warps in [4, 8]: + for group_size in [1, 16, 32, 64]: + configs.append( + { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_size, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) + return configs + + @ray.remote(num_gpus=1) class BenchmarkWorker: @@ -192,6 +236,7 @@ def benchmark( dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + block_shape: List[int], ) -> Tuple[Dict[str, int], float]: torch.cuda.manual_seed_all(0) dtype_str = get_config_dtype_str( @@ -199,8 +244,10 @@ def benchmark( ) # NOTE(woosuk): The current naming convention uses w2.shape[2], which # is the intermediate size after silu_and_mul. + block_n = block_shape[0] if block_shape else 0 + block_k = block_shape[1] if block_shape else 0 op_config = get_moe_configs( - num_experts, shard_intermediate_size // 2, dtype_str + num_experts, shard_intermediate_size // 2, dtype_str, block_n, block_k ) if op_config is None: config = get_default_config( @@ -210,6 +257,7 @@ def benchmark( hidden_size, topk, dtype_str, + False, ) else: config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))] @@ -223,6 +271,7 @@ def benchmark( dtype, use_fp8_w8a8, use_int8_w8a16, + block_shape, ) return config, kernel_time @@ -236,6 +285,7 @@ def tune( dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + block_shape: List[int], search_space: List[Dict[str, int]], ) -> Dict[str, int]: best_config = None @@ -252,6 +302,7 @@ def tune( dtype, use_fp8_w8a8, use_int8_w8a16, + block_shape, num_iters=10, ) except triton.runtime.autotuner.OutOfResources: @@ -275,6 +326,9 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: "GROUP_SIZE_M": config["GROUP_SIZE_M"], "num_warps": config["num_warps"], "num_stages": config["num_stages"], + **( + {"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {} + ), } @@ -287,6 +341,7 @@ def save_configs( dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + block_shape: List[int], ) -> None: dtype_str = get_config_dtype_str( dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 @@ -295,7 +350,10 @@ def save_configs( # NOTE(woosuk): The current naming convention uses w2.shape[2], which # is the intermediate size after silu_and_mul. filename = get_config_file_name( - num_experts, shard_intermediate_size // 2, dtype_str + num_experts, + shard_intermediate_size // 2, + dtype_str, + block_shape, ) print(f"Writing best config to {filename}...") @@ -307,7 +365,7 @@ def save_configs( def main(args: argparse.Namespace): print(args) - config = AutoConfig.from_pretrained(args.model) + config = AutoConfig.from_pretrained(args.model, trust_remote_code=True) if config.architectures[0] == "DbrxForCausalLM": E = config.ffn_config.moe_num_experts topk = config.ffn_config.moe_top_k @@ -323,6 +381,11 @@ def main(args: argparse.Namespace): topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size + elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: + E = config.n_routed_experts + topk = config.num_experts_per_tok + intermediate_size = config.moe_intermediate_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size else: # Default: Mixtral E = config.num_local_experts @@ -334,6 +397,13 @@ def main(args: argparse.Namespace): dtype = config.torch_dtype use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_int8_w8a16 = args.dtype == "int8_w8a16" + block_shape = None + if ( + hasattr(config, "quantization_config") + and "weight_block_size" in config.quantization_config + ): + block_shape = config.quantization_config["weight_block_size"] + assert len(block_shape) == 2 if args.batch_size is None: batch_sizes = [ @@ -376,6 +446,13 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]: if args.tune: search_space = get_configs_compute_bound() + if block_shape is not None: + block_n, block_k = block_shape[0], block_shape[1] + search_space = [ + config + for config in search_space + if block_k % config["BLOCK_SIZE_K"] == 0 + ] print(f"Start tuning over {len(search_space)} configurations...") start = time.time() @@ -391,6 +468,7 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]: dtype, use_fp8_w8a8, use_int8_w8a16, + block_shape, search_space, ) for batch_size in batch_sizes @@ -408,6 +486,7 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]: dtype, use_fp8_w8a8, use_int8_w8a16, + block_shape, ) end = time.time() print(f"Tuning took {end - start:.2f} seconds") @@ -424,6 +503,7 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]: dtype, use_fp8_w8a8, use_int8_w8a16, + block_shape, ) for batch_size in batch_sizes ], diff --git a/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py b/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py new file mode 100644 index 00000000000..57fbcfddf2c --- /dev/null +++ b/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py @@ -0,0 +1,577 @@ +import itertools +import math +import os +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl +from einops import rearrange +from sgl_kernel import lightning_attention_decode as sgl_lightning_attention_decode + + +@triton.jit +def _decode_kernel( + Q, + K, + V, + KV, + Out, + S, + b: tl.constexpr, + h: tl.constexpr, + n: tl.constexpr, + d: tl.constexpr, + d_original: tl.constexpr, + e: tl.constexpr, + e_original: tl.constexpr, +): + off_bh = tl.program_id(0) + off_h = off_bh % h + + qk_offset = off_bh * n * d + v_offset = off_bh * n * e + o_offset = off_bh * n * e + kv_offset = off_bh * d * e + + s = tl.load(S + off_h) + ratio = tl.exp(-s) + + d_idx = tl.arange(0, d) + e_idx = tl.arange(0, e) + + # Create masks for original dimensions + d_mask = d_idx < d_original + e_mask = e_idx < e_original + + # Load with masking + q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0) + k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0) + v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0) + + # Load KV with 2D masking + kv = tl.load( + KV + kv_offset + d_idx[:, None] * e + e_idx[None, :], + mask=(d_mask[:, None] & e_mask[None, :]), + other=0.0, + ) + + # Compute outer product using element-wise operations + k_v_prod = k[:, None] * v[None, :] + kv = ratio * kv + k_v_prod + + # Store KV with 2D masking + tl.store( + KV + kv_offset + d_idx[:, None] * e + e_idx[None, :], + kv.to(KV.dtype.element_ty), + mask=(d_mask[:, None] & e_mask[None, :]), + ) + + # Compute matrix-vector multiplication using element-wise operations and reduction + o = tl.sum(q[:, None] * kv, axis=0) + + # Store output with masking + tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask) + + +def lightning_attn_decode(q, k, v, kv, s): + """Triton implementation of Lightning Attention decode operation""" + b, h, n, d = q.shape + e = v.shape[-1] + assert n == 1, "Sequence length must be 1 in decode mode" + + # Get padded dimensions (power of 2) + d_padded = next_power_of_2(d) + e_padded = next_power_of_2(e) + + # Create output tensor (padded) + o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device) + + # Create padded tensors without actually padding the data + q_padded = torch.empty(b, h, n, d_padded, dtype=q.dtype, device=q.device) + k_padded = torch.empty(b, h, n, d_padded, dtype=k.dtype, device=k.device) + v_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device) + kv_padded = torch.empty( + b, h, d_padded, e_padded, dtype=torch.float32, device=kv.device + ) + + # Copy data to padded tensors + q_padded[..., :d] = q + k_padded[..., :d] = k + v_padded[..., :e] = v + kv_padded[..., :d, :e] = kv + + # Launch kernel + grid = (b * h, 1) + _decode_kernel[grid]( + q_padded, + k_padded, + v_padded, + kv_padded, + o_padded, + s, + b=b, + h=h, + n=n, + d=d_padded, + d_original=d, + e=e_padded, + e_original=e, + ) + + # Get unpadded outputs + o = o_padded[..., :e] + kv_out = kv_padded[..., :d, :e] + + return o, kv_out + + +def next_power_of_2(n): + return 2 ** (int(math.ceil(math.log(n, 2)))) + + +class MiniMaxText01LightningAttention(nn.Module): + def __init__(self, config=None, layer_idx: Optional[int] = None, **kwargs): + super().__init__() + if config is None: + config = type("Config", (), kwargs) + + bias = False + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + + self.out_proj = nn.Linear( + self.head_dim * self.num_heads, self.hidden_size, bias=bias + ) + self.act = get_activation_fn(config.hidden_act) + self.norm = MiniMaxText01RMSNorm(self.head_dim * self.num_heads) + + self.qkv_proj = nn.Linear( + self.hidden_size, 3 * self.head_dim * self.num_heads, bias=bias + ) + self.output_gate = nn.Linear( + self.hidden_size, self.head_dim * self.num_heads, bias=bias + ) + + # for inference only + self.offset = 0 + self.layer_idx = layer_idx + + def forward( + self, + hidden_states, + attn_mask: Optional[torch.Tensor] = None, # (b, h, n, m) + output_attentions: bool = False, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + use_cache: bool = False, + slope_rate: Optional[torch.Tensor] = None, + **kwargs, + ): + if (not self.training) and (not do_eval): + return self.inference( + hidden_states, + attn_mask, + output_attentions, + past_key_value, + use_cache, + slope_rate, + ) + + def inference( + self, + x, + attn_mask: Optional[torch.Tensor] = None, # (b, n) + output_attentions: bool = False, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + use_cache: bool = False, + slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1) + ): + # x: b n d + b, n, d = x.shape + # linear map + qkv = self.act(self.qkv_proj(x)) + new_shape = qkv.size()[:-1] + (self.num_heads, -1) + qkv = qkv.view(*new_shape) + q, k, v = torch.split(qkv, [self.head_dim] * 3, dim=3) + q = q.transpose(1, 2) # [b, n, h, d] -> [b, h, n, d] + k = k.transpose(1, 2) # [b, n, h, d] -> [b, h, n, d] + v = v.transpose(1, 2) # [b, n, h, d] -> [b, h, n, e] + + self.offset += 1 + ratio = torch.exp(-slope_rate) # [h, 1, 1] + + # decode mode + kv = past_key_value # [b, h, d, e] + output = [] + for i in range(n): + # kv: [b, h, d, e] + # ratio: [h, 1, 1] + # k: [b, h, n, d] + # v: [b, h, n, e] + # k[:, :, i : i + 1]: [b, h, 1, d] + # v[:, :, i : i + 1]: [b, h, 1, e] + # ratio * kv: [b, h, d, e] + # torch.einsum( + # "... n d, ... n e -> ... d e", + # k[:, :, i : i + 1], + # v[:, :, i : i + 1], + # ) + # [b, h, d, e] + [b, h, d, e] -> [b, h, d, e] + kv = ratio * kv + torch.einsum( + "... n d, ... n e -> ... d e", + k[:, :, i : i + 1], + v[:, :, i : i + 1], + ) + # q[:, :, i : i + 1]: [b, h, 1, d] + # kv.to(q.dtype): [b, h, d, e] + # torch.einsum( + # "... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype) + # ) + # [b, h, 1, d] * [b, h, d, e] -> [b, h, 1, e] + qkv = torch.einsum( + "... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype) + ) + output.append(qkv) + output = torch.concat(output, dim=-2) + + # reshape + output = rearrange(output, "b h n d -> b n (h d)") + # normalize + output = self.norm(output) + # gate + output = F.sigmoid(self.output_gate(x)) * output + # outproj + output = self.out_proj(output) + + attn_weights = None + + return output, attn_weights, kv + + +def get_activation_fn(activation): + if activation == "gelu": + return F.gelu + elif activation == "relu": + return F.relu + elif activation == "elu": + return F.elu + elif activation == "sigmoid": + return F.sigmoid + elif activation == "exp": + + def f(x): + with torch.no_grad(): + x_max = torch.max(x, dim=-1, keepdims=True).values + y = torch.exp(x - x_max) + return y + + return f + elif activation == "leak": + return F.leaky_relu + elif activation == "1+elu": + + def f(x): + return 1 + F.elu(x) + + return f + elif activation == "2+elu": + + def f(x): + return 2 + F.elu(x) + + return f + elif activation == "silu" or activation == "swish": + return F.silu + elif activation == "sine": + return torch.sin + else: + return lambda x: x + + +class MiniMaxText01RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MiniMaxText01RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +def test_lightning_attention_implementations(model_params): + torch.manual_seed(42) + + batch_size = 64 + seq_len = 1 + dtype = torch.bfloat16 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + hidden_states = torch.randn( + batch_size, seq_len, model_params["hidden_size"], dtype=dtype, device=device + ) + + attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device) + + slope_rate = _build_slope_tensor(model_params["num_attention_heads"]).to(device) + + model_attn = MiniMaxText01LightningAttention(**model_params).to(dtype).to(device) + model_attn.eval() + + d = model_params["head_dim"] + past_kv = torch.randn( + batch_size, + model_params["num_attention_heads"], + d, + d, + device=device, + ) + with torch.no_grad(): + model_output, _, new_kv = model_attn.inference( + hidden_states, + attn_mask=attention_mask, + slope_rate=slope_rate, + past_key_value=past_kv, + ) + + qkv = model_attn.act(model_attn.qkv_proj(hidden_states)) + new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1) + qkv = qkv.view(*new_shape) + q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + past_kv = past_kv.contiguous() + slope_rate = slope_rate.contiguous() + + # Test Triton implementation + triton_output, triton_new_kv = lightning_attn_decode(q, k, v, past_kv, slope_rate) + triton_output = triton_output.transpose(1, 2).contiguous() + triton_output = triton_output.view(batch_size, seq_len, -1) + triton_output = model_attn.norm(triton_output) + triton_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * triton_output + triton_output = model_attn.out_proj(triton_output) + + # Test SGL implementation + sgl_output = torch.empty_like(v) + sgl_new_kv = torch.empty_like(past_kv) + sgl_lightning_attention_decode(q, k, v, past_kv, slope_rate, sgl_output, sgl_new_kv) + + sgl_output = sgl_output.transpose(1, 2).contiguous() + sgl_output = sgl_output.view(batch_size, seq_len, -1) + sgl_output = model_attn.norm(sgl_output) + sgl_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * sgl_output + sgl_output = model_attn.out_proj(sgl_output) + + # Verify Triton implementation results + torch.testing.assert_close( + model_output, + triton_output, + rtol=1e-3, + atol=1e-2, + msg="Triton lightning attention implementation produces different output results", + ) + torch.testing.assert_close( + new_kv, + triton_new_kv, + rtol=1e-3, + atol=1e-2, + msg="Triton lightning attention implementation produces different kv results", + ) + + # Verify SGL implementation results + torch.testing.assert_close( + model_output, + sgl_output, + rtol=1e-3, + atol=1e-2, + msg="SGL lightning attention implementation produces different output results", + ) + torch.testing.assert_close( + new_kv, + sgl_new_kv, + rtol=1e-3, + atol=1e-2, + msg="SGL lightning attention implementation produces different kv results", + ) + + print("✅ All implementations match") + + +def _build_slope_tensor(n_attention_heads: int): + def get_slopes(n): + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + slopes = torch.tensor(get_slopes(n_attention_heads)).reshape( + n_attention_heads, 1, 1 + ) + return slopes + + +def get_benchmark(): + batch_size_range = [i for i in range(1, 33)] # max 32 + seq_length_range = [1] # decode mode sequence length is fixed to 1 + configs = list(itertools.product(batch_size_range, seq_length_range)) + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "seq_len"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["Original", "Triton", "SGL"], + line_names=[ + "Original PyTorch Implementation", + "Triton Implementation", + "SGL Implementation", + ], + styles=[("blue", "-"), ("green", "-"), ("red", "-")], + ylabel="us", + plot_name="lightning-attention-decode-performance", + args={}, + ) + ) + def benchmark(batch_size, seq_len, provider): + dtype = torch.bfloat16 + device = torch.device("cuda") + + params = { + "hidden_size": 6144, + "num_attention_heads": 64, + "head_dim": 96, + "hidden_act": "gelu", + } + + hidden_states = torch.randn( + batch_size, seq_len, params["hidden_size"], dtype=dtype, device=device + ) + + attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device) + + slope_rate = _build_slope_tensor(params["num_attention_heads"]).to(device) + model_attn = MiniMaxText01LightningAttention(**params).to(dtype).to(device) + model_attn.eval() + + d = params["head_dim"] + past_kv = torch.randn( + batch_size, + params["num_attention_heads"], + d, + d, + device=device, + ) + + quantiles = [0.5, 0.2, 0.8] + if provider == "Original": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: model_attn.inference( + hidden_states, + attn_mask=attention_mask, + slope_rate=slope_rate, + past_key_value=past_kv, + ), + quantiles=quantiles, + ) + elif provider == "Triton": + + def run_triton(): + qkv = model_attn.act(model_attn.qkv_proj(hidden_states)) + new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1) + qkv = qkv.view(*new_shape) + q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + output, new_kv = lightning_attn_decode(q, k, v, past_kv, slope_rate) + output = output.transpose(1, 2).contiguous() + output = output.view(batch_size, seq_len, -1) + output = model_attn.norm(output) + output = torch.sigmoid(model_attn.output_gate(hidden_states)) * output + return model_attn.out_proj(output) + + ms, min_ms, max_ms = triton.testing.do_bench( + run_triton, + quantiles=quantiles, + ) + else: # SGL + + def run_sgl(): + qkv = model_attn.act(model_attn.qkv_proj(hidden_states)) + new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1) + qkv = qkv.view(*new_shape) + q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1) + q = q.transpose(1, 2).contiguous() + k = k.transpose(1, 2).contiguous() + v = v.transpose(1, 2).contiguous() + + output = torch.empty_like(v) + new_kv = torch.empty_like(past_kv) + sgl_lightning_attention_decode( + q, k, v, past_kv, slope_rate, output, new_kv + ) + + output = output.transpose(1, 2).contiguous() + output = output.view(batch_size, seq_len, -1) + output = model_attn.norm(output) + output = torch.sigmoid(model_attn.output_gate(hidden_states)) * output + return model_attn.out_proj(output) + + ms, min_ms, max_ms = triton.testing.do_bench( + run_sgl, + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + return benchmark + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_path", + type=str, + default="./configs/benchmark_ops/lightning_attention_decode/", + help="Path to save lightning attention decode benchmark results", + ) + args = parser.parse_args() + + params = { + "hidden_size": 6144, + "num_attention_heads": 64, + "head_dim": 96, + "hidden_act": "silu", + } + # Run correctness test first + # Adapted from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/config.json + test_lightning_attention_implementations(params) + + # Run performance benchmark + benchmark = get_benchmark() + benchmark.run(print_data=True, save_path=args.save_path) diff --git a/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_prefill.py b/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_prefill.py new file mode 100644 index 00000000000..cd298487b59 --- /dev/null +++ b/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_prefill.py @@ -0,0 +1,603 @@ +import itertools +import math +import os +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl +from einops import rearrange + + +# Adapted from https://github.com/OpenNLPLab/lightning-attention/blob/main/lightning_attn/ops/triton/lightning_attn2.py +@triton.jit +def _fwd_kernel( + Q, + K, + V, + Out, + S, # log lambda + b: tl.constexpr, + h: tl.constexpr, + n: tl.constexpr, + d: tl.constexpr, + e: tl.constexpr, + BLOCK: tl.constexpr, + NUM_BLOCK: tl.constexpr, + BLOCK_MODEL: tl.constexpr, +): + ##### get offset + off_bh = tl.program_id(0) + off_h = off_bh % h + off_e = tl.program_id(1) + qk_offset = off_bh * n * d + v_offset = off_bh * n * e + o_offset = off_bh * n * e + # channel offset + e_offset = off_e * BLOCK_MODEL + + ##### get block ptr + Q_block_ptr = Q + qk_offset + tl.arange(0, d)[None, :] + K_trans_block_ptr = K + qk_offset + tl.arange(0, d)[:, None] + V_block_ptr = V + v_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :] + O_block_ptr = Out + o_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :] + S_block_ptr = S + off_h + + ##### init diag decay(Lambda); q, k decay; kv + s = tl.load(S_block_ptr) + # q, k decay + off_block = tl.arange( + 0, BLOCK + ) # Not bug, this is a bit different from algorithm 1, but is mathematically equivalent + q_decay = tl.exp(-s.to(tl.float32) * off_block[:, None]) + k_trans_decay = tl.exp(-s.to(tl.float32) * (BLOCK - off_block[None, :])) + block_decay = tl.exp(-s.to(tl.float32) * BLOCK) + # diag decay + index = off_block[:, None] - off_block[None, :] + s_index = s * index + s_index = tl.where(index >= 0, -s_index, float("-inf")) + diag_decay = tl.exp(s_index) + kv = tl.zeros([d, BLOCK_MODEL], dtype=tl.float32) + + ##### compute + for i in range(NUM_BLOCK): + # load + q = tl.load( + Q_block_ptr + off_block[:, None] * d, mask=off_block[:, None] < n, other=0.0 + ).to(tl.float32) + k_trans = tl.load( + K_trans_block_ptr + off_block[None, :] * d, + mask=off_block[None, :] < n, + other=0.0, + ).to(tl.float32) + v = tl.load( + V_block_ptr + off_block[:, None] * e, mask=off_block[:, None] < n, other=0.0 + ).to(tl.float32) + + # compute + qk = tl.dot(q, k_trans) * diag_decay + o_intra = tl.dot(qk, v) + o_inter = tl.dot(q, kv) * q_decay + o = o_intra + o_inter + + # save and update + tl.store( + O_block_ptr + off_block[:, None] * e, + o.to(O_block_ptr.dtype.element_ty), + mask=off_block[:, None] < n, + ) + kv = block_decay * kv + tl.dot(k_trans * k_trans_decay, v) + off_block += BLOCK + + +def lightning_attn2(q, k, v, s): + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + s = s.contiguous() + + b, h, n, d = q.shape + e = v.shape[-1] + + # Pad d to next power of 2 + d_padded = next_power_of_2(d) + if d_padded != d: + q_padded = F.pad(q, (0, d_padded - d)) + k_padded = F.pad(k, (0, d_padded - d)) + else: + q_padded = q + k_padded = k + + # Pad e to next power of 2 + e_padded = next_power_of_2(e) + if e_padded != e: + v_padded = F.pad(v, (0, e_padded - e)) + else: + v_padded = v + + o_padded = torch.empty((b, h, n, e_padded), dtype=q.dtype, device=q.device) + + BLOCK = 64 + NUM_BLOCK = triton.cdiv(q.shape[2], BLOCK) + # parallel over channel + BLOCK_MODEL = min(triton.next_power_of_2(e_padded), 32) + grid = (b * h, triton.cdiv(e_padded, BLOCK_MODEL)) + + _fwd_kernel[grid]( + q_padded, + k_padded, + v_padded, + o_padded, + s, + b, + h, + n, + d_padded, + e_padded, + BLOCK=BLOCK, + NUM_BLOCK=NUM_BLOCK, + BLOCK_MODEL=BLOCK_MODEL, + ) + + # Remove padding from output + if e_padded != e: + o = o_padded[..., :e] + else: + o = o_padded + + return o + + +def is_support(dim): + return 16 % dim + + +def next_power_of_2(n): + return 2 ** (int(math.ceil(math.log(n, 2)))) + + +def lightning_attn_func(q, k, v, s): + b, h, n, d = q.shape + e = v.shape[-1] + assert is_support(d) and is_support(e) + + # pad v's feature dim to power of 2 + e_pad = next_power_of_2(e) + need_pad = e_pad != e + if need_pad: + v = F.pad(v, (0, e_pad - e)) + + if d > 128: + # split over head + if 64 % d: + m = 64 + elif 32 % d: + m = 32 + elif 16 % d: + m = 16 + arr = [m * i for i in range(d // m + 1)] + if arr[-1] != d: + arr.append(d) + n = len(arr) + o = 0 + for i in range(n - 1): + start = arr[i] + end = arr[i + 1] + q1 = q[..., start:end] + k1 = k[..., start:end] + o += lightning_attn2(q1, k1, v, s) + else: + o = lightning_attn2(q, k, v, s) + + if need_pad: + o = o[:, :, :, :e] + + return o + + +debug = eval(os.environ.get("debug", default="False")) + +BLOCK = 256 + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MiniMaxText01 +class MiniMaxText01RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MiniMaxText01RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +# Copied from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/modeling_minimax_text_01.py +def get_activation_fn(activation): + if debug: + logger.info(f"activation: {activation}") + if activation == "gelu": + return F.gelu + elif activation == "relu": + return F.relu + elif activation == "elu": + return F.elu + elif activation == "sigmoid": + return F.sigmoid + elif activation == "exp": + + def f(x): + with torch.no_grad(): + x_max = torch.max(x, dim=-1, keepdims=True).values + y = torch.exp(x - x_max) + + return y + + return f + elif activation == "leak": + return F.leaky_relu + elif activation == "1+elu": + + def f(x): + return 1 + F.elu(x) + + return f + elif activation == "2+elu": + + def f(x): + return 2 + F.elu(x) + + return f + elif activation == "silu" or activation == "swish": + return F.silu + elif activation == "sine": + return torch.sin + else: + logger.info(f"activation: does not support {activation}, use Identity!!!") + return lambda x: x + + +# Copied from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/modeling_minimax_text_01.py +class MiniMaxText01LightningAttention(nn.Module): + def __init__(self, config=None, layer_idx: Optional[int] = None, **kwargs): + super().__init__() + if config is None: + config = type("Config", (), kwargs) + + bias = False + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + + self.out_proj = nn.Linear( + self.head_dim * self.num_heads, self.hidden_size, bias=bias + ) + self.act = get_activation_fn(config.hidden_act) + self.norm = MiniMaxText01RMSNorm(self.head_dim * self.num_heads) + + self.qkv_proj = nn.Linear( + self.hidden_size, 3 * self.head_dim * self.num_heads, bias=bias + ) + self.output_gate = nn.Linear( + self.hidden_size, self.head_dim * self.num_heads, bias=bias + ) + + # for inference only + self.offset = 0 + self.layer_idx = layer_idx + + def forward( + self, + hidden_states, + attn_mask: Optional[torch.Tensor] = None, # (b, h, n, m) + output_attentions: bool = False, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + use_cache: bool = False, + slope_rate: Optional[torch.Tensor] = None, + **kwargs, + ): + if (not self.training) and (not do_eval): + return self.inference( + hidden_states, + attn_mask, + output_attentions, + past_key_value, + use_cache, + slope_rate, + ) + + def inference( + self, + x, + attn_mask: Optional[torch.Tensor] = None, # (b, n) + output_attentions: bool = False, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + use_cache: bool = False, + slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1) + ): + # x: b n d + b, n, d = x.shape + # linear map + qkv = self.act(self.qkv_proj(x)) + new_shape = qkv.size()[:-1] + (self.num_heads, -1) + qkv = qkv.view(*new_shape) + q, k, v = torch.split(qkv, [self.head_dim] * 3, dim=3) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + if past_key_value is None: + self.offset = q.shape[-2] + else: + self.offset += 1 + + # for align with metaseq + ratio = torch.exp(-slope_rate) + + # only use for the first time + if past_key_value is None: + slope_rate = slope_rate.to(torch.float32) + if attn_mask is not None: + v = v.masked_fill( + (1 - attn_mask).unsqueeze(1).unsqueeze(-1).to(torch.bool), 0 + ) + NUM_BLOCK = (n + BLOCK - 1) // BLOCK + b, h, n, d = q.shape + e = v.shape[-1] + # other + array = torch.arange(BLOCK).to(q) + 1 + q_decay = torch.exp(-slope_rate * array.reshape(-1, 1)) + k_decay = torch.exp(-slope_rate * (BLOCK - array.reshape(-1, 1))) + index = array[:, None] - array[None, :] + s_index = ( + slope_rate + * index[ + None, + None, + ] + ) + s_index = torch.where(index >= 0, -s_index, float("-inf")) + diag_decay = torch.exp(s_index) + + kv = torch.zeros(b, h, d, e).to(torch.float32).to(q.device) + output = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device) + for i in range(NUM_BLOCK): + si = i * BLOCK + ei = min(si + BLOCK, n) + m = ei - si + qi = q[:, :, si:ei].contiguous() + ki = k[:, :, si:ei].contiguous() + vi = v[:, :, si:ei].contiguous() + qkv_none_diag = torch.matmul(qi * q_decay[:, :m], kv).to(torch.float32) + + # diag + qk = ( + torch.matmul(qi, ki.transpose(-1, -2)).to(torch.float32) + * diag_decay[:, :, :m, :m] + ) + qkv_diag = torch.matmul(qk, vi.to(torch.float32)) + block_decay = torch.exp(-slope_rate * m) + output[:, :, si:ei] = qkv_none_diag + qkv_diag + kv = block_decay * kv + torch.matmul( + (ki * k_decay[:, -m:]).transpose(-1, -2).to(vi.dtype), vi + ) + + else: + kv = past_key_value + output = [] + for i in range(n): + kv = ratio * kv + torch.einsum( + "... n d, ... n e -> ... d e", + k[:, :, i : i + 1], + v[:, :, i : i + 1], + ) + qkv = torch.einsum( + "... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype) + ) + output.append(qkv) + output = torch.concat(output, dim=-2) + # reshape + output = rearrange(output, "b h n d -> b n (h d)") + # normalize + output = self.norm(output) + # gate + output = F.sigmoid(self.output_gate(x)) * output + # outproj + output = self.out_proj(output) + + attn_weights = None + + return output, attn_weights, kv + + +def _build_slope_tensor(n_attention_heads: int): + def get_slopes(n): + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2( + n + ) # In the paper, we only train models that have 2^a heads for some a. This function has + else: # some good properties that only occur when the input is a power of 2. To maintain that even + closest_power_of_2 = 2 ** math.floor( + math.log2(n) + ) # when the number of heads is not a power of 2, we use this workaround. + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + # h, 1, 1 + slopes = torch.tensor(get_slopes(n_attention_heads)).reshape( + n_attention_heads, 1, 1 + ) + + return slopes + + +def test_lightning_attention_implementations(model_params): + torch.manual_seed(42) + + batch_size = 2 + seq_len = 1024 + dtype = torch.bfloat16 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + hidden_states = torch.randn( + batch_size, seq_len, model_params["hidden_size"], dtype=dtype, device=device + ) + + attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device) + + slope_rate = _build_slope_tensor(model_params["num_attention_heads"]).to(device) + + model_attn = MiniMaxText01LightningAttention(**model_params).to(dtype).to(device) + model_attn.eval() + + with torch.no_grad(): + model_output, _, _ = model_attn.inference( + hidden_states, attn_mask=attention_mask, slope_rate=slope_rate + ) + + qkv = model_attn.act(model_attn.qkv_proj(hidden_states)) + new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1) + qkv = qkv.view(*new_shape) + q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + lib_output = lightning_attn_func(q, k, v, slope_rate) + lib_output = lib_output.transpose(1, 2).contiguous() + lib_output = lib_output.view(batch_size, seq_len, -1) + lib_output = model_attn.norm(lib_output) + lib_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * lib_output + lib_output = model_attn.out_proj(lib_output) + + torch.testing.assert_close( + model_output, + lib_output, + rtol=1e-3, + atol=1e-2, + msg="Lightning attention implementations produce different results", + ) + + print("✅ Two implementations match") + + +def get_benchmark(): + batch_size_range = [2**i for i in range(0, 7)] # max 64 + seq_length_range = [256, 512, 1024, 2048, 4096] # max 4096 + configs = list(itertools.product(batch_size_range, seq_length_range)) + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "seq_len"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["MiniMax-Text-01", "OpenNLPLab"], + line_names=[ + "MiniMax-Text-01 Model Implementation", + "OpenNLPLab Library Implementation", + ], + styles=[("blue", "-"), ("green", "-")], + ylabel="us", + plot_name="lightning-attention-prefill-performance", + args={}, + ) + ) + def benchmark(batch_size, seq_len, provider): + dtype = torch.bfloat16 + device = torch.device("cuda") + + params = { + "hidden_size": 6144, + "num_attention_heads": 64, + "head_dim": 96, + "hidden_act": "gelu", + } + + hidden_states = torch.randn( + batch_size, seq_len, params["hidden_size"], dtype=dtype, device=device + ) + + attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device) + + slope_rate = _build_slope_tensor(params["num_attention_heads"]).to(device) + model_attn = MiniMaxText01LightningAttention(**params).to(dtype).to(device) + model_attn.eval() + + quantiles = [0.5, 0.2, 0.8] + if provider == "MiniMax-Text-01": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: model_attn.inference( + hidden_states, attn_mask=attention_mask, slope_rate=slope_rate + ), + quantiles=quantiles, + ) + else: + + def run_lib(): + qkv = model_attn.act(model_attn.qkv_proj(hidden_states)) + new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1) + qkv = qkv.view(*new_shape) + q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + lib_output = lightning_attn_func(q, k, v, slope_rate) + lib_output = lib_output.transpose(1, 2).contiguous() + lib_output = lib_output.view(batch_size, seq_len, -1) + lib_output = model_attn.norm(lib_output) + lib_output = ( + torch.sigmoid(model_attn.output_gate(hidden_states)) * lib_output + ) + return model_attn.out_proj(lib_output) + + ms, min_ms, max_ms = triton.testing.do_bench( + run_lib, + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + return benchmark + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_path", + type=str, + default="./configs/benchmark_ops/lightning_attention_prefill/", + help="Path to save lightning attention prefill benchmark results", + ) + args = parser.parse_args() + + # Run correctness test first + # Adapted from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/config.json + params = { + "hidden_size": 6144, + "num_attention_heads": 64, + "head_dim": 96, + "hidden_act": "silu", + } + test_lightning_attention_implementations(params) + + # Run performance benchmark + benchmark = get_benchmark() + benchmark.run(print_data=True, save_path=args.save_path) diff --git a/benchmark/kernels/quantization/bench_int8_quant.py b/benchmark/kernels/quantization/bench_int8_quant.py new file mode 100644 index 00000000000..94b795690bf --- /dev/null +++ b/benchmark/kernels/quantization/bench_int8_quant.py @@ -0,0 +1,94 @@ +import argparse + +import torch +import triton +from vllm._custom_ops import scaled_int8_quant as vllm_scaled_int8_quant + +from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 + + +@torch.compile(backend="inductor") +def torch_int8_quant(x): + int8_max = torch.iinfo(torch.int8).max + + abs_max = x.abs().max(dim=-1, keepdim=True).values + scales = abs_max.to(torch.float32) / float(int8_max) + + q_x = (x / scales).round().to(torch.int8) + + return q_x, scales + + +def _test_accuracy_once(M, K, input_dtype, device): + x = torch.randn(M, K, dtype=input_dtype, device=device) * 5000 + out, scales, _ = vllm_scaled_int8_quant(x, symmetric=True) + out1, scales1 = per_token_quant_int8(x) + out2, scales2 = torch_int8_quant(x) + torch.testing.assert_close(out, out2, atol=1, rtol=0) + torch.testing.assert_close(out, out1, atol=1, rtol=0) + torch.testing.assert_close(scales, scales2) + torch.testing.assert_close(scales1, scales2) + print(f"M: {M}, K: {K}, type: {input_dtype} OK") + + +def test_accuracy(): + Ms = [1, 13, 128, 1024, 2048, 4096] + Ks = [512, 1024, 2048, 8192] + input_dtypes = [torch.float16, torch.bfloat16] + for M in Ms: + for K in Ks: + for input_dtype in input_dtypes: + _test_accuracy_once(M, K, input_dtype, "cuda") + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048], + x_log=False, + line_arg="provider", + line_vals=["vllm op", "triton", "torch.compile"], + line_names=["vllm op", "triton", "torch.compile"], + styles=[("blue", "-"), ("orange", "-"), ("red", "-")], + ylabel="ms", + plot_name="int8 per token quant", + args={}, + ) +) +def benchmark(batch_size, provider): + M, K = batch_size, 16384 + x = torch.randn(M, K, dtype=torch.float16, device="cuda") * 1000 + + quantiles = [0.5, 0.2, 0.8] + if provider == "vllm op": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: vllm_scaled_int8_quant(x, symmetric=True), + quantiles=quantiles, + ) + if provider == "triton": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: per_token_quant_int8(x), + quantiles=quantiles, + ) + if provider == "torch.compile": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: torch_int8_quant(x), + quantiles=quantiles, + ) + + return ms, min_ms, max_ms + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_path", + type=str, + default="./bench_int8_quant_res", + help="Path to save int8 quant benchmark results", + ) + args = parser.parse_args() + + test_accuracy() + + benchmark.run(print_data=True, show_plots=True, save_path=args.save_path) diff --git a/benchmark/kernels/rmsnorm/benchmark_rmsnorm.py b/benchmark/kernels/rmsnorm/benchmark_rmsnorm.py new file mode 100644 index 00000000000..ad7b180ce1d --- /dev/null +++ b/benchmark/kernels/rmsnorm/benchmark_rmsnorm.py @@ -0,0 +1,231 @@ +import itertools +from typing import Optional, Tuple, Union + +import torch +import triton +import triton.language as tl +from flashinfer.norm import fused_add_rmsnorm, rmsnorm +from torch import nn +from vllm import _custom_ops as vllm_ops + + +class HuggingFaceRMSNorm(nn.Module): + def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + orig_dtype = x.dtype + x = x.to(torch.float32) + if residual is not None: + x = x + residual.to(torch.float32) + residual = x.to(orig_dtype) + + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x.to(orig_dtype) * self.weight + if residual is None: + return x + else: + return x, residual + + +def rmsnorm_naive( + x: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, +): + naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps) + naive_norm.weight = nn.Parameter(weight) + naive_norm = naive_norm.to(x.device) + + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + output = naive_norm(x, residual) + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + + +def rmsnorm_flashinfer( + x: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, +): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + if residual is not None: + fused_add_rmsnorm(x, residual, weight, eps) + output = (x, residual) + else: + output = rmsnorm(x, weight, eps) + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + + +def rmsnorm_vllm( + x: torch.Tensor, + weight: torch.Tensor, + residual: Optional[torch.Tensor] = None, + eps: float = 1e-6, +): + orig_shape = x.shape + x = x.view(-1, x.shape[-1]) + if residual is not None: + residual = residual.view(-1, residual.shape[-1]) + + if residual is not None: + vllm_ops.fused_add_rms_norm(x, residual, weight, eps) + output = (x, residual) + else: + out = torch.empty_like(x) + vllm_ops.rms_norm(out, x, weight, eps) + output = out + + if isinstance(output, tuple): + output = (output[0].view(orig_shape), output[1].view(orig_shape)) + else: + output = output.view(orig_shape) + return output + + +def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): + dtype = torch.bfloat16 + x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda") + weight = torch.ones(hidden_size, dtype=dtype, device="cuda") + residual = torch.randn_like(x) if use_residual else None + + output_naive = rmsnorm_naive( + x.clone(), weight, residual.clone() if residual is not None else None + ) + output_flashinfer = rmsnorm_flashinfer( + x.clone(), weight, residual.clone() if residual is not None else None + ) + output_vllm = rmsnorm_vllm( + x.clone(), weight, residual.clone() if residual is not None else None + ) + + if use_residual: + output_naive = output_naive[0] + output_flashinfer = output_flashinfer[0] + output_vllm = output_vllm[0] + + print(f"Naive output={output_naive}") + print(f"FlashInfer output={output_flashinfer}") + print(f"VLLM output={output_vllm}") + + if torch.allclose( + output_naive, output_flashinfer, atol=1e-2, rtol=1e-2 + ) and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2): + print("✅ All implementations match") + else: + print("❌ Implementations differ") + + +batch_size_range = [2**i for i in range(0, 7, 2)] +seq_length_range = [2**i for i in range(6, 11, 1)] +head_num_range = [32, 48] +configs = list(itertools.product(head_num_range, batch_size_range, seq_length_range)) + + +def get_benchmark(use_residual): + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["head_num", "batch_size", "seq_len"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["huggingface", "flashinfer", "vllm"], + line_names=["HuggingFace", "FlashInfer", "vLLM"], + styles=[("blue", "-"), ("green", "-"), ("red", "-")], + ylabel="us", + plot_name=f"rmsnorm-performance-{'with' if use_residual else 'without'}-residual", + args={}, + ) + ) + def benchmark(head_num, batch_size, seq_len, provider): + dtype = torch.bfloat16 + hidden_size = head_num * 128 # assuming head_dim = 128 + + x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda") + weight = torch.ones(hidden_size, dtype=dtype, device="cuda") + residual = torch.randn_like(x) if use_residual else None + + quantiles = [0.5, 0.2, 0.8] + + if provider == "huggingface": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rmsnorm_naive( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ), + quantiles=quantiles, + ) + elif provider == "flashinfer": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rmsnorm_flashinfer( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: rmsnorm_vllm( + x.clone(), + weight, + residual.clone() if residual is not None else None, + ), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + return benchmark + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--use_residual", action="store_true", help="Whether to use residual connection" + ) + parser.add_argument( + "--save_path", + type=str, + default="./configs/benchmark_ops/rmsnorm/", + help="Path to save rmsnorm benchmark results", + ) + args = parser.parse_args() + + # Run correctness test + calculate_diff( + batch_size=4, seq_len=128, hidden_size=4096, use_residual=args.use_residual + ) + + # Get the benchmark function with proper use_residual setting + benchmark = get_benchmark(args.use_residual) + # Run performance benchmark + benchmark.run(print_data=True, save_path=args.save_path) diff --git a/benchmark/kernels/scheduler_batch/benchmark_write_req_to_token_pool_triton.py b/benchmark/kernels/scheduler_batch/benchmark_write_req_to_token_pool_triton.py new file mode 100644 index 00000000000..a9ad7bc5fdc --- /dev/null +++ b/benchmark/kernels/scheduler_batch/benchmark_write_req_to_token_pool_triton.py @@ -0,0 +1,345 @@ +import itertools +import os +from typing import List + +import numpy as np +import pytest +import torch +import triton +import triton.language as tl + + +@triton.jit +def write_req_to_token_pool_triton( + req_to_token_ptr, # [max_batch, max_context_len] + req_pool_indices, + pre_lens, + seq_lens, + extend_lens, + out_cache_loc, + req_to_token_ptr_stride: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 512 + pid = tl.program_id(0) + + req_pool_index = tl.load(req_pool_indices + pid) + pre_len = tl.load(pre_lens + pid) + seq_len = tl.load(seq_lens + pid) + + # TODO: optimize this? + cumsum_start = 0 + for i in range(pid): + cumsum_start += tl.load(extend_lens + i) + + num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE) + for i in range(num_loop): + offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + mask = offset < (seq_len - pre_len) + value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask) + tl.store( + req_to_token_ptr + + req_pool_index * req_to_token_ptr_stride + + offset + + pre_len, + value, + mask=mask, + ) + + +@triton.jit +def write_req_to_token_pool_triton_optimize( + req_to_token_ptr, # [max_batch, max_context_len] + req_pool_indices, + pre_lens, + seq_lens, + extend_lens, + out_cache_loc, + req_to_token_ptr_stride: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid_batch = tl.program_id(0) + pid_token = tl.program_id(1) + + req_pool_index = tl.load(req_pool_indices + pid_batch) + pre_len = tl.load(pre_lens + pid_batch) + seq_len = tl.load(seq_lens + pid_batch) + extend_len = seq_len - pre_len + + cumsum_start = 0 + for i in range(pid_batch): + cumsum_start += tl.load(extend_lens + i) + + token_start = pid_token * BLOCK_SIZE + + offset = tl.arange(0, BLOCK_SIZE) + actual_offset = token_start + offset + mask = actual_offset < extend_len + + src_ptr = out_cache_loc + cumsum_start + actual_offset + src_ptr = tl.max_contiguous(tl.multiple_of(src_ptr, BLOCK_SIZE), BLOCK_SIZE) + value = tl.load(src_ptr, mask=mask) + dst_ptr = ( + req_to_token_ptr + + req_pool_index * req_to_token_ptr_stride + + actual_offset + + pre_len + ) + dst_ptr = tl.max_contiguous(tl.multiple_of(dst_ptr, BLOCK_SIZE), BLOCK_SIZE) + + tl.store(dst_ptr, value, mask=mask) + + +def write_req_to_token_pool_reference( + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + pre_lens: torch.Tensor, + seq_lens: torch.Tensor, + extend_lens: torch.Tensor, + out_cache_loc: torch.Tensor, +) -> None: + """Reference implementation using PyTorch""" + for i in range(len(req_pool_indices)): + req_pool_idx = req_pool_indices[i].item() + pre_len = pre_lens[i].item() + seq_len = seq_lens[i].item() + extend_len = extend_lens[i].item() + + cumsum_start = sum(extend_lens[:i].tolist()) + + # Copy values from out_cache_loc to req_to_token + req_to_token[req_pool_idx, pre_len:seq_len] = out_cache_loc[ + cumsum_start : cumsum_start + extend_len + ] + + +def test_write_req_to_token_pool(): + max_batch = 4097 + max_context_len = 6148 + batch_size = 1 + extend_len = 14 + + # Initialize input tensors + req_to_token = torch.zeros( + (max_batch, max_context_len), dtype=torch.int32, device="cuda" + ) + req_pool_indices = torch.tensor([42], dtype=torch.int32, device="cuda") + pre_lens = torch.tensor([8], dtype=torch.int32, device="cuda") + seq_lens = torch.tensor([22], dtype=torch.int32, device="cuda") + extend_lens = torch.tensor([extend_len], dtype=torch.int32, device="cuda") + out_cache_loc = torch.arange(extend_len, dtype=torch.int32, device="cuda") + + # Create copies for reference implementation + req_to_token_ref = req_to_token.clone() + req_to_token_opt = req_to_token.clone() + + # Run original triton kernel + write_req_to_token_pool_triton[(batch_size,)]( + req_to_token, + req_pool_indices, + pre_lens, + seq_lens, + extend_lens, + out_cache_loc, + max_context_len, + ) + + # Run optimized triton kernel + def grid(batch_size, extend_len): + num_token_blocks = triton.cdiv(extend_len, 512) + return (batch_size, num_token_blocks) + + write_req_to_token_pool_triton_optimize[grid(batch_size, extend_len)]( + req_to_token_opt, + req_pool_indices, + pre_lens, + seq_lens, + extend_lens, + out_cache_loc, + max_context_len, + BLOCK_SIZE=512, + ) + + # Run reference implementation + write_req_to_token_pool_reference( + req_to_token_ref, + req_pool_indices, + pre_lens, + seq_lens, + extend_lens, + out_cache_loc, + ) + + # Compare results + torch.testing.assert_close(req_to_token, req_to_token_ref) + torch.testing.assert_close(req_to_token_opt, req_to_token_ref) + + # Test case 2: batch size > 1 + batch_size = 3 + extend_lens_list = [14, 20, 30] + total_extend_len = sum(extend_lens_list) + + req_to_token = torch.zeros( + (max_batch, max_context_len), dtype=torch.int32, device="cuda" + ) + req_pool_indices = torch.tensor([42, 100, 200], dtype=torch.int32, device="cuda") + pre_lens = torch.tensor([8, 10, 15], dtype=torch.int32, device="cuda") + seq_lens = torch.tensor([22, 30, 45], dtype=torch.int32, device="cuda") + extend_lens = torch.tensor(extend_lens_list, dtype=torch.int32, device="cuda") + out_cache_loc = torch.arange(total_extend_len, dtype=torch.int32, device="cuda") + + req_to_token_ref = req_to_token.clone() + req_to_token_opt = req_to_token.clone() + + # Run original triton kernel + write_req_to_token_pool_triton[(batch_size,)]( + req_to_token, + req_pool_indices, + pre_lens, + seq_lens, + extend_lens, + out_cache_loc, + max_context_len, + ) + + # Run optimized triton kernel + max_extend_len = max(extend_lens_list) + write_req_to_token_pool_triton_optimize[grid(batch_size, max_extend_len)]( + req_to_token_opt, + req_pool_indices, + pre_lens, + seq_lens, + extend_lens, + out_cache_loc, + max_context_len, + BLOCK_SIZE=512, + ) + + # Run reference implementation + write_req_to_token_pool_reference( + req_to_token_ref, + req_pool_indices, + pre_lens, + seq_lens, + extend_lens, + out_cache_loc, + ) + + # Compare results + torch.testing.assert_close(req_to_token, req_to_token_ref) + torch.testing.assert_close(req_to_token_opt, req_to_token_ref) + + +def get_benchmark(): + batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128] + extend_lens = [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] + configs = list(itertools.product(batch_sizes, extend_lens)) + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "extend_len"], + x_vals=configs, + line_arg="provider", + line_vals=["reference", "triton", "triton_optimize"], + line_names=["PyTorch", "Triton", "Triton Optimized"], + styles=[("blue", "-"), ("green", "-"), ("red", "-")], + ylabel="us", + plot_name="write-req-to-token-pool-performance", + args={}, + ) + ) + def benchmark(batch_size, extend_len, provider): + max_batch = 256 + max_context_len = 16384 + + extend_lens_list = [extend_len] * batch_size + total_extend_len = sum(extend_lens_list) + + req_to_token = torch.zeros( + (max_batch, max_context_len), dtype=torch.int32, device="cuda" + ) + req_pool_indices = torch.arange(batch_size, dtype=torch.int32, device="cuda") + pre_lens = torch.ones(batch_size, dtype=torch.int32, device="cuda") * 8 + seq_lens = pre_lens + extend_len + extend_lens = torch.tensor(extend_lens_list, dtype=torch.int32, device="cuda") + out_cache_loc = torch.arange(total_extend_len, dtype=torch.int32, device="cuda") + + quantiles = [0.5, 0.2, 0.8] + + if provider == "reference": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: write_req_to_token_pool_reference( + req_to_token.clone(), + req_pool_indices, + pre_lens, + seq_lens, + extend_lens, + out_cache_loc, + ), + quantiles=quantiles, + ) + elif provider == "triton": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: write_req_to_token_pool_triton[(batch_size,)]( + req_to_token.clone(), + req_pool_indices, + pre_lens, + seq_lens, + extend_lens, + out_cache_loc, + max_context_len, + ), + quantiles=quantiles, + ) + else: + + def run_optimized(): + block_size = 128 if extend_len <= 1024 else 512 + grid_config = (batch_size, triton.cdiv(extend_len, block_size)) + write_req_to_token_pool_triton_optimize[grid_config]( + req_to_token.clone(), + req_pool_indices, + pre_lens, + seq_lens, + extend_lens, + out_cache_loc, + max_context_len, + BLOCK_SIZE=block_size, + ) + + ms, min_ms, max_ms = triton.testing.do_bench( + run_optimized, quantiles=quantiles + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + return benchmark + + +def run_benchmark(save_path: str = "./configs/benchmark_ops/write_req_to_token_pool/"): + """Run benchmark and save results""" + + # Ensure save path exists + os.makedirs(save_path, exist_ok=True) + + # Run correctness test + test_write_req_to_token_pool() + print("Correctness test passed!") + + # Run performance test + benchmark = get_benchmark() + benchmark.run(print_data=True, save_path=save_path) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_path", + type=str, + default="./configs/benchmark_ops/write_req_to_token_pool/", + help="Path to save benchmark results", + ) + args = parser.parse_args() + + run_benchmark(args.save_path) diff --git a/benchmark/tree_of_thought_deep/bench_sglang.py b/benchmark/tree_of_thought_deep/bench_sglang.py index b60f1f00f19..bfb2a4113de 100644 --- a/benchmark/tree_of_thought_deep/bench_sglang.py +++ b/benchmark/tree_of_thought_deep/bench_sglang.py @@ -103,6 +103,7 @@ def tree_search(s, question, num_branches): def main(args): lines = read_jsonl(args.data_path) + lines = list(lines) # Construct prompts num_branches = 2 diff --git a/docker/Dockerfile b/docker/Dockerfile index aa9f3a4e4cf..cec05825d0b 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -26,14 +26,29 @@ WORKDIR /sgl-workspace ARG CUDA_VERSION RUN python3 -m pip install --upgrade pip setuptools wheel html5lib six \ && git clone --depth=1 https://github.com/sgl-project/sglang.git \ + && if [ "$CUDA_VERSION" = "12.1.1" ]; then \ + python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu121; \ + elif [ "$CUDA_VERSION" = "12.4.1" ]; then \ + python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu124; \ + elif [ "$CUDA_VERSION" = "12.5.1" ]; then \ + python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu124; \ + elif [ "$CUDA_VERSION" = "11.8.0" ]; then \ + python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu118; \ + python3 -m pip install sgl-kernel -i https://docs.sglang.ai/whl/cu118; \ + else \ + echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1; \ + fi \ && cd sglang \ && if [ "$BUILD_TYPE" = "srt" ]; then \ if [ "$CUDA_VERSION" = "12.1.1" ]; then \ python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu121/torch2.4/flashinfer/; \ elif [ "$CUDA_VERSION" = "12.4.1" ]; then \ python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/; \ + elif [ "$CUDA_VERSION" = "12.5.1" ]; then \ + python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/; \ elif [ "$CUDA_VERSION" = "11.8.0" ]; then \ python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu118/torch2.4/flashinfer/; \ + python3 -m pip install sgl-kernel -i https://docs.sglang.ai/whl/cu118; \ else \ echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1; \ fi; \ @@ -42,8 +57,11 @@ RUN python3 -m pip install --upgrade pip setuptools wheel html5lib six \ python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu121/torch2.4/flashinfer/; \ elif [ "$CUDA_VERSION" = "12.4.1" ]; then \ python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/; \ + elif [ "$CUDA_VERSION" = "12.5.1" ]; then \ + python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/; \ elif [ "$CUDA_VERSION" = "11.8.0" ]; then \ python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu118/torch2.4/flashinfer/; \ + python3 -m pip install sgl-kernel -i https://docs.sglang.ai/whl/cu118; \ else \ echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1; \ fi; \ diff --git a/docker/Dockerfile.dev b/docker/Dockerfile.dev index d79dabeb5fc..5ff1fa7a51a 100644 --- a/docker/Dockerfile.dev +++ b/docker/Dockerfile.dev @@ -18,9 +18,19 @@ RUN apt-get update && apt-get install -y \ silversearcher-ag \ cloc \ unzip \ + pkg-config \ + libssl-dev \ + bear \ && rm -rf /var/lib/apt/lists/* \ && apt-get clean +RUN apt update -y \ + && apt install -y --no-install-recommends gnupg \ + && echo "deb http://developer.download.nvidia.com/devtools/repos/ubuntu2004/amd64 /" | tee /etc/apt/sources.list.d/nvidia-devtools.list \ + && apt-key adv --fetch-keys http://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub \ + && apt update -y \ + && apt install nsight-systems-cli -y + # Set up locale RUN locale-gen en_US.UTF-8 ENV LANG en_US.UTF-8 @@ -32,7 +42,8 @@ RUN python3 -m pip install --no-cache-dir \ pytest \ black \ isort \ - icdiff + icdiff \ + pre-commit # Install diff-so-fancy RUN curl -LSso /usr/local/bin/diff-so-fancy https://github.com/so-fancy/diff-so-fancy/releases/download/v1.4.4/diff-so-fancy \ diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 2c9af6e7b0d..af9f9e24df7 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -1,8 +1,8 @@ # Usage (to build SGLang ROCm docker image): -# docker build --build-arg SGL_BRANCH=v0.4.0.post1 -t v0.4.0.post1-rocm620 -f Dockerfile.rocm . +# docker build --build-arg SGL_BRANCH=v0.4.2.post1 -t v0.4.2.post1-rocm620 -f Dockerfile.rocm . # default base image -ARG BASE_IMAGE="rocm/vllm-dev:20241022" +ARG BASE_IMAGE="rocmshared/vllm-rocm:20250114-tuned-elementwise-layernorm" FROM $BASE_IMAGE AS base USER root @@ -13,6 +13,13 @@ ARG SGL_REPO="https://github.com/sgl-project/sglang" ENV SGL_DEFAULT="main" ARG SGL_BRANCH=${SGL_DEFAULT} +ARG TRITON_REPO="https://github.com/triton-lang/triton.git" +ARG TRITON_COMMIT="845d75a" + + +ARG ATER_REPO="https://github.com/HaiShaw/ater" +ARG CK_COMMITS="fa05ae" + RUN git clone ${SGL_REPO} \ && cd sglang \ && if [ "${SGL_BRANCH}" = ${SGL_DEFAULT} ]; then \ @@ -30,6 +37,24 @@ RUN git clone ${SGL_REPO} \ RUN cp -r /sgl-workspace/sglang /sglang RUN python -m pip cache purge +RUN pip install IPython \ + && pip install orjson \ + && pip install python-multipart \ + && pip install torchao \ + && pip install pybind11 + +RUN pip uninstall -y triton +RUN git clone ${TRITON_REPO} \ + && cd triton \ + && git checkout ${TRITON_COMMIT} \ + && cd python \ + && python3 setup.py install + +RUN git clone ${ATER_REPO} \ + && cd ater \ + && git submodule update --init --recursive \ + && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop + # Performance environment variable. ENV HIP_FORCE_DEV_KERNARG=1 diff --git a/docs/README.md b/docs/README.md index 67c3ad19411..0a12d64b1f1 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,35 +1,77 @@ # SGLang Documentation -This is the documentation repository for SGLang. It is auto-generated from https://github.com/sgl-project/sglang/tree/main/docs. +We recommend new contributors start from writing documentation, which helps you quickly understand SGLang codebase. Most documentation files are located under the `docs/` folder. We prefer **Jupyter Notebooks** over Markdown so that all examples can be executed and validated by our docs CI pipeline. -## Build the documentation website +## Docs Workflow -### Dependency -``` +### Install Dependency + +```bash pip install -r requirements.txt ``` -### Build -``` +### Update Documentation + +Update your Jupyter notebooks in the appropriate subdirectories under `docs/`. If you add new files, remember to update `index.rst` (or relevant `.rst` files) accordingly. + +- **`pre-commit run --all-files`** manually runs all configured checks, applying fixes if possible. If it fails the first time, re-run it to ensure lint errors are fully resolved. Make sure your code passes all checks **before** creating a Pull Request. +- **Do not commit** directly to the `main` branch. Always create a new branch (e.g., `feature/my-new-feature`), push your changes, and open a PR from that branch. + +```bash +# 1) Compile all Jupyter notebooks +make compile + +# 2) Generate static HTML make html -``` -### Clean -To remove all generated files: -``` -make clean -``` +# 3) Preview documentation locally +# Open your browser at the displayed port to view the docs +bash serve.sh -### Serve (preview) -Run an HTTP server and visit http://localhost:8000 in your browser. +# 4) Clean notebook outputs +# nbstripout removes notebook outputs so your PR stays clean +pip install nbstripout +find . -name '*.ipynb' -exec nbstripout {} \; +# 5) Pre-commit checks and create a PR +# After these checks pass, push your changes and open a PR on your branch +pre-commit run --all-files ``` -python3 -m http.server --d _build/html + + +If you need to run and shut down a SGLang server or engine, following these examples: + +1. Launch and close Sever: + +```python +#Launch Sever + +from sglang.utils import ( + execute_shell_command, + wait_for_server, + terminate_process, + print_highlight, +) + +server_process = execute_shell_command( + "python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --host 0.0.0.0" +) + +wait_for_server("http://localhost:30000") + +# Terminate Sever + +terminate_process(server_process) ``` +2. Launch Engine and close Engine -### Deploy -Clone [sgl-project.github.io](https://github.com/sgl-project/sgl-project.github.io) and make sure you have write access. +```python +# Launch Engine -```bash -export DOC_SITE_PATH=../../sgl-project.github.io # update this with your path -python3 deploy.py +import sglang as sgl +import asyncio + +llm = sgl.Engine(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct") + +# Terminalte Engine +llm.shutdown() ``` diff --git a/docs/backend/backend.md b/docs/backend/backend.md deleted file mode 100644 index 79d17b809c7..00000000000 --- a/docs/backend/backend.md +++ /dev/null @@ -1,168 +0,0 @@ -# Backend: SGLang Runtime (SRT) -The SGLang Runtime (SRT) is an efficient serving engine. - -## Quick Start -Launch a server -``` -python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 -``` - -Send a request -``` -curl http://localhost:30000/generate \ - -H "Content-Type: application/json" \ - -d '{ - "text": "Once upon a time,", - "sampling_params": { - "max_new_tokens": 16, - "temperature": 0 - } - }' -``` - -Learn more about the argument specification, streaming, and multi-modal support [here](../references/sampling_params.md). - -## OpenAI Compatible API -In addition, the server supports OpenAI-compatible APIs. - -```python -import openai -client = openai.Client( - base_url="http://127.0.0.1:30000/v1", api_key="EMPTY") - -# Text completion -response = client.completions.create( - model="default", - prompt="The capital of France is", - temperature=0, - max_tokens=32, -) -print(response) - -# Chat completion -response = client.chat.completions.create( - model="default", - messages=[ - {"role": "system", "content": "You are a helpful AI assistant"}, - {"role": "user", "content": "List 3 countries and their capitals."}, - ], - temperature=0, - max_tokens=64, -) -print(response) - -# Text embedding -response = client.embeddings.create( - model="default", - input="How are you today", -) -print(response) -``` - -It supports streaming, vision, and almost all features of the Chat/Completions/Models/Batch endpoints specified by the [OpenAI API Reference](https://platform.openai.com/docs/api-reference/). - -## Additional Server Arguments -- To enable multi-GPU tensor parallelism, add `--tp 2`. If it reports the error "peer access is not supported between these two devices", add `--enable-p2p-check` to the server launch command. -``` -python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 2 -``` -- To enable multi-GPU data parallelism, add `--dp 2`. Data parallelism is better for throughput if there is enough memory. It can also be used together with tensor parallelism. The following command uses 4 GPUs in total. -``` -python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --dp 2 --tp 2 -``` -- If you see out-of-memory errors during serving, try to reduce the memory usage of the KV cache pool by setting a smaller value of `--mem-fraction-static`. The default value is `0.9`. -``` -python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --mem-fraction-static 0.7 -``` -- See [hyperparameter tuning](../references/hyperparameter_tuning.md) on tuning hyperparameters for better performance. -- If you see out-of-memory errors during prefill for long prompts, try to set a smaller chunked prefill size. -``` -python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --chunked-prefill-size 4096 -``` -- To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes. This does not work for FP8 currently. -- To enable torchao quantization, add `--torchao-config int4wo-128`. It supports other [quantization strategies (INT8/FP8)](https://github.com/sgl-project/sglang/blob/v0.3.6/python/sglang/srt/server_args.py#L671) as well. -- To enable fp8 weight quantization, add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments. -- To enable fp8 kv cache quantization, add `--kv-cache-dtype fp8_e5m2`. -- If the model does not have a chat template in the Hugging Face tokenizer, you can specify a [custom chat template](../references/custom_chat_template.md). - -- To run tensor parallelism on multiple nodes, add `--nnodes 2`. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port, you can use the following commands. If you meet deadlock, please try to add `--disable-cuda-graph` -``` -# Node 0 -python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --nccl-init sgl-dev-0:50000 --nnodes 2 --node-rank 0 - -# Node 1 -python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --nccl-init sgl-dev-0:50000 --nnodes 2 --node-rank 1 -``` - -## Engine Without HTTP Server - -We also provide an inference engine **without a HTTP server**. For example, - -```python -import sglang as sgl - -def main(): - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - sampling_params = {"temperature": 0.8, "top_p": 0.95} - llm = sgl.Engine(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct") - - outputs = llm.generate(prompts, sampling_params) - for prompt, output in zip(prompts, outputs): - print("===============================") - print(f"Prompt: {prompt}\nGenerated text: {output['text']}") - -if __name__ == "__main__": - main() -``` - -This can be used for offline batch inference and building custom servers. -You can view the full example [here](https://github.com/sgl-project/sglang/tree/main/examples/runtime/engine). - -## Use Models From ModelScope -
-More - -To use a model from [ModelScope](https://www.modelscope.cn), set the environment variable SGLANG_USE_MODELSCOPE. -``` -export SGLANG_USE_MODELSCOPE=true -``` -Launch [Qwen2-7B-Instruct](https://www.modelscope.cn/models/qwen/qwen2-7b-instruct) Server -``` -SGLANG_USE_MODELSCOPE=true python -m sglang.launch_server --model-path qwen/Qwen2-7B-Instruct --port 30000 -``` - -Or start it by docker. -```bash -docker run --gpus all \ - -p 30000:30000 \ - -v ~/.cache/modelscope:/root/.cache/modelscope \ - --env "SGLANG_USE_MODELSCOPE=true" \ - --ipc=host \ - lmsysorg/sglang:latest \ - python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --host 0.0.0.0 --port 30000 -``` - -
- -## Example: Run Llama 3.1 405B -
-More - -```bash -# Run 405B (fp8) on a single node -python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct-FP8 --tp 8 - -# Run 405B (fp16) on two nodes -## on the first node, replace the `172.16.4.52:20000` with your own first node ip address and port -GLOO_SOCKET_IFNAME=eth0 python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 0 --disable-cuda-graph - -## on the first node, replace the `172.16.4.52:20000` with your own first node ip address and port -GLOO_SOCKET_IFNAME=eth0 python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 1 --disable-cuda-graph -``` - -
diff --git a/docs/backend/function_calling.ipynb b/docs/backend/function_calling.ipynb new file mode 100644 index 00000000000..05e7108e60e --- /dev/null +++ b/docs/backend/function_calling.ipynb @@ -0,0 +1,523 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Tool and Function Calling\n", + "\n", + "This guide demonstrates how to use SGLang’s **Tool Calling** functionality." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## OpenAI Compatible API" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Launching the Server" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from openai import OpenAI\n", + "import json\n", + "from sglang.utils import (\n", + " execute_shell_command,\n", + " wait_for_server,\n", + " terminate_process,\n", + " print_highlight,\n", + ")\n", + "\n", + "server_process = execute_shell_command(\n", + " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --tool-call-parser llama3 --port 30333 --host 0.0.0.0\" # llama3\n", + ")\n", + "wait_for_server(\"http://localhost:30333\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that `--tool-call-parser` defines the parser used to interpret responses. Currently supported parsers include:\n", + "\n", + "- llama3: Llama 3.1 / 3.2 (e.g. meta-llama/Llama-3.1-8B-Instruct, meta-llama/Llama-3.2-1B-Instruct).\n", + "- mistral: Mistral (e.g. mistralai/Mistral-7B-Instruct-v0.3, mistralai/Mistral-Nemo-Instruct-2407, mistralai/\n", + "Mistral-Nemo-Instruct-2407, mistralai/Mistral-7B-v0.3).\n", + "- qwen25: Qwen 2.5 (e.g. Qwen/Qwen2.5-1.5B-Instruct, Qwen/Qwen2.5-7B-Instruct)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define Tools for Function Call\n", + "Below is a Python snippet that shows how to define a tool as a dictionary. The dictionary includes a tool name, a description, and property defined Parameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Define tools\n", + "tools = [\n", + " {\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"get_current_weather\",\n", + " \"description\": \"Get the current weather in a given location\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"city\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The city to find the weather for, e.g. 'San Francisco'\",\n", + " },\n", + " \"state\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"the two-letter abbreviation for the state that the city is\"\n", + " \" in, e.g. 'CA' which would mean 'California'\",\n", + " },\n", + " \"unit\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The unit to fetch the temperature in\",\n", + " \"enum\": [\"celsius\", \"fahrenheit\"],\n", + " },\n", + " },\n", + " \"required\": [\"city\", \"state\", \"unit\"],\n", + " },\n", + " },\n", + " }\n", + "]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define Messages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_messages():\n", + " return [\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"What's the weather like in Boston today? Please respond with the format: Today's weather is :{function call result}\",\n", + " }\n", + " ]\n", + "\n", + "\n", + "messages = get_messages()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Initialize the Client" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize OpenAI-like client\n", + "client = OpenAI(api_key=\"None\", base_url=\"http://0.0.0.0:30333/v1\")\n", + "model_name = client.models.list().data[0].id" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Non-Streaming Request" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Non-streaming mode test\n", + "response_non_stream = client.chat.completions.create(\n", + " model=model_name,\n", + " messages=messages,\n", + " temperature=0.8,\n", + " top_p=0.8,\n", + " stream=False, # Non-streaming\n", + " tools=tools,\n", + ")\n", + "print_highlight(\"Non-stream response:\")\n", + "print(response_non_stream)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Streaming Request" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Streaming mode test\n", + "print_highlight(\"Streaming response:\")\n", + "response_stream = client.chat.completions.create(\n", + " model=model_name,\n", + " messages=messages,\n", + " temperature=0.8,\n", + " top_p=0.8,\n", + " stream=True, # Enable streaming\n", + " tools=tools,\n", + ")\n", + "\n", + "chunks = []\n", + "for chunk in response_stream:\n", + " chunks.append(chunk)\n", + " if chunk.choices[0].delta.tool_calls:\n", + " print(chunk.choices[0].delta.tool_calls[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "### Handle Tool Calls\n", + "\n", + "When the engine determines it should call a particular tool, it will return arguments or partial arguments through the response. You can parse these arguments and later invoke the tool accordingly." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Non-Streaming Request**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "name_non_stream = response_non_stream.choices[0].message.tool_calls[0].function.name\n", + "arguments_non_stream = (\n", + " response_non_stream.choices[0].message.tool_calls[0].function.arguments\n", + ")\n", + "\n", + "print_highlight(f\"Final streamed function call name: {name_non_stream}\")\n", + "print_highlight(f\"Final streamed function call arguments: {arguments_non_stream}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Streaming Request**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Parse and combine function call arguments\n", + "arguments = []\n", + "for chunk in chunks:\n", + " choice = chunk.choices[0]\n", + " delta = choice.delta\n", + " if delta.tool_calls:\n", + " tool_call = delta.tool_calls[0]\n", + " if tool_call.function.name:\n", + " print_highlight(f\"Streamed function call name: {tool_call.function.name}\")\n", + "\n", + " if tool_call.function.arguments:\n", + " arguments.append(tool_call.function.arguments)\n", + " print(f\"Streamed function call arguments: {tool_call.function.arguments}\")\n", + "\n", + "# Combine all fragments into a single JSON string\n", + "full_arguments = \"\".join(arguments)\n", + "print_highlight(f\"Final streamed function call arguments: {full_arguments}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define a Tool Function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# This is a demonstration, define real function according to your usage.\n", + "def get_current_weather(city: str, state: str, unit: \"str\"):\n", + " return (\n", + " f\"The weather in {city}, {state} is 85 degrees {unit}. It is \"\n", + " \"partly cloudly, with highs in the 90's.\"\n", + " )\n", + "\n", + "\n", + "available_tools = {\"get_current_weather\": get_current_weather}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "## Execute the Tool" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "call_data = json.loads(full_arguments)\n", + "\n", + "messages.append(\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"\",\n", + " \"tool_calls\": {\"name\": \"get_current_weather\", \"arguments\": full_arguments},\n", + " }\n", + ")\n", + "\n", + "# Call the corresponding tool function\n", + "tool_name = messages[-1][\"tool_calls\"][\"name\"]\n", + "tool_to_call = available_tools[tool_name]\n", + "result = tool_to_call(**call_data)\n", + "print_highlight(f\"Function call result: {result}\")\n", + "messages.append({\"role\": \"tool\", \"content\": result, \"name\": tool_name})\n", + "\n", + "print_highlight(f\"Updated message history: {messages}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Send Results Back to Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "final_response = client.chat.completions.create(\n", + " model=model_name,\n", + " messages=messages,\n", + " temperature=0.8,\n", + " top_p=0.8,\n", + " stream=False,\n", + " tools=tools,\n", + ")\n", + "print_highlight(\"Non-stream response:\")\n", + "print(final_response)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Native API and SGLang Runtime (SRT)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer\n", + "import requests\n", + "\n", + "# generate an answer\n", + "tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n", + "\n", + "messages = get_messages()\n", + "\n", + "input = tokenizer.apply_chat_template(\n", + " messages,\n", + " tokenize=False,\n", + " add_generation_prompt=True,\n", + " tools=tools,\n", + ")\n", + "\n", + "gen_url = \"http://localhost:30333/generate\"\n", + "gen_data = {\"text\": input, \"sampling_params\": {\"skip_special_tokens\": False}}\n", + "gen_response = requests.post(gen_url, json=gen_data).json()[\"text\"]\n", + "print(gen_response)\n", + "\n", + "# parse the response\n", + "parse_url = \"http://localhost:30333/function_call\"\n", + "\n", + "function_call_input = {\n", + " \"text\": gen_response,\n", + " \"tool_call_parser\": \"llama3\",\n", + " \"tools\": tools,\n", + "}\n", + "\n", + "function_call_response = requests.post(parse_url, json=function_call_input)\n", + "function_call_response_json = function_call_response.json()\n", + "print(\"function name: \", function_call_response_json[\"calls\"][0][\"name\"])\n", + "print(\"function arguments: \", function_call_response_json[\"calls\"][0][\"parameters\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "terminate_process(server_process)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Offline Engine API" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sglang as sgl\n", + "from sglang.srt.function_call_parser import FunctionCallParser\n", + "from sglang.srt.managers.io_struct import Tool, Function\n", + "\n", + "llm = sgl.Engine(model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\")\n", + "tokenizer = llm.tokenizer_manager.tokenizer\n", + "input_ids = tokenizer.apply_chat_template(\n", + " messages, tokenize=True, add_generation_prompt=True, tools=tools\n", + ")\n", + "\n", + "sampling_params = {\n", + " \"max_new_tokens\": 128,\n", + " \"temperature\": 0.3,\n", + " \"top_p\": 0.95,\n", + " \"skip_special_tokens\": False,\n", + "}\n", + "\n", + "# 1) Offline generation\n", + "result = llm.generate(input_ids=input_ids, sampling_params=sampling_params)\n", + "generated_text = result[\"text\"] # Assume there is only one prompt\n", + "\n", + "print(\"=== Offline Engine Output Text ===\")\n", + "print(generated_text)\n", + "\n", + "\n", + "# 2) Parse using FunctionCallParser\n", + "def convert_dict_to_tool(tool_dict: dict) -> Tool:\n", + " function_dict = tool_dict.get(\"function\", {})\n", + " return Tool(\n", + " type=tool_dict.get(\"type\", \"function\"),\n", + " function=Function(\n", + " name=function_dict.get(\"name\"),\n", + " description=function_dict.get(\"description\"),\n", + " parameters=function_dict.get(\"parameters\"),\n", + " ),\n", + " )\n", + "\n", + "\n", + "tools = [convert_dict_to_tool(raw_tool) for raw_tool in tools]\n", + "\n", + "parser = FunctionCallParser(tools=tools, tool_call_parser=\"llama3\")\n", + "normal_text, calls = parser.parse_non_stream(generated_text)\n", + "\n", + "print(\"\\n=== Parsing Result ===\")\n", + "print(\"Normal text portion:\", normal_text)\n", + "print(\"Function call portion:\")\n", + "for call in calls:\n", + " # call: ToolCallItem\n", + " print(f\" - tool name: {call.name}\")\n", + " print(f\" parameters: {call.parameters}\")\n", + "\n", + "# 3) If needed, perform additional logic on the parsed functions, such as automatically calling the corresponding function to obtain a return value, etc." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "llm.shutdown()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## How to support a new model?\n", + "1. Update the TOOLS_TAG_LIST in sglang/srt/function_call_parser.py with the model’s tool tags. Currently supported tags include:\n", + "```\n", + "\tTOOLS_TAG_LIST = [\n", + "\t “<|plugin|>“,\n", + "\t ““,\n", + "\t “<|python_tag|>“,\n", + "\t “[TOOL_CALLS]”\n", + "\t]\n", + "```\n", + "2. Create a new detector class in sglang/srt/function_call_parser.py that inherits from BaseFormatDetector. The detector should handle the model’s specific function call format. For example:\n", + "```\n", + " class NewModelDetector(BaseFormatDetector):\n", + "```\n", + "3. Add the new detector to the MultiFormatParser class that manages all the format detectors." + ] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/backend/native_api.ipynb b/docs/backend/native_api.ipynb index 26758f7f975..f6c10d745c5 100644 --- a/docs/backend/native_api.ipynb +++ b/docs/backend/native_api.ipynb @@ -348,6 +348,76 @@ "source": [ "terminate_process(reward_process)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Skip Tokenizer and Detokenizer\n", + "\n", + "SGLang Runtime also supports skip tokenizer and detokenizer. This is useful in cases like integrating with RLHF workflow." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer_free_server_process = execute_shell_command(\n", + " \"\"\"\n", + "python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --port=30010 --skip-tokenizer-init\n", + "\"\"\"\n", + ")\n", + "\n", + "wait_for_server(\"http://localhost:30010\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Llama-3.2-1B-Instruct\")\n", + "\n", + "input_text = \"What is the capital of France?\"\n", + "\n", + "input_tokens = tokenizer.encode(input_text)\n", + "print_highlight(f\"Input Text: {input_text}\")\n", + "print_highlight(f\"Tokenized Input: {input_tokens}\")\n", + "\n", + "response = requests.post(\n", + " \"http://localhost:30010/generate\",\n", + " json={\n", + " \"input_ids\": input_tokens,\n", + " \"sampling_params\": {\n", + " \"temperature\": 0,\n", + " \"max_new_tokens\": 256,\n", + " \"stop_token_ids\": [tokenizer.eos_token_id],\n", + " },\n", + " \"stream\": False,\n", + " },\n", + ")\n", + "output = response.json()\n", + "output_tokens = output[\"token_ids\"]\n", + "\n", + "output_text = tokenizer.decode(output_tokens, skip_special_tokens=False)\n", + "print_highlight(f\"Tokenized Output: {output_tokens}\")\n", + "print_highlight(f\"Decoded Output: {output_text}\")\n", + "print_highlight(f\"Output Text: {output['meta_info']['finish_reason']}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "terminate_process(tokenizer_free_server_process)" + ] } ], "metadata": { diff --git a/docs/backend/offline_engine_api.ipynb b/docs/backend/offline_engine_api.ipynb index 7ce89d435d5..58d24ac3ff6 100644 --- a/docs/backend/offline_engine_api.ipynb +++ b/docs/backend/offline_engine_api.ipynb @@ -37,7 +37,7 @@ "outputs": [], "source": [ "# launch the offline engine\n", - "\n", + "from sglang.utils import stream_and_merge, async_stream_and_merge\n", "import sglang as sgl\n", "import asyncio\n", "\n", @@ -86,20 +86,22 @@ "outputs": [], "source": [ "prompts = [\n", - " \"Hello, my name is\",\n", - " \"The capital of France is\",\n", - " \"The future of AI is\",\n", + " \"Write a short, neutral self-introduction for a fictional character. Hello, my name is\",\n", + " \"Provide a concise factual statement about France’s capital city. The capital of France is\",\n", + " \"Explain possible future trends in artificial intelligence. The future of AI is\",\n", "]\n", - "sampling_params = {\"temperature\": 0.8, \"top_p\": 0.95}\n", "\n", - "print(\"\\n=== Testing synchronous streaming generation ===\")\n", + "sampling_params = {\n", + " \"temperature\": 0.2,\n", + " \"top_p\": 0.9,\n", + "}\n", "\n", - "for prompt in prompts:\n", - " print(f\"\\nPrompt: {prompt}\")\n", - " print(\"Generated text: \", end=\"\", flush=True)\n", + "print(\"\\n=== Testing synchronous streaming generation with overlap removal ===\\n\")\n", "\n", - " for chunk in llm.generate(prompt, sampling_params, stream=True):\n", - " print(chunk[\"text\"], end=\"\", flush=True)\n", + "for prompt in prompts:\n", + " print(f\"Prompt: {prompt}\")\n", + " merged_output = stream_and_merge(llm, prompt, sampling_params)\n", + " print(\"Generated text:\", merged_output)\n", " print()" ] }, @@ -117,9 +119,9 @@ "outputs": [], "source": [ "prompts = [\n", - " \"Hello, my name is\",\n", - " \"The capital of France is\",\n", - " \"The future of AI is\",\n", + " \"Write a short, neutral self-introduction for a fictional character. Hello, my name is\",\n", + " \"Provide a concise factual statement about France’s capital city. The capital of France is\",\n", + " \"Explain possible future trends in artificial intelligence. The future of AI is\",\n", "]\n", "\n", "sampling_params = {\"temperature\": 0.8, \"top_p\": 0.95}\n", @@ -152,13 +154,14 @@ "outputs": [], "source": [ "prompts = [\n", - " \"Hello, my name is\",\n", - " \"The capital of France is\",\n", - " \"The future of AI is\",\n", + " \"Write a short, neutral self-introduction for a fictional character. Hello, my name is\",\n", + " \"Provide a concise factual statement about France’s capital city. The capital of France is\",\n", + " \"Explain possible future trends in artificial intelligence. The future of AI is\",\n", "]\n", + "\n", "sampling_params = {\"temperature\": 0.8, \"top_p\": 0.95}\n", "\n", - "print(\"\\n=== Testing asynchronous streaming generation ===\")\n", + "print(\"\\n=== Testing asynchronous streaming generation (no repeats) ===\")\n", "\n", "\n", "async def main():\n", @@ -166,10 +169,11 @@ " print(f\"\\nPrompt: {prompt}\")\n", " print(\"Generated text: \", end=\"\", flush=True)\n", "\n", - " generator = await llm.async_generate(prompt, sampling_params, stream=True)\n", - " async for chunk in generator:\n", - " print(chunk[\"text\"], end=\"\", flush=True)\n", - " print()\n", + " # Replace direct calls to async_generate with our custom overlap-aware version\n", + " async for cleaned_chunk in async_stream_and_merge(llm, prompt, sampling_params):\n", + " print(cleaned_chunk, end=\"\", flush=True)\n", + "\n", + " print() # New line after each prompt\n", "\n", "\n", "asyncio.run(main())" diff --git a/docs/backend/openai_api_completions.ipynb b/docs/backend/openai_api_completions.ipynb index 067a046885d..58b524108db 100644 --- a/docs/backend/openai_api_completions.ipynb +++ b/docs/backend/openai_api_completions.ipynb @@ -24,14 +24,7 @@ "source": [ "## Launch A Server\n", "\n", - "This code block is equivalent to executing \n", - "\n", - "```bash\n", - "python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", - "--port 30000 --host 0.0.0.0\n", - "```\n", - "\n", - "in your terminal and wait for the server to be ready." + "Launch the server in your terminal and wait for it to initialize." ] }, { @@ -48,10 +41,10 @@ ")\n", "\n", "server_process = execute_shell_command(\n", - " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --host 0.0.0.0\"\n", + " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30020 --host 0.0.0.0\"\n", ")\n", "\n", - "wait_for_server(\"http://localhost:30000\")" + "wait_for_server(\"http://localhost:30020\")" ] }, { @@ -75,7 +68,7 @@ "source": [ "import openai\n", "\n", - "client = openai.Client(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n", + "client = openai.Client(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n", "\n", "response = client.chat.completions.create(\n", " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", @@ -220,74 +213,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Structured decoding (JSON, Regex)\n", - "You can specify a JSON schema or a regular expression to constrain the model output. The model output will be guaranteed to follow the given constraints.\n", + "## Structured Outputs (JSON, Regex, EBNF)\n", "\n", - "### JSON" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import json\n", - "\n", - "json_schema = json.dumps(\n", - " {\n", - " \"type\": \"object\",\n", - " \"properties\": {\n", - " \"name\": {\"type\": \"string\", \"pattern\": \"^[\\\\w]+$\"},\n", - " \"population\": {\"type\": \"integer\"},\n", - " },\n", - " \"required\": [\"name\", \"population\"],\n", - " }\n", - ")\n", - "\n", - "response = client.chat.completions.create(\n", - " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", - " messages=[\n", - " {\n", - " \"role\": \"user\",\n", - " \"content\": \"Give me the information of the capital of France in the JSON format.\",\n", - " },\n", - " ],\n", - " temperature=0,\n", - " max_tokens=128,\n", - " response_format={\n", - " \"type\": \"json_schema\",\n", - " \"json_schema\": {\"name\": \"foo\", \"schema\": json.loads(json_schema)},\n", - " },\n", - ")\n", - "\n", - "print_highlight(response.choices[0].message.content)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Regular expression" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "response = client.chat.completions.create(\n", - " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", - " messages=[\n", - " {\"role\": \"user\", \"content\": \"What is the capital of France?\"},\n", - " ],\n", - " temperature=0,\n", - " max_tokens=128,\n", - " extra_body={\"regex\": \"(Paris|London)\"},\n", - ")\n", - "\n", - "print_highlight(response.choices[0].message.content)" + "For OpenAI compatible structed outputs API, refer to [Structured Outputs](https://docs.sglang.ai/backend/structured_outputs.html#OpenAI-Compatible-API) for more details.\n" ] }, { @@ -317,7 +245,7 @@ "import time\n", "from openai import OpenAI\n", "\n", - "client = OpenAI(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n", + "client = OpenAI(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n", "\n", "requests = [\n", " {\n", @@ -420,7 +348,7 @@ "import time\n", "from openai import OpenAI\n", "\n", - "client = OpenAI(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n", + "client = OpenAI(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n", "\n", "requests = []\n", "for i in range(100):\n", @@ -497,7 +425,7 @@ "from openai import OpenAI\n", "import os\n", "\n", - "client = OpenAI(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n", + "client = OpenAI(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n", "\n", "requests = []\n", "for i in range(500):\n", diff --git a/docs/backend/openai_api_embeddings.ipynb b/docs/backend/openai_api_embeddings.ipynb index 65b07c384d7..67ce68bcf1b 100644 --- a/docs/backend/openai_api_embeddings.ipynb +++ b/docs/backend/openai_api_embeddings.ipynb @@ -20,14 +20,7 @@ "source": [ "## Launch A Server\n", "\n", - "The following code is equivalent to running this in the shell:\n", - "\n", - "```bash\n", - "python -m sglang.launch_server --model-path Alibaba-NLP/gte-Qwen2-7B-instruct \\\n", - " --port 30000 --host 0.0.0.0 --is-embedding\n", - "```\n", - "\n", - "Remember to add `--is-embedding` to the command." + "Launch the server in your terminal and wait for it to initialize. Remember to add `--is-embedding` to the command." ] }, { diff --git a/docs/backend/openai_api_vision.ipynb b/docs/backend/openai_api_vision.ipynb index af17b440969..da8864c24c9 100644 --- a/docs/backend/openai_api_vision.ipynb +++ b/docs/backend/openai_api_vision.ipynb @@ -22,13 +22,7 @@ "source": [ "## Launch A Server\n", "\n", - "This code block is equivalent to executing \n", - "\n", - "```bash\n", - "python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-11B-Vision-Instruct \\\n", - " --port 30000 --chat-template llama_3_vision\n", - "```\n", - "in your terminal and wait for the server to be ready.\n", + "Launch the server in your terminal and wait for it to initialize.\n", "\n", "Remember to add `--chat-template llama_3_vision` to specify the vision chat template, otherwise the server only supports text.\n", "We need to specify `--chat-template` for vision language models because the chat template provided in Hugging Face tokenizer only supports text." diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md new file mode 100644 index 00000000000..7e8f4ca0a54 --- /dev/null +++ b/docs/backend/server_arguments.md @@ -0,0 +1,184 @@ +# Server Arguments + +## Common launch commands + +- To enable multi-GPU tensor parallelism, add `--tp 2`. If it reports the error "peer access is not supported between these two devices", add `--enable-p2p-check` to the server launch command. +``` +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 2 +``` +- To enable multi-GPU data parallelism, add `--dp 2`. Data parallelism is better for throughput if there is enough memory. It can also be used together with tensor parallelism. The following command uses 4 GPUs in total. We recommend [SGLang Router](https://docs.sglang.ai/router/router.html) for data parallelism. +``` +python -m sglang_router.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --dp 2 --tp 2 +``` + +- If you see out-of-memory errors during serving, try to reduce the memory usage of the KV cache pool by setting a smaller value of `--mem-fraction-static`. The default value is `0.9`. +``` +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --mem-fraction-static 0.7 +``` +- See [hyperparameter tuning](../references/hyperparameter_tuning.md) on tuning hyperparameters for better performance. +- If you see out-of-memory errors during prefill for long prompts, try to set a smaller chunked prefill size. +``` +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --chunked-prefill-size 4096 +``` +- To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes. This does not work for FP8 currently. +- To enable torchao quantization, add `--torchao-config int4wo-128`. It supports other [quantization strategies (INT8/FP8)](https://github.com/sgl-project/sglang/blob/v0.3.6/python/sglang/srt/server_args.py#L671) as well. +- To enable fp8 weight quantization, add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments. +- To enable fp8 kv cache quantization, add `--kv-cache-dtype fp8_e5m2`. +- If the model does not have a chat template in the Hugging Face tokenizer, you can specify a [custom chat template](../references/custom_chat_template.md). + +- To run tensor parallelism on multiple nodes, add `--nnodes 2`. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port, you can use the following commands. If you meet deadlock, please try to add `--disable-cuda-graph` +``` +# Node 0 +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --dist-init-addr sgl-dev-0:50000 --nnodes 2 --node-rank 0 + +# Node 1 +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --dist-init-addr sgl-dev-0:50000 --nnodes 2 --node-rank 1 +``` + +Please consult the documentation below to learn more about the parameters you may provide when launching a server. + + +## Model and tokenizer + +* `model_path`: Path to the model that will be served. +* `tokenizer_path`: Defaults to the `model_path`. +* `tokenizer_mode`: By default `auto`, see [here](https://huggingface.co/docs/transformers/en/main_classes/tokenizer) for different mode. +* `load_format`: The format the weights are loaded in. Defaults to `*.safetensors`/`*.bin`. +* `trust_remote_code`: If `True`, will use locally cached config files, other wise use remote configs in HuggingFace. +* `dtype`: Dtype used for the model, defaults to `bfloat16`. +* `kv_cache_dtype`: Dtype of the kv cache, defaults to the `dtype`. +* `context_length`: The number of tokens our model can process *including the input*. Not that extending the default might lead to strange behavior. +* `device`: The device we put the model, defaults to `cuda`. +* `chat_template`: The chat template to use. Deviating from the default might lead to unexpected responses. For multi-modal chat templates, refer to [here](https://docs.sglang.ai/backend/openai_api_vision.html#Chat-Template). +* `is_embedding`: Set to true to perform [embedding](https://docs.sglang.ai/backend/openai_api_embeddings.html) / [enocode](https://docs.sglang.ai/backend/native_api.html#Encode-(embedding-model)) and [reward](https://docs.sglang.ai/backend/native_api.html#Classify-(reward-model)) tasks. +* `revision`: Adjust if a specific version of the model should be used. +* `skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. +* `json_model_override_args`: Override model config with the provided JSON. +* `delete_ckpt_after_loading`: Delete the model checkpoint after loading the model. + +## Serving: HTTP & API + +### HTTP Server configuration + +* `port` and `host`: Setup the host for HTTP server. By default `host: str = "127.0.0.1"` and `port: int = 30000` + +### API configuration + +* `api_key`: Sets an API key for the server and the OpenAI-compatible API. +* `file_storage_pth`: Directory for storing uploaded or generated files from API calls. +* `enable_cache_report`: If set, includes detailed usage of cached tokens in the response usage. + +## Parallelism + +### Tensor parallelism + +* `tp_size`: The number of GPUs the model weights get sharded over. Mainly for saving memory rather than for high throughput, see [this blogpost](https://pytorch.org/tutorials/intermediate/TP_tutorial.html#how-tensor-parallel-works). + +### Data parallelism + +* `dp_size`: Will be deprecated. The number of data-parallel copies of the model. [SGLang router](https://docs.sglang.ai/router/router.html) is recommended instead of the current naive data parallel. +* `load_balance_method`: Will be deprecated. Load balancing strategy for data parallel requests. + +### Expert parallelism + +* `ep_size`: Distribute the experts onto multiple GPUs for MoE models. Remember to shard the model weights with `tp_size=ep_size`, for detailed benchmarking refer to [this PR](https://github.com/sgl-project/sglang/pull/2203). + +## Memory and scheduling + +* `mem_fraction_static`: Fraction of the free GPU memory used for static memory like model weights and KV cache. If building KV cache fails, it should be increased. If CUDA runs out of memory, it should be decreased. +* `max_running_requests`: The maximum number of requests to run concurrently. +* `max_total_tokens`: The maximum number of tokens that can be stored into the KV cache. Use mainly for debugging. +* `chunked_prefill_size`: Perform the prefill in chunks of these size. Larger chunk size speeds up the prefill phase but increases the VRAM consumption. If CUDA runs out of memory, it should be decreased. +* `max_prefill_tokens`: Token budget of how many tokens to accept in one prefill batch. The actual number is the max of this parameter and the `context_length`. +* `schedule_policy`: The scheduling policy to control the processing order of waiting prefill requests in a single engine. +* `schedule_conservativeness`: Can be used to decrease/increase the conservativeness of the server when taking new requests. Highly conservative behavior leads to starvation, but low conservativeness leads to slowed-down performance. +* `cpu_offload_gb`: Reserve this amount of RAM in GB for offloading of model parameters to the CPU. +* `prefill_only_one_req`: When this flag is turned on, the engine prefills only one request at a time. + +## Other runtime options + +* `stream_interval`: Interval (in tokens) for streaming responses. Smaller values lead to smoother streaming, and larger values lead to better throughput. +* `random_seed`: Can be used to enforce more deterministic behavior. +* `watchdog_timeout`: Adjusts the watchdog thread’s timeout before killing the server if batch generation takes too long. +* `download_dir`: Use to override the default Hugging Face cache directory for model weights. +* `base_gpu_id`: Use to adjust first GPU used to distribute the model across available GPUs. +* `allow_auto_truncate`: Automatically truncate requests that exceed the maximum input length. + +## Logging + +* `log_level`: Global log verbosity. +* `log_level_http`: Separate verbosity level for the HTTP server logs (if unset, defaults to `log_level`). +* `log_requests`: Logs the inputs and outputs of all requests for debugging. +* `show_time_cost`: Prints or logs detailed timing info for internal operations (helpful for performance tuning). +* `enable_metrics`: Exports Prometheus-like metrics for request usage and performance. +* `decode_log_interval`: How often (in tokens) to log decode progress. + +## Multi-node distributed serving + +* `dist_init_addr`: The TCP address used for initializing PyTorch’s distributed backend (e.g. `192.168.0.2:25000`). +* `nnodes`: Total number of nodes in the cluster. Refer to how to run the [Llama 405B model](https://docs.sglang.ai/references/llama_405B.html#run-405b-fp16-on-two-nodes). +* `node_rank`: Rank (ID) of this node among the `nnodes` in the distributed setup. + + +## LoRA + +* `lora_paths`: You may provide a list of adapters to your model as a list. Each batch element will get model response with the corresponding lora adapter applied. Currently `cuda_graph` and `radix_attention` are not supportet with this option so you need to disable them manually. We are still working on through these [issues](https://github.com/sgl-project/sglang/issues/2929). +* `max_loras_per_batch`: Maximum number of LoRAs in a running batch including base model. + +## Kernel backend + +* `attention_backend`: The backend for attention computation and KV cache management. +* `sampling_backend`: The backend for sampling. + +## Constrained Decoding + +* `grammar_backend`: The grammar backend for constraint decoding. Detailed usage can be found in this [document](https://docs.sglang.ai/backend/structured_outputs.html). +* `constrained_json_whitespace_pattern`: Use with `Outlines` grammar backend to allow JSON with syntatic newlines, tabs or multiple spaces. Details can be found [here](https://dottxt-ai.github.io/outlines/latest/reference/generation/json/#using-pydantic). + +## Speculative decoding + +* `speculative_draft_model_path`: The draft model path for speculative decoding. +* `speculative_algorithm`: The algorithm for speculative decoding. Currently only [Eagle](https://arxiv.org/html/2406.16858v1) is supported. Note that the radix cache, chunked prefill, and overlap scheduler are disabled when using eagle speculative decoding. +* `speculative_num_steps`: How many draft passes we run before verifying. +* `speculative_num_draft_tokens`: The number of tokens proposed in a draft. +* `speculative_eagle_topk`: The number of top candidates we keep for verification at each step for [Eagle](https://arxiv.org/html/2406.16858v1). + + +## Double Sparsity + +* `enable_double_sparsity`: Enables [double sparsity](https://arxiv.org/html/2408.07092v2) which increases throughput. +* `ds_channel_config_path`: The double sparsity config. For a guide on how to generate the config for your model see [this repo](https://github.com/andy-yang-1/DoubleSparse/tree/main/config). +* `ds_heavy_channel_num`: Number of channel indices to keep for each layer. +* `ds_heavy_token_num`: Number of tokens used for attention during decode. Skip sparse decoding if `min_seq_len` in batch < this number. +* `ds_heavy_channel_type`: The type of heavy channels. Either `q`, `k` or `qk`. +* `ds_sparse_decode_threshold`: Don't apply sparse decoding if `max_seq_len` in batch < this threshold. + +## Debug options + +*Note: We recommend to stay with the defaults and only use these options for debugging for best possible performance.* + +* `disable_radix_cache`: Disable [Radix](https://lmsys.org/blog/2024-01-17-sglang/) backend for prefix caching. +* `disable_jump_forward`: Disable [jump-forward](https://lmsys.org/blog/2024-02-05-compressed-fsm/#our-method-jump-forward-decoding-with-a-compressed-finite-state-machine) for outlines grammar backend. +* `disable_cuda_graph`: Disable [cuda graph](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/) for model forward. +* `disable_cuda_graph_padding`: Disable cuda graph when padding is needed. In other case still use cuda graph. +* `disable_outlines_disk_cache`: Disable disk cache for outlines grammar backend. +* `disable_custom_all_reduce`: Disable usage of custom all reduce kernel. +* `disable_mla`: Disable [Multi-Head Latent Attention](https://arxiv.org/html/2405.04434v5) for Deepseek model. +* `disable_overlap_schedule`: Disable the [Overhead-Scheduler](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#zero-overhead-batch-scheduler). +* `enable_nan_detection`: Turning this on makes the sampler print a warning if the logits contain `NaN`. +* `enable_p2p_check`: Turns off the default of allowing always p2p check when accessing GPU. +* `triton_attention_reduce_in_fp32`: In triton kernels this will cast the intermediate attention result to `float32`. + +## Optimization + +*Note: Some of these options are still in experimental stage.* + +* `enable_mixed_chunk`: Enables mixing prefill and decode, see [this discussion](https://github.com/sgl-project/sglang/discussions/1163). +* `enable_dp_attention`: Enable [Data Parallelism Attention](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models) for Deepseek models. Note that you need to choose `dp_size = tp_size` for this. +* `enable_ep_moe`: Enables expert parallelism, see the description of `ep_size`. +* `enable_torch_compile`: Torch compile the model. This is an experimental feature. +* `torch_compile_max_bs`: The maximum batch size when using `torch_compile`. +* `cuda_graph_max_bs`: Adjust the maximum batchsize when using cuda graph. By default this is chosen for you based on GPU specifics. +* `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you. +* `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-, fp8wo, fp8dq-per_tensor, fp8dq-per_row. +* `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8. diff --git a/docs/backend/speculative_decoding.ipynb b/docs/backend/speculative_decoding.ipynb new file mode 100644 index 00000000000..273d943d120 --- /dev/null +++ b/docs/backend/speculative_decoding.ipynb @@ -0,0 +1,182 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Speculative Decoding\n", + "\n", + "SGLang now provides an EAGLE-based speculative decoding option. The implementation aims to maximize speed and efficiency and is considered to be among the fastest in open-source LLM engines.\n", + "\n", + "**Note:** Currently, Speculative Decoding in SGLang does not support radix cache.\n", + "\n", + "To run the following tests or benchmarks, you also need to install [**cutex**](https://pypi.org/project/cutex/): \n", + "\n", + "`pip install cutex`\n", + "\n", + "### Performance Highlights\n", + "\n", + "- Official EAGLE code ([SafeAILab/EAGLE](https://github.com/SafeAILab/EAGLE)): ~200 tokens/s\n", + "- Standard SGLang Decoding: ~156 tokens/s\n", + "- EAGLE Decoding in SGLang: ~297 tokens/s\n", + "- EAGLE Decoding in SGLang (w/ `torch.compile`): ~316 tokens/s\n", + "\n", + "All benchmarks below were run on a single H100." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## EAGLE Decoding\n", + "\n", + "To enable EAGLE-based speculative decoding, specify the draft model (`--speculative-draft`) and the relevant EAGLE parameters:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# EAGLE decoding\n", + "from sglang.utils import (\n", + " execute_shell_command,\n", + " wait_for_server,\n", + " terminate_process,\n", + " print_highlight,\n", + ")\n", + "\n", + "server_process = execute_shell_command(\n", + " \"\"\"\n", + "python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algo EAGLE \\\n", + " --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n", + " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.7 --port=30020\n", + "\"\"\"\n", + ")\n", + "\n", + "wait_for_server(\"http://localhost:30020\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import openai\n", + "\n", + "client = openai.Client(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n", + "\n", + "response = client.chat.completions.create(\n", + " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", + " messages=[\n", + " {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n", + " ],\n", + " temperature=0,\n", + " max_tokens=64,\n", + ")\n", + "\n", + "print_highlight(f\"Response: {response}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "terminate_process(server_process)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### EAGLE Decoding with `torch.compile`\n", + "\n", + "You can also enable `torch.compile` for further optimizations and optionally set `--cuda-graph-max-bs`:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "server_process = execute_shell_command(\n", + " \"\"\"\n", + "python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algo EAGLE \\\n", + " --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n", + " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.7 \\\n", + " --enable-torch-compile --cuda-graph-max-bs 2 --port=30020\n", + "\"\"\"\n", + ")\n", + "\n", + "wait_for_server(\"http://localhost:30020\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Benchmark Script\n", + "\n", + "The following code example shows how to measure the decoding speed when generating tokens:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "import requests\n", + "\n", + "tic = time.time()\n", + "response = requests.post(\n", + " \"http://localhost:30020/generate\",\n", + " json={\n", + " \"text\": \"[INST] Give me a simple FastAPI server. Show the python code. [/INST]\",\n", + " \"sampling_params\": {\n", + " \"temperature\": 0,\n", + " \"max_new_tokens\": 256,\n", + " },\n", + " },\n", + ")\n", + "latency = time.time() - tic\n", + "ret = response.json()\n", + "completion_text = ret[\"text\"]\n", + "speed = ret[\"meta_info\"][\"completion_tokens\"] / latency\n", + "\n", + "print_highlight(completion_text)\n", + "print_highlight(f\"speed: {speed:.2f} token/s\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "terminate_process(server_process)" + ] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/backend/structured_outputs.ipynb b/docs/backend/structured_outputs.ipynb new file mode 100644 index 00000000000..e413743ccfd --- /dev/null +++ b/docs/backend/structured_outputs.ipynb @@ -0,0 +1,598 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Structured Outputs" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can specify a JSON schema, [regular expression](https://en.wikipedia.org/wiki/Regular_expression) or [EBNF](https://en.wikipedia.org/wiki/Extended_Backus%E2%80%93Naur_form) to constrain the model output. The model output will be guaranteed to follow the given constraints. Only one constraint parameter (`json_schema`, `regex`, or `ebnf`) can be specified for a request.\n", + "\n", + "SGLang supports two grammar backends:\n", + "\n", + "- [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints.\n", + "- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema, regular expression, and EBNF constraints.\n", + "\n", + "We suggest using XGrammar for its better performance and utility. XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md). For more details, see [XGrammar technical overview](https://blog.mlc.ai/2024/11/22/achieving-efficient-flexible-portable-structured-generation-with-xgrammar).\n", + "\n", + "To use Xgrammar, simply add `--grammar-backend` xgrammar when launching the server. If no backend is specified, Outlines will be used as the default.\n", + "\n", + "For better output quality, **It's advisable to explicitly include instructions in the prompt to guide the model to generate the desired format.** For example, you can specify, 'Please generate the output in the following JSON format: ...'.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## OpenAI Compatible API" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sglang.utils import (\n", + " execute_shell_command,\n", + " wait_for_server,\n", + " terminate_process,\n", + " print_highlight,\n", + ")\n", + "import openai\n", + "import os\n", + "\n", + "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", + "\n", + "\n", + "server_process = execute_shell_command(\n", + " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --host 0.0.0.0 --grammar-backend xgrammar\"\n", + ")\n", + "\n", + "wait_for_server(\"http://localhost:30000\")\n", + "client = openai.Client(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### JSON\n", + "\n", + "you can directly define a JSON schema or use [Pydantic](https://docs.pydantic.dev/latest/) to define and validate the response." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Using Pydantic**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from pydantic import BaseModel, Field\n", + "\n", + "\n", + "# Define the schema using Pydantic\n", + "class CapitalInfo(BaseModel):\n", + " name: str = Field(..., pattern=r\"^\\w+$\", description=\"Name of the capital city\")\n", + " population: int = Field(..., description=\"Population of the capital city\")\n", + "\n", + "\n", + "response = client.chat.completions.create(\n", + " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", + " messages=[\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"Please generate the information of the capital of France in the JSON format.\",\n", + " },\n", + " ],\n", + " temperature=0,\n", + " max_tokens=128,\n", + " response_format={\n", + " \"type\": \"json_schema\",\n", + " \"json_schema\": {\n", + " \"name\": \"foo\",\n", + " # convert the pydantic model to json schema\n", + " \"schema\": CapitalInfo.model_json_schema(),\n", + " },\n", + " },\n", + ")\n", + "\n", + "response_content = response.choices[0].message.content\n", + "# validate the JSON response by the pydantic model\n", + "capital_info = CapitalInfo.model_validate_json(response_content)\n", + "print_highlight(f\"Validated response: {capital_info.model_dump_json()}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**JSON Schema Directly**\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "\n", + "json_schema = json.dumps(\n", + " {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"name\": {\"type\": \"string\", \"pattern\": \"^[\\\\w]+$\"},\n", + " \"population\": {\"type\": \"integer\"},\n", + " },\n", + " \"required\": [\"name\", \"population\"],\n", + " }\n", + ")\n", + "\n", + "response = client.chat.completions.create(\n", + " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", + " messages=[\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"Give me the information of the capital of France in the JSON format.\",\n", + " },\n", + " ],\n", + " temperature=0,\n", + " max_tokens=128,\n", + " response_format={\n", + " \"type\": \"json_schema\",\n", + " \"json_schema\": {\"name\": \"foo\", \"schema\": json.loads(json_schema)},\n", + " },\n", + ")\n", + "\n", + "print_highlight(response.choices[0].message.content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### EBNF" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ebnf_grammar = \"\"\"\n", + "root ::= city | description\n", + "city ::= \"London\" | \"Paris\" | \"Berlin\" | \"Rome\"\n", + "description ::= city \" is \" status\n", + "status ::= \"the capital of \" country\n", + "country ::= \"England\" | \"France\" | \"Germany\" | \"Italy\"\n", + "\"\"\"\n", + "\n", + "response = client.chat.completions.create(\n", + " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": \"You are a helpful geography bot.\"},\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"Give me the information of the capital of France.\",\n", + " },\n", + " ],\n", + " temperature=0,\n", + " max_tokens=32,\n", + " extra_body={\"ebnf\": ebnf_grammar},\n", + ")\n", + "\n", + "print_highlight(response.choices[0].message.content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Regular expression" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "response = client.chat.completions.create(\n", + " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", + " messages=[\n", + " {\"role\": \"user\", \"content\": \"What is the capital of France?\"},\n", + " ],\n", + " temperature=0,\n", + " max_tokens=128,\n", + " extra_body={\"regex\": \"(Paris|London)\"},\n", + ")\n", + "\n", + "print_highlight(response.choices[0].message.content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Native API and SGLang Runtime (SRT)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### JSON" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Using Pydantic**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import requests\n", + "import json\n", + "from pydantic import BaseModel, Field\n", + "\n", + "\n", + "# Define the schema using Pydantic\n", + "class CapitalInfo(BaseModel):\n", + " name: str = Field(..., pattern=r\"^\\w+$\", description=\"Name of the capital city\")\n", + " population: int = Field(..., description=\"Population of the capital city\")\n", + "\n", + "\n", + "# Make API request\n", + "response = requests.post(\n", + " \"http://localhost:30000/generate\",\n", + " json={\n", + " \"text\": \"Here is the information of the capital of France in the JSON format.\\n\",\n", + " \"sampling_params\": {\n", + " \"temperature\": 0,\n", + " \"max_new_tokens\": 64,\n", + " \"json_schema\": json.dumps(CapitalInfo.model_json_schema()),\n", + " },\n", + " },\n", + ")\n", + "print_highlight(response.json())\n", + "\n", + "\n", + "response_data = json.loads(response.json()[\"text\"])\n", + "# validate the response by the pydantic model\n", + "capital_info = CapitalInfo.model_validate(response_data)\n", + "print_highlight(f\"Validated response: {capital_info.model_dump_json()}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**JSON Schema Directly**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "json_schema = json.dumps(\n", + " {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"name\": {\"type\": \"string\", \"pattern\": \"^[\\\\w]+$\"},\n", + " \"population\": {\"type\": \"integer\"},\n", + " },\n", + " \"required\": [\"name\", \"population\"],\n", + " }\n", + ")\n", + "\n", + "# JSON\n", + "response = requests.post(\n", + " \"http://localhost:30000/generate\",\n", + " json={\n", + " \"text\": \"Here is the information of the capital of France in the JSON format.\\n\",\n", + " \"sampling_params\": {\n", + " \"temperature\": 0,\n", + " \"max_new_tokens\": 64,\n", + " \"json_schema\": json_schema,\n", + " },\n", + " },\n", + ")\n", + "\n", + "print_highlight(response.json())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### EBNF" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import requests\n", + "\n", + "response = requests.post(\n", + " \"http://localhost:30000/generate\",\n", + " json={\n", + " \"text\": \"Give me the information of the capital of France.\",\n", + " \"sampling_params\": {\n", + " \"max_new_tokens\": 128,\n", + " \"temperature\": 0,\n", + " \"n\": 3,\n", + " \"ebnf\": (\n", + " \"root ::= city | description\\n\"\n", + " 'city ::= \"London\" | \"Paris\" | \"Berlin\" | \"Rome\"\\n'\n", + " 'description ::= city \" is \" status\\n'\n", + " 'status ::= \"the capital of \" country\\n'\n", + " 'country ::= \"England\" | \"France\" | \"Germany\" | \"Italy\"'\n", + " ),\n", + " },\n", + " \"stream\": False,\n", + " \"return_logprob\": False,\n", + " },\n", + ")\n", + "\n", + "print_highlight(response.json())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Regular expression" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "response = requests.post(\n", + " \"http://localhost:30000/generate\",\n", + " json={\n", + " \"text\": \"Paris is the capital of\",\n", + " \"sampling_params\": {\n", + " \"temperature\": 0,\n", + " \"max_new_tokens\": 64,\n", + " \"regex\": \"(France|England)\",\n", + " },\n", + " },\n", + ")\n", + "print_highlight(response.json())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "terminate_process(server_process)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Offline Engine API" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sglang as sgl\n", + "\n", + "llm = sgl.Engine(\n", + " model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\", grammar_backend=\"xgrammar\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### JSON" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Using Pydantic**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "from pydantic import BaseModel, Field\n", + "\n", + "\n", + "prompts = [\n", + " \"Give me the information of the capital of China in the JSON format.\",\n", + " \"Give me the information of the capital of France in the JSON format.\",\n", + " \"Give me the information of the capital of Ireland in the JSON format.\",\n", + "]\n", + "\n", + "\n", + "# Define the schema using Pydantic\n", + "class CapitalInfo(BaseModel):\n", + " name: str = Field(..., pattern=r\"^\\w+$\", description=\"Name of the capital city\")\n", + " population: int = Field(..., description=\"Population of the capital city\")\n", + "\n", + "\n", + "sampling_params = {\n", + " \"temperature\": 0.1,\n", + " \"top_p\": 0.95,\n", + " \"json_schema\": json.dumps(CapitalInfo.model_json_schema()),\n", + "}\n", + "\n", + "outputs = llm.generate(prompts, sampling_params)\n", + "for prompt, output in zip(prompts, outputs):\n", + " print_highlight(\"===============================\")\n", + " print_highlight(f\"Prompt: {prompt}\") # validate the output by the pydantic model\n", + " capital_info = CapitalInfo.model_validate_json(output[\"text\"])\n", + " print_highlight(f\"Validated output: {capital_info.model_dump_json()}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**JSON Schema Directly**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prompts = [\n", + " \"Give me the information of the capital of China in the JSON format.\",\n", + " \"Give me the information of the capital of France in the JSON format.\",\n", + " \"Give me the information of the capital of Ireland in the JSON format.\",\n", + "]\n", + "\n", + "json_schema = json.dumps(\n", + " {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"name\": {\"type\": \"string\", \"pattern\": \"^[\\\\w]+$\"},\n", + " \"population\": {\"type\": \"integer\"},\n", + " },\n", + " \"required\": [\"name\", \"population\"],\n", + " }\n", + ")\n", + "\n", + "sampling_params = {\"temperature\": 0.1, \"top_p\": 0.95, \"json_schema\": json_schema}\n", + "\n", + "outputs = llm.generate(prompts, sampling_params)\n", + "for prompt, output in zip(prompts, outputs):\n", + " print_highlight(\"===============================\")\n", + " print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### EBNF\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prompts = [\n", + " \"Give me the information of the capital of France.\",\n", + " \"Give me the information of the capital of Germany.\",\n", + " \"Give me the information of the capital of Italy.\",\n", + "]\n", + "\n", + "sampling_params = {\n", + " \"temperature\": 0.8,\n", + " \"top_p\": 0.95,\n", + " \"ebnf\": (\n", + " \"root ::= city | description\\n\"\n", + " 'city ::= \"London\" | \"Paris\" | \"Berlin\" | \"Rome\"\\n'\n", + " 'description ::= city \" is \" status\\n'\n", + " 'status ::= \"the capital of \" country\\n'\n", + " 'country ::= \"England\" | \"France\" | \"Germany\" | \"Italy\"'\n", + " ),\n", + "}\n", + "\n", + "outputs = llm.generate(prompts, sampling_params)\n", + "for prompt, output in zip(prompts, outputs):\n", + " print_highlight(\"===============================\")\n", + " print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Regular expression" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prompts = [\n", + " \"Please provide information about London as a major global city:\",\n", + " \"Please provide information about Paris as a major global city:\",\n", + "]\n", + "\n", + "sampling_params = {\"temperature\": 0.8, \"top_p\": 0.95, \"regex\": \"(France|England)\"}\n", + "\n", + "outputs = llm.generate(prompts, sampling_params)\n", + "for prompt, output in zip(prompts, outputs):\n", + " print_highlight(\"===============================\")\n", + " print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "llm.shutdown()" + ] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/developer/development_guide_using_docker.md b/docs/developer/development_guide_using_docker.md new file mode 100644 index 00000000000..918057d0e96 --- /dev/null +++ b/docs/developer/development_guide_using_docker.md @@ -0,0 +1,47 @@ +# Development Guide Using Docker + +## Setup VSCode + +Download `code` from `Https://code.visualstudio.com/docs/?dv=linux64cli` + +```bash +wget https://vscode.download.prss.microsoft.com/dbazure/download/stable/fabdb6a30b49f79a7aba0f2ad9df9b399473380f/vscode_cli_alpine_x64_cli.tar.gz +tar xf vscode_cli_alpine_x64_cli.tar.gz + +# https://code.visualstudio.com/docs/remote/tunnels +./code tunnel +``` + +## Setup Docker Container + +The following startup command is an example for internal development by the SGLang team. You can **modify or add directory mappings as needed**, especially for model weight downloads, to prevent repeated downloads by different Docker containers. + +### H100 + +```bash +# Change the name to yours +docker run -itd --shm-size 32g --gpus all -v /opt/dlami/nvme/.cache:/root/.cache --ipc=host --name sglang_zhyncs lmsysorg/sglang:dev /bin/zsh +docker exec -it sglang_zhyncs /bin/zsh +``` + +### H200 + +```bash +docker run -itd --shm-size 32g --gpus all -v /mnt/co-research/shared-models:/root/.cache/huggingface --ipc=host --name sglang_zhyncs lmsysorg/sglang:dev /bin/zsh +docker exec -it sglang_zhyncs /bin/zsh +``` + +## Profile + +```bash +# Change batch size, input, output and add `disable-cuda-graph` (for easier analysis) +# e.g. DeepSeek V3 +nsys profile -o deepseek_v3 python3 -m sglang.bench_one_batch --batch-size 1 --input 128 --output 256 --model deepseek-ai/DeepSeek-V3 --trust-remote-code --tp 8 --disable-cuda-graph +``` + +## Evaluation + +```bash +# e.g. gsm8k 8 shot +python3 benchmark/gsm8k/bench_sglang.py --num-questions 2000 --parallel 2000 --num-shots 8 +``` diff --git a/docs/developer/setup_github_runner.md b/docs/developer/setup_github_runner.md index c82094f6de3..96c9cae0154 100644 --- a/docs/developer/setup_github_runner.md +++ b/docs/developer/setup_github_runner.md @@ -11,9 +11,9 @@ docker pull nvidia/cuda:12.1.1-devel-ubuntu22.04 # Nvidia docker run --shm-size 128g -it -v /tmp/huggingface:/hf_home --gpus all nvidia/cuda:12.1.1-devel-ubuntu22.04 /bin/bash # AMD -docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.0.post1-rocm620 /bin/bash +docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.2.post1-rocm620 /bin/bash # AMD just the last 2 GPUs -docker run --rm --device=/dev/kfd --device=/dev/dri/renderD176 --device=/dev/dri/renderD184 --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.0.post1-rocm620 /bin/bash +docker run --rm --device=/dev/kfd --device=/dev/dri/renderD176 --device=/dev/dri/renderD184 --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.2.post1-rocm620 /bin/bash ``` ### Step 2: Configure the runner by `config.sh` diff --git a/docs/index.rst b/docs/index.rst index 8c6c018c4ce..aaa46384490 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -28,7 +28,10 @@ The core features include: backend/openai_api_embeddings.ipynb backend/native_api.ipynb backend/offline_engine_api.ipynb - backend/backend.md + backend/structured_outputs.ipynb + backend/speculative_decoding.ipynb + backend/function_calling.ipynb + backend/server_arguments.md .. toctree:: @@ -55,7 +58,10 @@ The core features include: references/hyperparameter_tuning.md references/benchmark_and_profiling.md references/custom_chat_template.md - references/contributor_guide.md + references/deepseek.md + references/llama_405B.md + references/modelscope.md + references/contribution_guide.md references/troubleshooting.md references/faq.md references/learn_more.md diff --git a/docs/references/benchmark_and_profiling.md b/docs/references/benchmark_and_profiling.md index 329dad33609..0600b192b4f 100644 --- a/docs/references/benchmark_and_profiling.md +++ b/docs/references/benchmark_and_profiling.md @@ -56,22 +56,39 @@ with nvtx.annotate("description", color="color"): ## Other tips 1. You can benchmark a model using dummy weights by only providing the config.json file. This allows for quick testing of model variants without training. To do so, add `--load-format dummy` to the above commands and then you only need a correct `config.json` under the checkpoint folder. +2. You can benchmark a model with modified configs (e.g., less layers) by using `--json-model-override-args`. For example, you can benchmark a model with only 2 layers and 2 kv heads using `python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --batch 32 --input-len 256 --output-len 32 --load-format dummy --json-model-override-args '{"num_hidden_layers": 1, "num_key_value_heads": 1}'` + ## Profile with PyTorch Profiler - To profile a server ```bash # set trace path export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log + # start server python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct -python -m sglang.bench_serving --backend sglang --model-path meta-llama/Llama-3.1-8B-Instruct --num-prompts 10 --profile +# send profiling request from client +python -m sglang.bench_serving --backend sglang --model-path meta-llama/Llama-3.1-8B-Instruct --num-prompts 10 --sharegpt-output-len 100 --profile ``` - -Traces can be visualized using https://ui.perfetto.dev/. +Please make sure that the `SGLANG_TORCH_PROFILER_DIR` should be set at both server and client side, otherwise the trace file cannot be generated correctly . A secure way will be setting `SGLANG_TORCH_PROFILER_DIR` in the `.*rc` file of shell (e.g. `~/.bashrc` for bash shells). - To profile offline ```bash export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log python -m sglang.bench_offline_throughput --model-path meta-llama/Llama-3.1-8B-Instruct --dataset-name random --num-prompts 10 --profile --mem-frac=0.8 ``` + +- View Traces + +Trace files can be loaded and visualized from: +1. https://ui.perfetto.dev/ (any browser) +2. chrome://tracing (Chrome browser only) + +If browser cannot open trace file due to its large size, +client can generate a small trace file (<100MB) by controlling number of prompts and lengths of prompt outputs. +For example, when profiling a server, +```bash +python -m sglang.bench_serving --backend sglang --model-path meta-llama/Llama-3.1-8B-Instruct --num-prompts 2 --sharegpt-output-len 100 --profile +``` +sets the number of prompts to 2 with `--num-prompts` argument and limits the length of output sequences to 100 with `--sharegpt-output-len` argument, which can generate a small trace file for browser to open smoothly. diff --git a/docs/references/contribution_guide.md b/docs/references/contribution_guide.md new file mode 100644 index 00000000000..b3b7f826894 --- /dev/null +++ b/docs/references/contribution_guide.md @@ -0,0 +1,45 @@ +# Contribution Guide + +Welcome to **SGLang**! We appreciate your interest in contributing. This guide provides a concise overview of how to set up your environment, run tests, build documentation, and open a Pull Request (PR). Whether you’re fixing a small bug or developing a major feature, we encourage following these steps for a smooth contribution process. + +## Setting Up & Building from Source + +### Fork and Clone the Repository + +**Note**: New contributors do **not** have the write permission to push to the official SGLang repo. Please fork the repository under your GitHub account, then clone your fork locally. + +```bash +git clone https://github.com//sglang.git +``` + +### Install Dependencies & Build + +Refer to [Install SGLang from Source](https://docs.sglang.ai/start/install.html#method-2-from-source) documentation for more details on setting up the necessary dependencies. + +## Code Formatting with Pre-Commit + +We use [pre-commit](https://pre-commit.com/) to maintain consistent code style checks. Before pushing your changes, please run: + +```bash +pip3 install pre-commit +pre-commit run --all-files +``` + +- **`pre-commit run --all-files`** manually runs all configured checks, applying fixes if possible. If it fails the first time, re-run it to ensure lint errors are fully resolved. Make sure your code passes all checks **before** creating a Pull Request. +- **Do not commit** directly to the `main` branch. Always create a new branch (e.g., `feature/my-new-feature`), push your changes, and open a PR from that branch. + +## Running Unit Tests & Adding to CI + +SGLang uses Python's built-in [unittest](https://docs.python.org/3/library/unittest.html) framework. For detailed instructions on running tests and adding them to CI, please refer to [test/README.md](https://github.com/sgl-project/sglang/tree/main/test/README.md). + +## Writing Documentation & Running Docs CI + +We recommend new contributors start from writing documentation, which helps you quickly understand SGLang codebase. For more details, please refer to [docs/README.md](https://github.com/sgl-project/sglang/tree/main/docs/README.md). + +## Tips for Newcomers + +If you want to contribute but don’t have a specific idea in mind, pick issues labeled [“good first issue” or “help wanted”](https://github.com/sgl-project/sglang/issues?q=is%3Aissue+label%3A%22good+first+issue%22%2C%22help+wanted%22). These tasks typically have lower complexity and provide an excellent introduction to the codebase. Also check out this [code walk-through](https://github.com/zhaochenyang20/Awesome-ML-SYS-Tutorial/tree/main/sglang/code-walk-through) for a deeper look into SGLang’s workflow. + +If you have any questions or want to start a discussion, please feel free to ask in our [Slack channel](https://join.slack.com/t/sgl-fru7574/shared_invite/zt-2um0ad92q-LkU19KQTxCGzlCgRiOiQEw). + +Thank you for your interest in SGLang—**happy coding**! diff --git a/docs/references/contributor_guide.md b/docs/references/contributor_guide.md deleted file mode 100644 index a9b25163d12..00000000000 --- a/docs/references/contributor_guide.md +++ /dev/null @@ -1,14 +0,0 @@ -# Contributor Guide - -## Format Your Code -Use these commands to format your code and pass CI linting tests. - -``` -pip3 install pre-commit -cd sglang -pre-commit install -pre-commit run --all-files -``` - -## Add Unit Tests -Add unit tests under [sglang/test](https://github.com/sgl-project/sglang/tree/main/test). You can learn how to add and run tests from the README.md in that folder. diff --git a/docs/references/deepseek.md b/docs/references/deepseek.md new file mode 100644 index 00000000000..2bdceb90478 --- /dev/null +++ b/docs/references/deepseek.md @@ -0,0 +1,56 @@ +# DeepSeek Model Optimizations + +SGLang provides several optimizations specifically designed for the DeepSeek model to boost its inference speed. This document outlines current optimizations for DeepSeek. Additionally, the SGLang team is actively developing enhancements for [DeepSeek-V3](https://github.com/sgl-project/sglang/issues/2591). + + +## Multi-head Latent Attention (MLA) Throughput Optimizations + +**Description**: [MLA](https://arxiv.org/pdf/2405.04434) is an innovative attention mechanism introduced by the DeepSeek team, aimed at improving inference efficiency. SGLang has implemented specific optimizations for this, including: + +- **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase. +- **Triton Decoding Kernel Optimization**: In the MLA decoding kernel, there is only one KV head. This optimization reduces memory access to the KV cache by processing multiple query heads within one block, accelerating the decoding process. + +- **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption. + +- **CUDA Graph & Torch.compile**: Both MLA and Mixture of Experts (MoE) are compatible with CUDA Graph and Torch.compile, which reduces latency and accelerates decoding speed for small batch sizes. + +Overall, with these optimizations, we have achieved up to a 7x acceleration in output throughput compared to the previous version. + +

+ Multi-head Latent Attention for DeepSeek Series Models +

+ +**Usage**: MLA optimization is enabled by defalut, to disable, use `--disable-mla`. + +**Reference**: Check [Blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#deepseek-multi-head-latent-attention-mla-throughput-optimizations) and [Slides](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/lmsys_1st_meetup_deepseek_mla.pdf) for more details. + +## Data Parallelism Attention + +**Description**: This optimization involves data parallelism (DP) for the MLA attention mechanism of DeepSeek Series Models, which allows for a significant reduction in the KV cache size, enabling larger batch sizes. Each DP worker independently handles different types of batches (prefill, decode, idle), which are then synchronized before and after processing through the Mixture-of-Experts (MoE) layer. + +

+ Data Parallelism Attention for DeepSeek Series Models +

+ +**Usage**: This optimization is aimed at improving throughput and should be used for scenarios with high QPS (Queries Per Second). Data Parallelism Attention optimization can be enabeld by `--enable-dp-attention` for DeepSeek Series Models. + +

+ Data Parallelism Attention Performance Comparison +

+ +**Reference**: Check [Blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models). + +## Multi Node Tensor Parallelism + +**Description**: For users with limited memory on a single node, SGLang supports serving DeepSeek Series Models, including DeepSeek V3, across multiple nodes using tensor parallelism. This approach partitions the model parameters across multiple GPUs or nodes to handle models that are too large for one node's memory. + +**Usage**: Check [here](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-2-h208) for usage examples. + +## Block-wise FP8 + +**Description**: SGLang implements block-wise FP8 quantization with two key optimizations: + +- **Activation**: E4M3 format using per-token-per-128-channel sub-vector scales with online casting. +- **Weight**: Per-128x128-block quantization for better numerical stability. + +**Usage**: turn on by default for DeepSeek V3 models. diff --git a/docs/references/llama_405B.md b/docs/references/llama_405B.md new file mode 100644 index 00000000000..a63b012fb27 --- /dev/null +++ b/docs/references/llama_405B.md @@ -0,0 +1,19 @@ +# Run Llama 3.1 405B + +## Run 405B (fp8) on a Single Node + +```bash +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct-FP8 --tp 8 +``` + +## Run 405B (fp16) on Two Nodes + +```bash +# on the first node, replace 172.16.4.52:20000 with your own node ip address and port + +python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --dist-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 0 + +# on the second node, replace 172.18.45.52:20000 with your own node ip address and port + +python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --dist-init-addr 172.18.45.52:20000 --nnodes 2 --node-rank 1 +``` diff --git a/docs/references/modelscope.md b/docs/references/modelscope.md new file mode 100644 index 00000000000..4740c2770f9 --- /dev/null +++ b/docs/references/modelscope.md @@ -0,0 +1,28 @@ +# Use Models From ModelScope + +To use a model from [ModelScope](https://www.modelscope.cn), set the environment variable `SGLANG_USE_MODELSCOPE`. + +```bash +export SGLANG_USE_MODELSCOPE=true +``` + +We take [Qwen2-7B-Instruct](https://www.modelscope.cn/models/qwen/qwen2-7b-instruct) as an example. + +Launch the Server: +```bash +python -m sglang.launch_server --model-path qwen/Qwen2-7B-Instruct --port 30000 +``` + +Or start it by docker: + +```bash +docker run --gpus all \ + -p 30000:30000 \ + -v ~/.cache/modelscope:/root/.cache/modelscope \ + --env "SGLANG_USE_MODELSCOPE=true" \ + --ipc=host \ + lmsysorg/sglang:latest \ + python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --host 0.0.0.0 --port 30000 +``` + +Note that modelscope uses a different cache directory than huggingface. You may need to set it manually to avoid running out of disk space. diff --git a/docs/references/production_metrics.md b/docs/references/production_metrics.md index 36515f3d454..20a34e54bcb 100644 --- a/docs/references/production_metrics.md +++ b/docs/references/production_metrics.md @@ -1,189 +1,129 @@ # Production Metrics -sglang exposes the following metrics via Prometheus. The metrics are namespaced by `$name` (the model name). +SGLang exposes the following metrics via Prometheus. The metrics are namespaced by `$name` (the model name). An example of the monitoring dashboard is available in [examples/monitoring/grafana.json](../examples/monitoring/grafana.json). Here is an example of the metrics: ``` -# HELP sglang:max_total_num_tokens Maximum total number of tokens -# TYPE sglang:max_total_num_tokens gauge -sglang:max_total_num_tokens{name="google/gemma-2-9b-it"} 161721.0 -# HELP sglang:max_prefill_tokens Maximum prefill tokens -# TYPE sglang:max_prefill_tokens gauge -sglang:max_prefill_tokens{name="google/gemma-2-9b-it"} 16384.0 -# HELP sglang:max_running_requests Maximum running requests -# TYPE sglang:max_running_requests gauge -sglang:max_running_requests{name="google/gemma-2-9b-it"} 4097.0 -# HELP sglang:context_len Context length -# TYPE sglang:context_len gauge -sglang:context_len{name="google/gemma-2-9b-it"} 8192.0 +$ curl http://localhost:30000/metrics + # HELP sglang:prompt_tokens_total Number of prefill tokens processed. # TYPE sglang:prompt_tokens_total counter -sglang:prompt_tokens_total{name="google/gemma-2-9b-it"} 506780.0 +sglang:prompt_tokens_total{model_name="meta-llama/Llama-3.1-8B-Instruct"} 7.0 # HELP sglang:generation_tokens_total Number of generation tokens processed. # TYPE sglang:generation_tokens_total counter -sglang:generation_tokens_total{name="google/gemma-2-9b-it"} 424549.0 -# HELP sglang:num_requests_running Number of requests currently running on GPU -# TYPE sglang:num_requests_running gauge -sglang:num_requests_running{name="google/gemma-2-9b-it"} 0.0 -# HELP sglang:num_requests_waiting Number of requests waiting to be processed. -# TYPE sglang:num_requests_waiting gauge -sglang:num_requests_waiting{name="google/gemma-2-9b-it"} 0.0 -# HELP sglang:gen_throughput Gen token throughput (token/s) -# TYPE sglang:gen_throughput gauge -sglang:gen_throughput{name="google/gemma-2-9b-it"} 0.0 -# HELP sglang:token_usage Total token usage -# TYPE sglang:token_usage gauge -sglang:token_usage{name="google/gemma-2-9b-it"} 0.01 -# HELP sglang:new_seq Number of new sequences -# TYPE sglang:new_seq gauge -sglang:new_seq{name="google/gemma-2-9b-it"} 0.0 -# HELP sglang:new_token Number of new token -# TYPE sglang:new_token gauge -sglang:new_token{name="google/gemma-2-9b-it"} 0.0 -# HELP sglang:cached_token Number of cached token -# TYPE sglang:cached_token gauge -sglang:cached_token{name="google/gemma-2-9b-it"} 0.0 -# HELP sglang:cache_hit_rate Cache hit rate -# TYPE sglang:cache_hit_rate gauge -sglang:cache_hit_rate{name="google/gemma-2-9b-it"} 10.61 -# HELP sglang:queue_req Number of queued requests -# TYPE sglang:queue_req gauge -sglang:queue_req{name="google/gemma-2-9b-it"} 0.0 +sglang:generation_tokens_total{model_name="meta-llama/Llama-3.1-8B-Instruct"} 8.0 # HELP sglang:time_to_first_token_seconds Histogram of time to first token in seconds. # TYPE sglang:time_to_first_token_seconds histogram -sglang:time_to_first_token_seconds_sum{name="google/gemma-2-9b-it"} 656.0780844688416 -sglang:time_to_first_token_seconds_bucket{le="0.001",name="google/gemma-2-9b-it"} 0.0 -sglang:time_to_first_token_seconds_bucket{le="0.005",name="google/gemma-2-9b-it"} 0.0 -sglang:time_to_first_token_seconds_bucket{le="0.01",name="google/gemma-2-9b-it"} 0.0 -sglang:time_to_first_token_seconds_bucket{le="0.02",name="google/gemma-2-9b-it"} 0.0 -sglang:time_to_first_token_seconds_bucket{le="0.04",name="google/gemma-2-9b-it"} 207.0 -sglang:time_to_first_token_seconds_bucket{le="0.06",name="google/gemma-2-9b-it"} 456.0 -sglang:time_to_first_token_seconds_bucket{le="0.08",name="google/gemma-2-9b-it"} 598.0 -sglang:time_to_first_token_seconds_bucket{le="0.1",name="google/gemma-2-9b-it"} 707.0 -sglang:time_to_first_token_seconds_bucket{le="0.25",name="google/gemma-2-9b-it"} 1187.0 -sglang:time_to_first_token_seconds_bucket{le="0.5",name="google/gemma-2-9b-it"} 1350.0 -sglang:time_to_first_token_seconds_bucket{le="0.75",name="google/gemma-2-9b-it"} 2124.0 -sglang:time_to_first_token_seconds_bucket{le="1.0",name="google/gemma-2-9b-it"} 2124.0 -sglang:time_to_first_token_seconds_bucket{le="2.5",name="google/gemma-2-9b-it"} 2124.0 -sglang:time_to_first_token_seconds_bucket{le="5.0",name="google/gemma-2-9b-it"} 2124.0 -sglang:time_to_first_token_seconds_bucket{le="7.5",name="google/gemma-2-9b-it"} 2124.0 -sglang:time_to_first_token_seconds_bucket{le="10.0",name="google/gemma-2-9b-it"} 2124.0 -sglang:time_to_first_token_seconds_bucket{le="15.0",name="google/gemma-2-9b-it"} 2124.0 -sglang:time_to_first_token_seconds_bucket{le="20.0",name="google/gemma-2-9b-it"} 2124.0 -sglang:time_to_first_token_seconds_bucket{le="25.0",name="google/gemma-2-9b-it"} 2124.0 -sglang:time_to_first_token_seconds_bucket{le="30.0",name="google/gemma-2-9b-it"} 2124.0 -sglang:time_to_first_token_seconds_bucket{le="+Inf",name="google/gemma-2-9b-it"} 2124.0 -sglang:time_to_first_token_seconds_count{name="google/gemma-2-9b-it"} 2124.0 -# HELP sglang:time_per_output_token_seconds Histogram of time per output token in seconds. -# TYPE sglang:time_per_output_token_seconds histogram -sglang:time_per_output_token_seconds_sum{name="google/gemma-2-9b-it"} 29846.5393948555 -sglang:time_per_output_token_seconds_bucket{le="0.005",name="google/gemma-2-9b-it"} 0.0 -sglang:time_per_output_token_seconds_bucket{le="0.01",name="google/gemma-2-9b-it"} 0.0 -sglang:time_per_output_token_seconds_bucket{le="0.015",name="google/gemma-2-9b-it"} 0.0 -sglang:time_per_output_token_seconds_bucket{le="0.02",name="google/gemma-2-9b-it"} 9602.0 -sglang:time_per_output_token_seconds_bucket{le="0.025",name="google/gemma-2-9b-it"} 30060.0 -sglang:time_per_output_token_seconds_bucket{le="0.03",name="google/gemma-2-9b-it"} 39184.0 -sglang:time_per_output_token_seconds_bucket{le="0.04",name="google/gemma-2-9b-it"} 61387.0 -sglang:time_per_output_token_seconds_bucket{le="0.05",name="google/gemma-2-9b-it"} 78835.0 -sglang:time_per_output_token_seconds_bucket{le="0.075",name="google/gemma-2-9b-it"} 139394.0 -sglang:time_per_output_token_seconds_bucket{le="0.1",name="google/gemma-2-9b-it"} 422029.0 -sglang:time_per_output_token_seconds_bucket{le="0.15",name="google/gemma-2-9b-it"} 422029.0 -sglang:time_per_output_token_seconds_bucket{le="0.2",name="google/gemma-2-9b-it"} 422029.0 -sglang:time_per_output_token_seconds_bucket{le="0.3",name="google/gemma-2-9b-it"} 422424.0 -sglang:time_per_output_token_seconds_bucket{le="0.4",name="google/gemma-2-9b-it"} 422424.0 -sglang:time_per_output_token_seconds_bucket{le="0.5",name="google/gemma-2-9b-it"} 422425.0 -sglang:time_per_output_token_seconds_bucket{le="0.75",name="google/gemma-2-9b-it"} 422425.0 -sglang:time_per_output_token_seconds_bucket{le="1.0",name="google/gemma-2-9b-it"} 422425.0 -sglang:time_per_output_token_seconds_bucket{le="2.5",name="google/gemma-2-9b-it"} 422425.0 -sglang:time_per_output_token_seconds_bucket{le="+Inf",name="google/gemma-2-9b-it"} 422425.0 -sglang:time_per_output_token_seconds_count{name="google/gemma-2-9b-it"} 422425.0 -# HELP sglang:request_prompt_tokens Number of prefill tokens processed -# TYPE sglang:request_prompt_tokens histogram -sglang:request_prompt_tokens_sum{name="google/gemma-2-9b-it"} 500552.0 -sglang:request_prompt_tokens_bucket{le="1.0",name="google/gemma-2-9b-it"} 0.0 -sglang:request_prompt_tokens_bucket{le="2.0",name="google/gemma-2-9b-it"} 0.0 -sglang:request_prompt_tokens_bucket{le="5.0",name="google/gemma-2-9b-it"} 22.0 -sglang:request_prompt_tokens_bucket{le="10.0",name="google/gemma-2-9b-it"} 191.0 -sglang:request_prompt_tokens_bucket{le="20.0",name="google/gemma-2-9b-it"} 511.0 -sglang:request_prompt_tokens_bucket{le="50.0",name="google/gemma-2-9b-it"} 825.0 -sglang:request_prompt_tokens_bucket{le="100.0",name="google/gemma-2-9b-it"} 997.0 -sglang:request_prompt_tokens_bucket{le="200.0",name="google/gemma-2-9b-it"} 1182.0 -sglang:request_prompt_tokens_bucket{le="500.0",name="google/gemma-2-9b-it"} 1748.0 -sglang:request_prompt_tokens_bucket{le="1000.0",name="google/gemma-2-9b-it"} 2102.0 -sglang:request_prompt_tokens_bucket{le="2000.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:request_prompt_tokens_bucket{le="5000.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:request_prompt_tokens_bucket{le="10000.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:request_prompt_tokens_bucket{le="20000.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:request_prompt_tokens_bucket{le="50000.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:request_prompt_tokens_bucket{le="100000.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:request_prompt_tokens_bucket{le="+Inf",name="google/gemma-2-9b-it"} 2104.0 -sglang:request_prompt_tokens_count{name="google/gemma-2-9b-it"} 2104.0 -# HELP sglang:request_generation_tokens Number of generation tokens processed. -# TYPE sglang:request_generation_tokens histogram -sglang:request_generation_tokens_sum{name="google/gemma-2-9b-it"} 424529.0 -sglang:request_generation_tokens_bucket{le="1.0",name="google/gemma-2-9b-it"} 0.0 -sglang:request_generation_tokens_bucket{le="2.0",name="google/gemma-2-9b-it"} 0.0 -sglang:request_generation_tokens_bucket{le="5.0",name="google/gemma-2-9b-it"} 49.0 -sglang:request_generation_tokens_bucket{le="10.0",name="google/gemma-2-9b-it"} 202.0 -sglang:request_generation_tokens_bucket{le="20.0",name="google/gemma-2-9b-it"} 448.0 -sglang:request_generation_tokens_bucket{le="50.0",name="google/gemma-2-9b-it"} 814.0 -sglang:request_generation_tokens_bucket{le="100.0",name="google/gemma-2-9b-it"} 979.0 -sglang:request_generation_tokens_bucket{le="200.0",name="google/gemma-2-9b-it"} 1266.0 -sglang:request_generation_tokens_bucket{le="500.0",name="google/gemma-2-9b-it"} 1883.0 -sglang:request_generation_tokens_bucket{le="1000.0",name="google/gemma-2-9b-it"} 2095.0 -sglang:request_generation_tokens_bucket{le="2000.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:request_generation_tokens_bucket{le="5000.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:request_generation_tokens_bucket{le="10000.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:request_generation_tokens_bucket{le="20000.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:request_generation_tokens_bucket{le="50000.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:request_generation_tokens_bucket{le="100000.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:request_generation_tokens_bucket{le="+Inf",name="google/gemma-2-9b-it"} 2104.0 -sglang:request_generation_tokens_count{name="google/gemma-2-9b-it"} 2104.0 +sglang:time_to_first_token_seconds_sum{model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.30457592010498047 +sglang:time_to_first_token_seconds_bucket{le="0.001",model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0 +sglang:time_to_first_token_seconds_bucket{le="0.005",model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0 +sglang:time_to_first_token_seconds_bucket{le="0.01",model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0 +sglang:time_to_first_token_seconds_bucket{le="0.02",model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0 +sglang:time_to_first_token_seconds_bucket{le="0.04",model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0 +sglang:time_to_first_token_seconds_bucket{le="0.06",model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0 +sglang:time_to_first_token_seconds_bucket{le="0.08",model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0 +sglang:time_to_first_token_seconds_bucket{le="0.1",model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0 +sglang:time_to_first_token_seconds_bucket{le="0.25",model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0 +sglang:time_to_first_token_seconds_bucket{le="0.5",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:time_to_first_token_seconds_bucket{le="0.75",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:time_to_first_token_seconds_bucket{le="1.0",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:time_to_first_token_seconds_bucket{le="2.5",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:time_to_first_token_seconds_bucket{le="5.0",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:time_to_first_token_seconds_bucket{le="7.5",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:time_to_first_token_seconds_bucket{le="10.0",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:time_to_first_token_seconds_bucket{le="15.0",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:time_to_first_token_seconds_bucket{le="20.0",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:time_to_first_token_seconds_bucket{le="25.0",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:time_to_first_token_seconds_bucket{le="30.0",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:time_to_first_token_seconds_bucket{le="+Inf",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:time_to_first_token_seconds_count{model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 # HELP sglang:e2e_request_latency_seconds Histogram of End-to-end request latency in seconds # TYPE sglang:e2e_request_latency_seconds histogram -sglang:e2e_request_latency_seconds_sum{name="google/gemma-2-9b-it"} 70517.99934530258 -sglang:e2e_request_latency_seconds_bucket{le="1.0",name="google/gemma-2-9b-it"} 2.0 -sglang:e2e_request_latency_seconds_bucket{le="2.0",name="google/gemma-2-9b-it"} 21.0 -sglang:e2e_request_latency_seconds_bucket{le="5.0",name="google/gemma-2-9b-it"} 54.0 -sglang:e2e_request_latency_seconds_bucket{le="10.0",name="google/gemma-2-9b-it"} 311.0 -sglang:e2e_request_latency_seconds_bucket{le="20.0",name="google/gemma-2-9b-it"} 733.0 -sglang:e2e_request_latency_seconds_bucket{le="50.0",name="google/gemma-2-9b-it"} 1563.0 -sglang:e2e_request_latency_seconds_bucket{le="100.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:e2e_request_latency_seconds_bucket{le="200.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:e2e_request_latency_seconds_bucket{le="500.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:e2e_request_latency_seconds_bucket{le="1000.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:e2e_request_latency_seconds_bucket{le="2000.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:e2e_request_latency_seconds_bucket{le="5000.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:e2e_request_latency_seconds_bucket{le="10000.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:e2e_request_latency_seconds_bucket{le="20000.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:e2e_request_latency_seconds_bucket{le="50000.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:e2e_request_latency_seconds_bucket{le="100000.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:e2e_request_latency_seconds_bucket{le="+Inf",name="google/gemma-2-9b-it"} 2104.0 -sglang:e2e_request_latency_seconds_count{name="google/gemma-2-9b-it"} 2104.0 -# HELP sglang:waiting_request_latency_seconds Histogram of request waiting time in seconds -# TYPE sglang:waiting_request_latency_seconds histogram -sglang:waiting_request_latency_seconds_sum{name="google/gemma-2-9b-it"} 24885.007263183594 -sglang:waiting_request_latency_seconds_bucket{le="1.0",name="google/gemma-2-9b-it"} 421.0 -sglang:waiting_request_latency_seconds_bucket{le="2.0",name="google/gemma-2-9b-it"} 563.0 -sglang:waiting_request_latency_seconds_bucket{le="5.0",name="google/gemma-2-9b-it"} 900.0 -sglang:waiting_request_latency_seconds_bucket{le="10.0",name="google/gemma-2-9b-it"} 1270.0 -sglang:waiting_request_latency_seconds_bucket{le="20.0",name="google/gemma-2-9b-it"} 1623.0 -sglang:waiting_request_latency_seconds_bucket{le="50.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:waiting_request_latency_seconds_bucket{le="100.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:waiting_request_latency_seconds_bucket{le="200.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:waiting_request_latency_seconds_bucket{le="500.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:waiting_request_latency_seconds_bucket{le="1000.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:waiting_request_latency_seconds_bucket{le="2000.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:waiting_request_latency_seconds_bucket{le="5000.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:waiting_request_latency_seconds_bucket{le="10000.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:waiting_request_latency_seconds_bucket{le="20000.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:waiting_request_latency_seconds_bucket{le="50000.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:waiting_request_latency_seconds_bucket{le="100000.0",name="google/gemma-2-9b-it"} 2104.0 -sglang:waiting_request_latency_seconds_bucket{le="+Inf",name="google/gemma-2-9b-it"} 2104.0 -sglang:waiting_request_latency_seconds_count{name="google/gemma-2-9b-it"} 2104.0 +sglang:e2e_request_latency_seconds_sum{model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.30521273612976074 +sglang:e2e_request_latency_seconds_bucket{le="0.3",model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0 +sglang:e2e_request_latency_seconds_bucket{le="0.5",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:e2e_request_latency_seconds_bucket{le="0.8",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:e2e_request_latency_seconds_bucket{le="1.0",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:e2e_request_latency_seconds_bucket{le="1.5",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:e2e_request_latency_seconds_bucket{le="2.0",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:e2e_request_latency_seconds_bucket{le="2.5",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:e2e_request_latency_seconds_bucket{le="5.0",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:e2e_request_latency_seconds_bucket{le="10.0",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:e2e_request_latency_seconds_bucket{le="15.0",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:e2e_request_latency_seconds_bucket{le="20.0",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:e2e_request_latency_seconds_bucket{le="30.0",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:e2e_request_latency_seconds_bucket{le="40.0",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:e2e_request_latency_seconds_bucket{le="50.0",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:e2e_request_latency_seconds_bucket{le="60.0",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:e2e_request_latency_seconds_bucket{le="+Inf",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:e2e_request_latency_seconds_count{model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +# HELP sglang:time_per_output_token_seconds Histogram of time per output token in seconds. +# TYPE sglang:time_per_output_token_seconds histogram +sglang:time_per_output_token_seconds_sum{model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0381757915019989 +sglang:time_per_output_token_seconds_bucket{le="0.005",model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0 +sglang:time_per_output_token_seconds_bucket{le="0.01",model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0 +sglang:time_per_output_token_seconds_bucket{le="0.015",model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0 +sglang:time_per_output_token_seconds_bucket{le="0.02",model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0 +sglang:time_per_output_token_seconds_bucket{le="0.025",model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0 +sglang:time_per_output_token_seconds_bucket{le="0.03",model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0 +sglang:time_per_output_token_seconds_bucket{le="0.04",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:time_per_output_token_seconds_bucket{le="0.05",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:time_per_output_token_seconds_bucket{le="0.075",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:time_per_output_token_seconds_bucket{le="0.1",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:time_per_output_token_seconds_bucket{le="0.15",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:time_per_output_token_seconds_bucket{le="0.2",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:time_per_output_token_seconds_bucket{le="0.3",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:time_per_output_token_seconds_bucket{le="0.4",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:time_per_output_token_seconds_bucket{le="0.5",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:time_per_output_token_seconds_bucket{le="0.75",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:time_per_output_token_seconds_bucket{le="1.0",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:time_per_output_token_seconds_bucket{le="2.5",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:time_per_output_token_seconds_bucket{le="+Inf",model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +sglang:time_per_output_token_seconds_count{model_name="meta-llama/Llama-3.1-8B-Instruct"} 1.0 +# HELP sglang:func_latency_seconds Function latency in seconds +# TYPE sglang:func_latency_seconds histogram +sglang:func_latency_seconds_sum{name="generate_request"} 0.3061351010110229 +sglang:func_latency_seconds_bucket{le="0.05",name="generate_request"} 0.0 +sglang:func_latency_seconds_bucket{le="0.07500000000000001",name="generate_request"} 0.0 +sglang:func_latency_seconds_bucket{le="0.1125",name="generate_request"} 0.0 +sglang:func_latency_seconds_bucket{le="0.16875",name="generate_request"} 0.0 +sglang:func_latency_seconds_bucket{le="0.253125",name="generate_request"} 0.0 +sglang:func_latency_seconds_bucket{le="0.3796875",name="generate_request"} 1.0 +sglang:func_latency_seconds_bucket{le="0.56953125",name="generate_request"} 1.0 +sglang:func_latency_seconds_bucket{le="0.8542968750000001",name="generate_request"} 1.0 +sglang:func_latency_seconds_bucket{le="1.2814453125",name="generate_request"} 1.0 +sglang:func_latency_seconds_bucket{le="1.9221679687500002",name="generate_request"} 1.0 +sglang:func_latency_seconds_bucket{le="2.8832519531250003",name="generate_request"} 1.0 +sglang:func_latency_seconds_bucket{le="4.3248779296875",name="generate_request"} 1.0 +sglang:func_latency_seconds_bucket{le="6.487316894531251",name="generate_request"} 1.0 +sglang:func_latency_seconds_bucket{le="9.730975341796876",name="generate_request"} 1.0 +sglang:func_latency_seconds_bucket{le="14.596463012695313",name="generate_request"} 1.0 +sglang:func_latency_seconds_bucket{le="21.89469451904297",name="generate_request"} 1.0 +sglang:func_latency_seconds_bucket{le="32.84204177856446",name="generate_request"} 1.0 +sglang:func_latency_seconds_bucket{le="49.26306266784668",name="generate_request"} 1.0 +sglang:func_latency_seconds_bucket{le="+Inf",name="generate_request"} 1.0 +sglang:func_latency_seconds_count{name="generate_request"} 1.0 +# HELP sglang:num_running_reqs The number of running requests +# TYPE sglang:num_running_reqs gauge +sglang:num_running_reqs{model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0 +# HELP sglang:num_used_tokens The number of used tokens +# TYPE sglang:num_used_tokens gauge +sglang:num_used_tokens{model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0 +# HELP sglang:gen_throughput The generate throughput (token/s) +# TYPE sglang:gen_throughput gauge +sglang:gen_throughput{model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0 +# HELP sglang:num_queue_reqs The number of requests in the waiting queue +# TYPE sglang:num_queue_reqs gauge +sglang:num_queue_reqs{model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0 +# HELP sglang:token_usage The token usage +# TYPE sglang:token_usage gauge +sglang:token_usage{model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0 +# HELP sglang:cache_hit_rate The cache hit rate +# TYPE sglang:cache_hit_rate gauge +sglang:cache_hit_rate{model_name="meta-llama/Llama-3.1-8B-Instruct"} 0.0 ``` ## Setup Guide diff --git a/docs/references/sampling_params.md b/docs/references/sampling_params.md index 147e6c2abe7..77d7c9f82e7 100644 --- a/docs/references/sampling_params.md +++ b/docs/references/sampling_params.md @@ -1,8 +1,7 @@ # Sampling Parameters in SGLang Runtime This doc describes the sampling parameters of the SGLang Runtime. It is the low-level endpoint of the runtime. -If you want a high-level endpoint that can automatically handle chat templates, consider using the [OpenAI Compatible API -](https://github.com/sgl-project/sglang?tab=readme-ov-file#openai-compatible-api). +If you want a high-level endpoint that can automatically handle chat templates, consider using the [OpenAI Compatible API](../backend/openai_api_completions.ipynb). The `/generate` endpoint accepts the following arguments in the JSON format. @@ -33,6 +32,20 @@ class GenerateReqInput: return_text_in_logprobs: bool = False # Whether to stream output. stream: bool = False + # Whether to log metrics for this request (e.g. health_generate calls do not log metrics) + log_metrics: bool = True + + # The modalities of the image data [image, multi-images, video] + modalities: Optional[List[str]] = None + # LoRA related + lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None + + # Session info for continual prompting + session_params: Optional[Union[List[Dict], Dict]] = None + # Custom logit processor for advanced sampling control. Must be a serialized instance + # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py + # Use the processor's `to_str()` method to generate the serialized string. + custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None ``` The `sampling_params` follows this format @@ -40,10 +53,9 @@ The `sampling_params` follows this format ```python # The maximum number of output tokens max_new_tokens: int = 128, -# Stop when hitting any of the strings in this list. +# Stop when hitting any of the strings in this list stop: Optional[Union[str, List[str]]] = None, -# Stop when hitting any of the token_ids in this list. Could be useful when mixed with -# `min_new_tokens`. +# Stop when hitting any of the token_ids in this list stop_token_ids: Optional[List[int]] = [], # Sampling temperature temperature: float = 1.0, @@ -53,21 +65,26 @@ top_p: float = 1.0, top_k: int = -1, # Min-p sampling min_p: float = 0.0, -# Whether to ignore EOS token. +# Whether to ignore EOS token ignore_eos: bool = False, -# Whether to skip the special tokens during detokenization. +# Whether to skip the special tokens during detokenization skip_special_tokens: bool = True, -# Whether to add spaces between special tokens during detokenization. +# Whether to add spaces between special tokens during detokenization spaces_between_special_tokens: bool = True, -# Constrains the output to follow a given regular expression. -regex: Optional[str] = None, # Do parallel sampling and return `n` outputs. n: int = 1, -# Constrains the output to follow a given JSON schema. -# `regex` and `json_schema` cannot be set at the same time. + +## Structured Outputs +# Only one of the below three can be set for a request. + +# Constrain the output to follow a given JSON schema. json_schema: Optional[str] = None, +# Constrain the output to follow a given regular expression. +regex: Optional[str] = None, +# Constrain the output to follow a given EBNF grammar. +ebnf: Optional[str] = None, -## Penalties. See [Performance Implications on Penalties] section below for more informations. +## Penalties. # Float that penalizes new tokens based on their frequency in the generated text so far. # Values > 0 encourage the model to use new tokens, while values < 0 encourage the model to @@ -87,6 +104,14 @@ repetition_penalty: float = 1.0, # difficult to infer the correct token ID by given `stop` strings. # Must be 0 <= value < max_new_tokens. Setting to 0 (default) will disable this penalty. min_new_tokens: int = 0, + + +## Custom Parameters for Custom Logit Processor. +# A dictionary of custom parameters for the custom logit processor. +# The custom logit processor takes a list of dictionaries as input, where each +# dictionary is the custom parameters for one token in a batch of the input. +# See also python/sglang/srt/sampling/custom_logit_processor.py +custom_params: Optional[Dict[str, Any]] = None, ``` ## Examples @@ -180,25 +205,35 @@ print(response.json()) The `image_data` can be a file name, a URL, or a base64 encoded string. See also `python/sglang/srt/utils.py:load_image`. Streaming is supported in a similar manner as [above](#streaming). -### Structured decoding (JSON, Regex) -You can specify a JSON schema or a regular expression to constrain the model output. The model output will be guaranteed to follow the given constraints. +### Structured Outputs (JSON, Regex, EBNF) +You can specify a JSON schema, regular expression or [EBNF](https://en.wikipedia.org/wiki/Extended_Backus%E2%80%93Naur_form) to constrain the model output. The model output will be guaranteed to follow the given constraints. Only one constraint parameter (`json_schema`, `regex`, or `ebnf`) can be specified for a request. + +SGLang supports two grammar backends: + +- [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints. +- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema, regular expression, and EBNF constraints. + - XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md) + +Initialize the XGrammar backend using `--grammar-backend xgrammar` flag +```bash +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \ +--port 30000 --host 0.0.0.0 --grammar-backend [xgrammar|outlines] # xgrammar or outlines (default: outlines) +``` ```python import json import requests -json_schema = json.dumps( - { - "type": "object", - "properties": { - "name": {"type": "string", "pattern": "^[\\w]+$"}, - "population": {"type": "integer"}, - }, - "required": ["name", "population"], - } -) +json_schema = json.dumps({ + "type": "object", + "properties": { + "name": {"type": "string", "pattern": "^[\\w]+$"}, + "population": {"type": "integer"}, + }, + "required": ["name", "population"], +}) -# JSON +# JSON (works with both Outlines and XGrammar) response = requests.post( "http://localhost:30000/generate", json={ @@ -212,7 +247,7 @@ response = requests.post( ) print(response.json()) -# Regular expression +# Regular expression (Outlines backend only) response = requests.post( "http://localhost:30000/generate", json={ @@ -225,4 +260,64 @@ response = requests.post( }, ) print(response.json()) + +# EBNF (XGrammar backend only) +response = requests.post( + "http://localhost:30000/generate", + json={ + "text": "Write a greeting.", + "sampling_params": { + "temperature": 0, + "max_new_tokens": 64, + "ebnf": 'root ::= "Hello" | "Hi" | "Hey"', + }, + }, +) +print(response.json()) +``` +### Custom Logit Processor +Launch a server with `--enable-custom-logit-processor` flag on. +``` +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --enable-custom-logit-processor +``` + +Define a custom logit processor that will always sample a specific token id. +```python +from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor + +class DeterministicLogitProcessor(CustomLogitProcessor): + """A dummy logit processor that changes the logits to always + sample the given token id. + """ + + def __call__(self, logits, custom_param_list): + # Check that the number of logits matches the number of custom parameters + assert logits.shape[0] == len(custom_param_list) + key = "token_id" + + for i, param_dict in enumerate(custom_param_list): + # Mask all other tokens + logits[i, :] = -float("inf") + # Assign highest probability to the specified token + logits[i, param_dict[key]] = 0.0 + return logits +``` + +Send a request +```python +import requests + +response = requests.post( + "http://localhost:30000/generate", + json={ + "text": "The capital of France is", + "custom_logit_processor": DeterministicLogitProcessor().to_str(), + "sampling_params": { + "temperature": 0.0, + "max_new_tokens": 32, + "custom_params": {"token_id": 5}, + }, + }, +) +print(response.json()) ``` diff --git a/docs/references/supported_models.md b/docs/references/supported_models.md index bf1044f8498..85de12f9f47 100644 --- a/docs/references/supported_models.md +++ b/docs/references/supported_models.md @@ -2,10 +2,10 @@ ## Generative Models - Llama / Llama 2 / Llama 3 / Llama 3.1 / Llama 3.2 -- Mistral / Mixtral / Mistral NeMo +- Mistral / Mixtral / Mistral NeMo / Mistral Small 3 - Gemma / Gemma 2 - Qwen / Qwen 2 / Qwen 2 MoE / Qwen 2 VL -- DeepSeek / DeepSeek 2 +- DeepSeek / DeepSeek 2 / [DeepSeek 3](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3) - OLMoE - [LLaVA-OneVision](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/) - `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-7b-ov --port=30000 --chat-template=chatml-llava` @@ -24,11 +24,13 @@ - InternLM 2 - Exaone 3 - BaiChuan2 -- MiniCPM / MiniCPM 3 +- MiniCPM / MiniCPM 3 / MiniCPMV - XVERSE / XVERSE MoE - SmolLM - GLM-4 +- Phi-3 / Phi-4 - Phi-3-Small +- IBM Granite 3 ## Embedding Models @@ -76,10 +78,12 @@ Another valuable resource is the [vLLM Models Directory](https://github.com/vllm To port a model from vLLM to SGLang, you can compare these two files [SGLang Llama Implementation](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama.py) and [vLLM Llama Implementation](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py). This comparison will help you understand how to convert a model implementation from vLLM to SGLang. The major difference is the replacement of Attention with RadixAttention. The other parts are almost identical. Specifically, - Replace vllm's `Attention` with `RadixAttention`. Note that you need to pass `layer_id` all the way to `RadixAttention`. - Replace vllm's `LogitsProcessor` with SGLang's `LogitsProcessor`. + - Replace Multi-headed `Attention` of ViT with SGLang's `VisionAttention`. - Replace other vLLM layers with SGLang layers (e.g., `RMSNorm`, `SiluAndMul`). - Remove `Sample`. - Change `forward()` functions, and add `forward_batch`. - Add `EntryClass` at the end. + - Please ensure the new implementation uses **only SGLang components and does not rely on any vLLM components**. ### Registering an external model implementation @@ -89,7 +93,7 @@ Here is how you can do it: ```python from sglang.srt.models.registry import ModelRegistry -from sglang.srt.server import launch_server +from sglang.srt.entrypoints.http_server import launch_server # for a single model, you can add it to the registry ModelRegistry.models[model_name] = model_class diff --git a/docs/references/torch_compile_cache.md b/docs/references/torch_compile_cache.md new file mode 100644 index 00000000000..f2bb257f430 --- /dev/null +++ b/docs/references/torch_compile_cache.md @@ -0,0 +1,13 @@ +# Enabling cache for torch.compile + +SGLang uses `max-autotune-no-cudagraphs` mode of torch.compile. The auto-tuning can be slow. +If you want to deploy a model on many different machines, you can ship the torch.compile cache to these machines and skip the compilation steps. + +This is based on https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html + + +1. Generate the cache by setting TORCHINDUCTOR_CACHE_DIR and running the model once. +``` +TORCHINDUCTOR_CACHE_DIR=/root/inductor_root_cache python3 -m sglang.launch_server --model meta-llama/Llama-3.1-8B-Instruct --enable-torch-compile +``` +2. Copy the cache folder to other machines and launch the server with `TORCHINDUCTOR_CACHE_DIR`. diff --git a/docs/router/router.md b/docs/router/router.md index 11ea8c59065..b5104926239 100644 --- a/docs/router/router.md +++ b/docs/router/router.md @@ -7,14 +7,14 @@ The router is a independent Python package, and it can be used as a drop-in repl ## Installation ```bash -pip install sglang-router +$ pip install sglang-router ``` -Detailed usage of the router can be found in [launch_router](https://github.com/sgl-project/sglang/blob/main/rust/py_src/sglang_router/launch_router.py) and [launch_server](https://github.com/sgl-project/sglang/blob/main/rust/py_src/sglang/launch_server.py). Also, you can directly run the following command to see the usage of the router. +Detailed usage of the router can be found in [launch_router](https://github.com/sgl-project/sglang/blob/main/sgl-router/py_src/sglang_router/launch_router.py) and [launch_server](https://github.com/sgl-project/sglang/blob/main/sgl-router/py_src/sglang/launch_server.py). Also, you can directly run the following command to see the usage of the router. ```bash -python -m sglang_router.launch_server --help -python -m sglang_router.launch_routher --help +$ python -m sglang_router.launch_server --help +$ python -m sglang_router.launch_router --help ``` The router supports two working modes: @@ -27,7 +27,7 @@ The router supports two working modes: This will be a drop-in replacement for the existing `--dp-size` arguement of SGLang Runtime. Under the hood, it uses multi-processes to launch multiple workers, wait for them to be ready, then connect the router to all workers. ```bash -python -m sglang_router.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dp-size 1 +$ python -m sglang_router.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dp-size 1 ``` After the server is ready, you can directly send requests to the router as the same way as sending requests to each single worker. @@ -47,12 +47,62 @@ print(response.json()) This is useful for multi-node DP. First, launch workers on multiple nodes, then launch a router on the main node, and connect the router to all workers. ```bash -python -m sglang_router.launch_router --worker-urls http://worker_url_1 http://worker_url_2 +$ python -m sglang_router.launch_router --worker-urls http://worker_url_1 http://worker_url_2 ``` -## Strategies +## Dynamic Scaling APIs -### Cache-Aware Load-Balancing Router +We offer `/add_worker` and `/remove_worker` APIs to dynamically add or remove workers from the router. + +- `/add_worker` + +Usage: + +```bash +$ curl -X POST http://localhost:30000/add_worker?url=http://worker_url_1 +``` + +Example: + +```bash +$ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30001 +$ curl -X POST http://localhost:30000/add_worker?url=http://127.0.0.1:30001 +Successfully added worker: http://127.0.0.1:30001 +``` + +- `/remove_worker` + +Usage: + +```bash +$ curl -X POST http://localhost:30000/remove_worker?url=http://worker_url_1 +``` + +Example: + +```bash +$ curl -X POST http://localhost:30000/remove_worker?url=http://127.0.0.1:30001 +Successfully removed worker: http://127.0.0.1:30001 +``` + +Note: + +- For cache-aware router, the worker will be removed from the tree and the queues. + +## Fault Tolerance + +We provide retries based for failure tolerance. + +1. If the request to a worker fails for `max_worker_retries` times, the router will remove the worker from the router and move on to the next worker. +2. If the total number of retries exceeds `max_total_retries`, the router will return an error. + +Note: + +- `max_worker_retries` is 3 and `max_total_retries` is 6 by default. + +## Routing Strategies + +#### Cache-Aware Load-Balancing Router The native router combines two strategies to optimize both cache utilization and request distribution: diff --git a/docs/start/install.md b/docs/start/install.md index e9d3abc8e78..a5012d6fc70 100644 --- a/docs/start/install.md +++ b/docs/start/install.md @@ -5,7 +5,8 @@ You can install SGLang using any of the methods below. ## Method 1: With pip ``` pip install --upgrade pip -pip install "sglang[all]" --find-links https://flashinfer.ai/whl/cu121/torch2.4/flashinfer/ +pip install sgl-kernel --force-reinstall --no-deps +pip install "sglang[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/ ``` Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions. @@ -13,23 +14,25 @@ Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/ ## Method 2: From source ``` # Use the last release branch -git clone -b v0.4.0.post1 https://github.com/sgl-project/sglang.git +git clone -b v0.4.2.post1 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip -pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu121/torch2.4/flashinfer/ +pip install sgl-kernel --force-reinstall --no-deps +pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/ ``` -Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions. +Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions. If you meet with issue like **ImportError: cannot import name `_grouped_size_compiled_for_decode_kernels`**, installing FlashInfer with some older version like 0.1.6 instead of the latest version could solve it. Note: To AMD ROCm system with Instinct/MI GPUs, do following instead: ``` # Use the last release branch -git clone -b v0.4.0.post1 https://github.com/sgl-project/sglang.git +git clone -b v0.4.2.post1 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip +pip install sgl-kernel --force-reinstall --no-deps pip install -e "python[all_hip]" ``` @@ -51,7 +54,7 @@ docker run --gpus all \ Note: To AMD ROCm system with Instinct/MI GPUs, it is recommended to use `docker/Dockerfile.rocm` to build images, example and usage as below: ```bash -docker build --build-arg SGL_BRANCH=v0.4.0.post1 -t v0.4.0.post1-rocm620 -f Dockerfile.rocm . +docker build --build-arg SGL_BRANCH=v0.4.2.post1 -t v0.4.2.post1-rocm620 -f Dockerfile.rocm . alias drun='docker run -it --rm --network=host --device=/dev/kfd --device=/dev/dri --ipc=host \ --shm-size 16G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ @@ -60,11 +63,11 @@ alias drun='docker run -it --rm --network=host --device=/dev/kfd --device=/dev/d drun -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=" \ - v0.4.0.post1-rocm620 \ + v0.4.2.post1-rocm620 \ python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --host 0.0.0.0 --port 30000 # Till flashinfer backend available, --attention-backend triton --sampling-backend pytorch are set by default -drun v0.4.0.post1-rocm620 python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8 +drun v0.4.2.post1-rocm620 python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8 ``` ## Method 4: Using docker compose diff --git a/examples/frontend_language/quick_start/shortfin_example_chat.py b/examples/frontend_language/quick_start/shortfin_example_chat.py index c6d0f6a5acc..107805c68e2 100644 --- a/examples/frontend_language/quick_start/shortfin_example_chat.py +++ b/examples/frontend_language/quick_start/shortfin_example_chat.py @@ -23,6 +23,7 @@ def multi_turn_question(s, question_1, question_2): s += sgl.user(question_2) s += sgl.assistant(sgl.gen("answer_2", max_tokens=256)) + @sgl.function def tip_suggestion(s): s += ( @@ -39,6 +40,7 @@ def tip_suggestion(s): s += "Tip 2:" + forks[1]["detailed_tip"] + "\n" s += "In summary" + sgl.gen("summary") + def single(): state = multi_turn_question.run( question_1="What is the capital of the United States?", @@ -49,6 +51,8 @@ def single(): print(m["role"], ":", m["content"]) print("\n-- answer_1 --\n", state["answer_1"]) + print("\n-- answer_2 --\n", state["answer_2"]) + def stream(): state = multi_turn_question.run( @@ -61,10 +65,12 @@ def stream(): print(out, end="", flush=True) print() + def fork(): state = tip_suggestion.run() print(state.text()) + def batch(): states = multi_turn_question.run_batch( [ @@ -86,6 +92,7 @@ def batch(): print() print() + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--base_url", default="http://localhost:8000") diff --git a/examples/frontend_language/usage/json_decode.py b/examples/frontend_language/usage/json_decode.py index ce8f5ba7062..5dc3522d512 100644 --- a/examples/frontend_language/usage/json_decode.py +++ b/examples/frontend_language/usage/json_decode.py @@ -9,7 +9,7 @@ from pydantic import BaseModel import sglang as sgl -from sglang.srt.constrained import build_regex_from_object +from sglang.srt.constrained.outlines_backend import build_regex_from_object character_regex = ( r"""\{\n""" diff --git a/examples/frontend_language/usage/triton/models/character_generation/1/model.py b/examples/frontend_language/usage/triton/models/character_generation/1/model.py index 5550e93984b..4bf86f1b691 100644 --- a/examples/frontend_language/usage/triton/models/character_generation/1/model.py +++ b/examples/frontend_language/usage/triton/models/character_generation/1/model.py @@ -3,8 +3,8 @@ from pydantic import BaseModel import sglang as sgl -from sglang import function, set_default_backend -from sglang.srt.constrained import build_regex_from_object +from sglang import function +from sglang.srt.constrained.outlines_backend import build_regex_from_object sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000")) diff --git a/examples/runtime/async_io_api.py b/examples/runtime/async_io_api.py deleted file mode 100644 index 23d3d0b90bf..00000000000 --- a/examples/runtime/async_io_api.py +++ /dev/null @@ -1,46 +0,0 @@ -""" -Usage: - -python3 async_io.py -""" - -import asyncio - -from sglang import Runtime - - -async def generate( - engine, - prompt, - sampling_params, -): - tokenizer = engine.get_tokenizer() - - messages = [ - { - "role": "system", - "content": "You will be given question answer tasks.", - }, - {"role": "user", "content": prompt}, - ] - - prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - - stream = engine.add_request(prompt, sampling_params) - - async for output in stream: - print(output, end="", flush=True) - print() - - -if __name__ == "__main__": - runtime = Runtime(model_path="meta-llama/Llama-2-7b-chat-hf") - print("--- runtime ready ---\n") - - prompt = "Who is Alan Turing?" - sampling_params = {"max_new_tokens": 128} - asyncio.run(generate(runtime, prompt, sampling_params)) - - runtime.shutdown() diff --git a/examples/runtime/engine/EAGLE_offline_batch_inference.py b/examples/runtime/engine/EAGLE_offline_batch_inference.py new file mode 100644 index 00000000000..0885959b3fc --- /dev/null +++ b/examples/runtime/engine/EAGLE_offline_batch_inference.py @@ -0,0 +1,37 @@ +import sglang as sgl + + +def main(): + # Sample prompts. + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + # Create a sampling params object. + sampling_params = {"temperature": 0, "max_new_tokens": 30} + + # Create an LLM. + llm = sgl.Engine( + model_path="meta-llama/Llama-2-7b-chat-hf", + speculative_algorithm="EAGLE", + speculative_draft_model_path="lmzheng/sglang-EAGLE-llama2-chat-7B", + speculative_num_steps=3, + speculative_eagle_topk=4, + speculative_num_draft_tokens=16, + ) + + outputs = llm.generate(prompts, sampling_params) + + # Print the outputs. + for prompt, output in zip(prompts, outputs): + print("===============================") + print(f"Prompt: {prompt}\nGenerated text: {output['text']}") + + +# The __main__ condition is necessary here because we use "spawn" to create subprocesses +# Spawn starts a fresh program every time, if there is no __main__, it will run into infinite loop to keep spawning processes from sgl.Engine +if __name__ == "__main__": + main() diff --git a/examples/runtime/engine/offline_batch_inference.py b/examples/runtime/engine/offline_batch_inference.py index 724051eab53..92e68dcd72c 100644 --- a/examples/runtime/engine/offline_batch_inference.py +++ b/examples/runtime/engine/offline_batch_inference.py @@ -1,3 +1,8 @@ +""" +Usage: +python3 offline_batch_inference.py --model meta-llama/Llama-3.1-8B-Instruct +""" + import argparse import dataclasses diff --git a/python/pyproject.toml b/python/pyproject.toml index 7a19ac649c7..8442ff5d2de 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sglang" -version = "0.4.0.post1" +version = "0.4.2.post1" description = "SGLang is yet another fast serving framework for large language models and vision language models." readme = "README.md" requires-python = ">=3.8" @@ -13,31 +13,42 @@ classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", ] -dependencies = ["requests", "tqdm", "numpy", "IPython"] +dependencies = ["requests", "tqdm", "numpy", "IPython", "setproctitle"] [project.optional-dependencies] -runtime_common = ["aiohttp", "decord", "fastapi", +runtime_common = [ + "aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "modelscope", "orjson", "outlines>=0.0.44,<0.1.0", "packaging", "pillow", "prometheus-client>=0.20.0", "psutil", "pydantic", "python-multipart", - "pyzmq>=25.1.2", "torchao", "uvicorn", "uvloop", - "xgrammar>=0.1.4"] -srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cuda-python", "flashinfer>=0.1.6"] + "pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop", + "xgrammar>=0.1.10" +] +srt = [ + "sglang[runtime_common]", "cuda-python", + "sgl-kernel>=0.0.3", "torch", "vllm==0.6.4.post1", + "flashinfer==0.1.6" +] # HIP (Heterogeneous-computing Interface for Portability) for AMD # => base docker rocm/vllm-dev:20241022, not from public vllm whl -srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.3.dev13"] +srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.3.post2.dev1"] # xpu is not enabled in public vllm and torch whl, # need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm srt_xpu = ["sglang[runtime_common]"] #For Intel Gaudi(device : hpu) follow the installation guide #https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html -srt_hpu = ["sglang[runtime_common]"] +srt_hpu = ["sglang[runtime_common]"] +# CPU: currently, there are no pre-built vllm wheels for CPU. +# To install vllm for CPU, please follow the instruction here: +# https://docs.vllm.ai/en/latest/getting_started/installation/cpu/index.html +srt_cpu = ["sglang[runtime_common]", "torch"] openai = ["openai>=1.0", "tiktoken"] anthropic = ["anthropic>=0.20.0"] litellm = ["litellm>=1.0.0"] +torch_memory_saver = ["torch_memory_saver"] test = [ "jsonlines", "matplotlib", @@ -50,15 +61,21 @@ all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] all_hip = ["sglang[srt_hip]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] all_hpu = ["sglang[srt_hpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] +all_cpu = ["sglang[srt_cpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] + dev = ["sglang[all]", "sglang[test]"] dev_hip = ["sglang[all_hip]", "sglang[test]"] dev_xpu = ["sglang[all_xpu]", "sglang[test]"] dev_hpu = ["sglang[all_hpu]", "sglang[test]"] +dev_cpu = ["sglang[all_cpu]", "sglang[test]"] [project.urls] "Homepage" = "https://github.com/sgl-project/sglang" "Bug Tracker" = "https://github.com/sgl-project/sglang/issues" +[tool.setuptools.package-data] +"sglang" = ["srt/layers/moe/fused_moe_triton/configs/*.json", "srt/layers/quantization/configs/*.json"] + [tool.setuptools.packages.find] exclude = [ "assets*", diff --git a/python/sglang/README.md b/python/sglang/README.md index 29a7149defe..6221cdb2c27 100644 --- a/python/sglang/README.md +++ b/python/sglang/README.md @@ -11,4 +11,5 @@ - `check_env.py`: Check the environment variables. - `global_config.py`: The global configs and constants. - `launch_server.py`: The entry point for launching the local server. +- `llama3_eval.py`: Evaluation of Llama 3.1 using the Meta Llama dataset. - `utils.py`: Common utilities. diff --git a/python/sglang/__init__.py b/python/sglang/__init__.py index cd1eb1c1281..22b070a3dae 100644 --- a/python/sglang/__init__.py +++ b/python/sglang/__init__.py @@ -1,5 +1,6 @@ -# SGL API Components +# SGLang public APIs +# Frontend Language APIs from sglang.api import ( Engine, Runtime, @@ -23,16 +24,27 @@ user_end, video, ) +from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.lang.choices import ( greedy_token_selection, token_length_normalized, unconditional_likelihood_normalized, ) +from sglang.utils import LazyImport + +Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic") +LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM") +OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI") +VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI") +Shortfin = LazyImport("sglang.lang.backend.shortfin", "Shortfin") + +# Other configs +from sglang.global_config import global_config +from sglang.version import __version__ -# SGLang DSL APIs __all__ = [ - "Runtime", "Engine", + "Runtime", "assistant", "assistant_begin", "assistant_end", @@ -52,28 +64,15 @@ "user_begin", "user_end", "video", + "RuntimeEndpoint", "greedy_token_selection", "token_length_normalized", "unconditional_likelihood_normalized", + "Anthropic", + "LiteLLM", + "OpenAI", + "VertexAI", + "Shortfin", + "global_config", + "__version__", ] - -# Global Configurations -from sglang.global_config import global_config - -__all__ += ["global_config"] - -from sglang.version import __version__ - -__all__ += ["__version__"] - -# SGLang Backends -from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint -from sglang.utils import LazyImport - -Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic") -LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM") -OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI") -VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI") -Shortfin = LazyImport("sglang.lang.backend.shortfin", "Shortfin") - -__all__ += ["Anthropic", "LiteLLM", "OpenAI", "VertexAI", "Shortfin", "RuntimeEndpoint"] diff --git a/python/sglang/api.py b/python/sglang/api.py index 9a30ad492da..7ef306380a9 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -1,6 +1,5 @@ """Public APIs of the language.""" -import os import re from typing import Callable, List, Optional, Union @@ -33,19 +32,15 @@ def decorator(func): def Runtime(*args, **kwargs): - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - # Avoid importing unnecessary dependency - from sglang.srt.server import Runtime + from sglang.lang.backend.runtime_endpoint import Runtime return Runtime(*args, **kwargs) def Engine(*args, **kwargs): - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - # Avoid importing unnecessary dependency - from sglang.srt.server import Engine + from sglang.srt.entrypoints.engine import Engine return Engine(*args, **kwargs) diff --git a/python/sglang/bench_offline_throughput.py b/python/sglang/bench_offline_throughput.py index 2e9eb1ad223..9d56ff07c8b 100644 --- a/python/sglang/bench_offline_throughput.py +++ b/python/sglang/bench_offline_throughput.py @@ -27,7 +27,8 @@ sample_random_requests, set_ulimit, ) -from sglang.srt.server import Engine, Runtime +from sglang.lang.backend.runtime_endpoint import Runtime +from sglang.srt.entrypoints.engine import Engine from sglang.srt.server_args import ServerArgs @@ -39,20 +40,22 @@ class BenchArgs: dataset_path: str = "" num_prompts: int = 1000 sharegpt_output_len: Optional[int] = None + sharegpt_context_len: Optional[int] = None random_input_len: int = 1024 random_output_len: int = 1024 random_range_ratio: float = 0.0 - gen_num_groups: int = 64 - gen_prompts_per_group: int = 16 - gen_system_prompt_len: int = 2048 - gen_question_len: int = 128 - gen_output_len: int = 256 + gsp_num_groups: int = 64 + gsp_prompts_per_group: int = 16 + gsp_system_prompt_len: int = 2048 + gsp_question_len: int = 128 + gsp_output_len: int = 256 + seed: int = 1 disable_ignore_eos: bool = False extra_request_body: Optional[str] = None - seed: int = 1 + apply_chat_template: bool = False + profile: bool = False skip_warmup: bool = False do_not_exit: bool = False - profile: bool = False @staticmethod def add_cli_args(parser: argparse.ArgumentParser): @@ -82,6 +85,12 @@ def add_cli_args(parser: argparse.ArgumentParser): default=BenchArgs.sharegpt_output_len, help="Output length for each request. Overrides the output length from the ShareGPT dataset.", ) + parser.add_argument( + "--sharegpt-context-len", + type=int, + default=BenchArgs.sharegpt_context_len, + help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.", + ) parser.add_argument( "--random-input-len", type=int, @@ -102,51 +111,62 @@ def add_cli_args(parser: argparse.ArgumentParser): "used only for random dataset.", ) parser.add_argument( - "--gen-num-groups", + "--gsp-num-groups", type=int, - default=BenchArgs.gen_num_groups, + default=BenchArgs.gsp_num_groups, help="Number of groups with shared prefix, used" "only for generate-shared-prefix", ) parser.add_argument( - "--gen-prompts-per-group", + "--gsp-prompts-per-group", type=int, - default=BenchArgs.gen_prompts_per_group, + default=BenchArgs.gsp_prompts_per_group, help="Number of prompts per group of shared prefix, used" "only for generate-shared-prefix", ) parser.add_argument( - "--gen-system-prompt-len", + "--gsp-system-prompt-len", type=int, - default=BenchArgs.gen_system_prompt_len, + default=BenchArgs.gsp_system_prompt_len, help="System prompt length, used" "only for generate-shared-prefix", ) parser.add_argument( - "--gen-question-len", + "--gsp-question-len", type=int, - default=BenchArgs.gen_question_len, + default=BenchArgs.gsp_question_len, help="Question length, used" "only for generate-shared-prefix", ) parser.add_argument( - "--gen-output-len", + "--gsp-output-len", type=int, - default=BenchArgs.gen_output_len, + default=BenchArgs.gsp_output_len, help="Target length in tokens for outputs in generated-shared-prefix dataset", ) + parser.add_argument("--seed", type=int, default=1, help="The random seed.") parser.add_argument( "--disable-ignore-eos", - type=bool, - default=BenchArgs.disable_ignore_eos, + action="store_true", help="Disable ignore EOS token", ) parser.add_argument( "--extra-request-body", metavar='{"key1": "value1", "key2": "value2"}', type=str, + default=BenchArgs.extra_request_body, help="Append given JSON object to the request payload. You can use this to specify" "additional generate params like sampling params.", ) - parser.add_argument("--seed", type=int, default=1, help="The random seed.") + parser.add_argument( + "--apply-chat-template", + action="store_true", + help="Apply chat template", + ) + parser.add_argument( + "--profile", + action="store_true", + help="Use Torch Profiler. The endpoint must be launched with " + "SGLANG_TORCH_PROFILER_DIR to enable profiler.", + ) parser.add_argument( "--skip-warmup", action="store_true", @@ -157,12 +177,6 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Do not exit the program. This is useful for nsys profile with --duration and --delay.", ) - parser.add_argument( - "--profile", - action="store_true", - help="Use Torch Profiler. The endpoint must be launched with " - "SGLANG_TORCH_PROFILER_DIR to enable profiler.", - ) @classmethod def from_cli_args(cls, args: argparse.Namespace): @@ -201,18 +215,17 @@ def throughput_test_once( for r in reqs ] - st = time.perf_counter() if profile: backend.start_profile() + st = time.perf_counter() gen_out = backend.generate(prompt=prompt, sampling_params=sampling_params) + latency = time.perf_counter() - st if profile: backend.stop_profile() monitor_trace_file(os.getenv("SGLANG_TORCH_PROFILER_DIR")) - latency = time.perf_counter() - st - if backend_name == "runtime": gen_out = json.loads(gen_out) @@ -285,7 +298,7 @@ def throughput_test( else: raise ValueError('Please set backend to either "engine" or "runtime"') - tokenizer_id = server_args.model_path + tokenizer_id = server_args.tokenizer_path or server_args.model_path tokenizer = get_tokenizer(tokenizer_id) # Set global environmnets @@ -304,8 +317,8 @@ def throughput_test( warmup_requests = sample_random_requests( input_len=256, output_len=16, - num_prompts=16, - range_ratio=0.8, + num_prompts=min(bench_args.num_prompts, 16), + range_ratio=1.0, tokenizer=tokenizer, dataset_path=bench_args.dataset_path, ) @@ -321,6 +334,7 @@ def throughput_test( extra_request_body=extra_request_body, profile=False, ) + time.sleep(0.5) logging.info("\nBenchmark...") result = throughput_test_once( @@ -331,6 +345,7 @@ def throughput_test( extra_request_body=extra_request_body, profile=bench_args.profile, ) + backend.shutdown() if bench_args.result_filename: with open(bench_args.result_filename, "a") as fout: diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index e7a83139954..de846066e63 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -9,7 +9,8 @@ python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy ## sweep through multiple data points and store (append) the results in a jsonl file: python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run - +## run with profiling: +python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --profile # Usage (correctness test): python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct @@ -56,14 +57,21 @@ import torch.distributed as dist from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.entrypoints.engine import _set_envs_and_config from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_params import SamplingParams -from sglang.srt.server import _set_envs_and_config from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import configure_logger, kill_process_tree, suppress_other_loggers +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.utils import ( + configure_logger, + get_bool_env_var, + kill_process_tree, + set_gpu_proc_affinity, + suppress_other_loggers, +) @dataclasses.dataclass @@ -76,6 +84,8 @@ class BenchArgs: correctness_test: bool = False # This is only used for correctness test cut_len: int = 4 + profile: bool = False + profile_filename_prefix: str = "profile" @staticmethod def add_cli_args(parser: argparse.ArgumentParser): @@ -94,6 +104,16 @@ def add_cli_args(parser: argparse.ArgumentParser): ) parser.add_argument("--correctness-test", action="store_true") parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len) + parser.add_argument( + "--profile", action="store_true", help="Use Torch Profiler." + ) + parser.add_argument( + "--profile-filename-prefix", + type=str, + default=BenchArgs.profile_filename_prefix, + help="Prefix of the profiling file names. The full profiling result file(s) be " + '"[profile_filename_prefix]_batch[batch_size]_input[input_len]_output[output_len].trace.json.gz"', + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): @@ -214,6 +234,8 @@ def extend(reqs, model_runner): tree_cache=None, model_config=model_runner.model_config, enable_overlap=False, + spec_algorithm=SpeculativeAlgorithm.NONE, + enable_custom_logit_processor=False, ) batch.prepare_for_extend() model_worker_batch = batch.get_model_worker_batch() @@ -284,7 +306,16 @@ def synchronize(device): def latency_test_run_once( - run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len, device + run_name, + model_runner, + rank_print, + reqs, + batch_size, + input_len, + output_len, + device, + profile, + profile_filename_prefix, ): max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len) if batch_size > max_batch_size: @@ -306,6 +337,17 @@ def latency_test_run_once( tot_latency = 0 + profiler = None + if profile: + profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, + ) + profiler.start() + # Prefill synchronize(device) tic = time.time() @@ -336,6 +378,14 @@ def latency_test_run_once( f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s" ) + if profile: + profiler.stop() + profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}.trace.json.gz" + parent_dir = os.path.dirname(os.path.abspath(profile_filename)) + os.makedirs(parent_dir, exist_ok=True) + profiler.export_chrome_trace(profile_filename) + rank_print(f"torch profiler chrome trace saved to {profile_filename}") + # Record decode timing from 2nd output if output_len > 1: med_decode_latency = np.median(decode_latencies) @@ -361,6 +411,10 @@ def latency_test( bench_args, tp_rank, ): + # Set CPU affinity + if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"): + set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, tp_rank) + # Configure the logger configure_logger(server_args, prefix=f" TP{tp_rank}") rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None @@ -384,7 +438,10 @@ def latency_test( bench_args.input_len[0], 8, # shorter decoding to speed up the warmup server_args.device, + profile=False, + profile_filename_prefix="", # not used ) + rank_print("Benchmark ...") # Run the sweep @@ -402,6 +459,8 @@ def latency_test( il, ol, server_args.device, + bench_args.profile if tp_rank == 0 else None, + bench_args.profile_filename_prefix, ) if ret is not None: result_list.append(ret) diff --git a/python/sglang/bench_one_batch_server.py b/python/sglang/bench_one_batch_server.py index 01cc561e1ce..5f0759a7ce1 100644 --- a/python/sglang/bench_one_batch_server.py +++ b/python/sglang/bench_one_batch_server.py @@ -22,7 +22,7 @@ import numpy as np import requests -from sglang.srt.server import launch_server +from sglang.srt.entrypoints.http_server import launch_server from sglang.srt.server_args import ServerArgs from sglang.srt.utils import kill_process_tree diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 808a25ddda8..e8fdd3eadd3 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -321,6 +321,8 @@ async def async_request_sglang_generate( }, "stream": not args.disable_stream, "lora_path": request_func_input.lora_name, + "return_logprob": args.return_logprob, + "logprob_start_len": -1, **request_func_input.extra_request_body, } headers = {} @@ -527,6 +529,8 @@ def get_dataset(args, tokenizer): num_requests=args.num_prompts, tokenizer=tokenizer, fixed_output_len=args.sharegpt_output_len, + context_len=args.sharegpt_context_len, + apply_chat_template=args.apply_chat_template, ) elif args.dataset_name == "random": input_requests = sample_random_requests( @@ -539,11 +543,11 @@ def get_dataset(args, tokenizer): ) elif args.dataset_name == "generated-shared-prefix": input_requests = sample_generated_shared_prefix_requests( - num_groups=args.gen_num_groups, - prompts_per_group=args.gen_prompts_per_group, - system_prompt_len=args.gen_system_prompt_len, - question_len=args.gen_question_len, - output_len=args.gen_output_len, + num_groups=args.gsp_num_groups, + prompts_per_group=args.gsp_prompts_per_group, + system_prompt_len=args.gsp_system_prompt_len, + question_len=args.gsp_question_len, + output_len=args.gsp_output_len, tokenizer=tokenizer, ) else: @@ -590,6 +594,9 @@ class BenchmarkMetrics: p99_itl_ms: float mean_e2e_latency_ms: float median_e2e_latency_ms: float + std_e2e_latency_ms: float + p99_e2e_latency_ms: float + concurrency: float SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" @@ -634,12 +641,14 @@ def sample_sharegpt_requests( num_requests: int, tokenizer: PreTrainedTokenizerBase, fixed_output_len: Optional[int] = None, + context_len: Optional[int] = None, + apply_chat_template=False, ) -> List[Tuple[str, int, int]]: if fixed_output_len is not None and fixed_output_len < 4: raise ValueError("output_len too small") # Download sharegpt if necessary - if not os.path.isfile(dataset_path): + if not os.path.isfile(dataset_path) and dataset_path == "": dataset_path = download_and_cache_file(SHAREGPT_URL) # Load the dataset. @@ -664,6 +673,15 @@ def sample_sharegpt_requests( # Tokenize the prompts and completions. prompt = dataset[i][0] + + if apply_chat_template: + prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=True, + tokenize=False, + ) + prompt = prompt.replace(tokenizer.bos_token, "") + prompt_token_ids = tokenizer.encode(prompt) completion = dataset[i][1] completion_token_ids = tokenizer.encode(completion) @@ -671,14 +689,15 @@ def sample_sharegpt_requests( output_len = ( len(completion_token_ids) if fixed_output_len is None else fixed_output_len ) - if prompt_len < 4 or output_len < 4: + + if prompt_len < 2 or output_len < 2: # Prune too short sequences. continue - if prompt_len > 1024 or ( - prompt_len + output_len > 2048 and fixed_output_len is None - ): + + if context_len and prompt_len + output_len > context_len: # Prune too long sequences. continue + filtered_dataset.append((prompt, prompt_len, output_len)) print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}") @@ -780,8 +799,8 @@ def get_gen_prefix_cache_path(args, tokenizer): # Create a unique cache filename based on the generation parameters cache_key = ( - f"gen_prefix_{args.gen_num_groups}_{args.gen_prompts_per_group}_" - f"{args.gen_system_prompt_len}_{args.gen_question_len}_{args.gen_output_len}_" + f"gen_shared_prefix_{args.gsp_num_groups}_{args.gsp_prompts_per_group}_" + f"{args.gsp_system_prompt_len}_{args.gsp_question_len}_{args.gsp_output_len}_" f"{tokenizer.__class__.__name__}.pkl" ) return cache_dir / cache_key @@ -949,6 +968,9 @@ def calculate_metrics( p99_itl_ms=np.percentile(itls or 0, 99) * 1000, mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000, median_e2e_latency_ms=np.median(e2e_latencies) * 1000, + std_e2e_latency_ms=np.std(e2e_latencies) * 1000, + p99_e2e_latency_ms=np.percentile(e2e_latencies, 99) * 1000, + concurrency=np.sum(e2e_latencies) / dur_s, ) return metrics, output_lens @@ -973,6 +995,7 @@ async def benchmark( else: raise ValueError(f"Unknown backend: {backend}") + # Limit concurrency # From https://github.com/vllm-project/vllm/pull/9390 semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None @@ -982,6 +1005,7 @@ async def limited_request_func(request_func_input, pbar): async with semaphore: return await request_func(request_func_input=request_func_input, pbar=pbar) + # Warmup print("Starting initial single prompt test run...") test_prompt, test_prompt_len, test_output_len = input_requests[0] test_input = RequestFuncInput( @@ -989,7 +1013,7 @@ async def limited_request_func(request_func_input, pbar): prompt=test_prompt, api_url=api_url, prompt_len=test_prompt_len, - output_len=test_output_len, + output_len=min(test_output_len, 32), lora_name=lora_name, extra_request_body=extra_request_body, ) @@ -1002,8 +1026,13 @@ async def limited_request_func(request_func_input, pbar): else: print("Initial test run completed. Starting main benchmark run...") - time.sleep(1.5) + # Flush cache + if "sglang" in backend: + requests.post(base_url + "/flush_cache") + + time.sleep(1.0) + # Start profiler if profile: print("Starting profiler...") profile_output = await async_request_profile( @@ -1014,6 +1043,7 @@ async def limited_request_func(request_func_input, pbar): pbar = None if disable_tqdm else tqdm(total=len(input_requests)) + # Run all requests benchmark_start_time = time.perf_counter() tasks: List[asyncio.Task] = [] async for request in get_request(input_requests, request_rate): @@ -1034,6 +1064,7 @@ async def limited_request_func(request_func_input, pbar): ) outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) + # Stop profiler if profile: print("Stopping profiler...") profile_output = await async_request_profile(api_url=base_url + "/stop_profile") @@ -1043,8 +1074,8 @@ async def limited_request_func(request_func_input, pbar): if pbar is not None: pbar.close() + # Compute metrics and print results benchmark_duration = time.perf_counter() - benchmark_start_time - metrics, output_lens = calculate_metrics( input_requests=input_requests, outputs=outputs, @@ -1091,6 +1122,7 @@ async def limited_request_func(request_func_input, pbar): "Total token throughput (tok/s):", metrics.total_throughput ) ) + print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency)) print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-")) print( "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms) @@ -1122,24 +1154,41 @@ async def limited_request_func(request_func_input, pbar): and metrics.output_throughput is not None ): result = { + # Arguments "backend": args.backend, "dataset_name": args.dataset_name, "request_rate": request_rate, "max_concurrency": max_concurrency, + "sharegpt_output_len": args.sharegpt_output_len, + "random_input_len": args.random_input_len, + "random_output_len": args.random_output_len, + "random_range_ratio": args.random_range_ratio, + # Results + "duration": benchmark_duration, + "completed": metrics.completed, "total_input_tokens": metrics.total_input, "total_output_tokens": metrics.total_output, "total_output_tokens_retokenized": metrics.total_output_retokenized, + "request_throughput": metrics.request_throughput, + "input_throughput": metrics.input_throughput, + "output_throughput": metrics.output_throughput, "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms, "median_e2e_latency_ms": metrics.median_e2e_latency_ms, + "std_e2e_latency_ms": metrics.std_e2e_latency_ms, + "p99_e2e_latency_ms": metrics.p99_e2e_latency_ms, + "mean_ttft_ms": metrics.mean_ttft_ms, "median_ttft_ms": metrics.median_ttft_ms, + "std_ttft_ms": metrics.std_ttft_ms, + "p99_ttft_ms": metrics.p99_ttft_ms, + "mean_tpot_ms": metrics.mean_tpot_ms, + "median_tpot_ms": metrics.median_tpot_ms, + "std_tpot_ms": metrics.std_tpot_ms, + "p99_tpot_ms": metrics.p99_tpot_ms, + "mean_itl_ms": metrics.mean_itl_ms, "median_itl_ms": metrics.median_itl_ms, - "output_throughput": metrics.output_throughput, - "sharegpt_output_len": args.sharegpt_output_len, - "random_input_len": args.random_input_len, - "random_output_len": args.random_output_len, - "random_range_ratio": args.random_range_ratio, - "duration": benchmark_duration, - "completed": metrics.completed, + "std_itl_ms": metrics.std_itl_ms, + "p99_itl_ms": metrics.p99_itl_ms, + "concurrency": metrics.concurrency, } else: print(f"Error running benchmark for request rate: {request_rate}") @@ -1159,36 +1208,16 @@ async def limited_request_func(request_func_input, pbar): with open(output_file_name, "a") as file: file.write(json.dumps(result) + "\n") - result = { - "duration": benchmark_duration, - "completed": metrics.completed, - "total_input_tokens": metrics.total_input, - "total_output_tokens": metrics.total_output, - "total_output_tokens_retokenized": metrics.total_output_retokenized, - "request_throughput": metrics.request_throughput, - "input_throughput": metrics.input_throughput, - "output_throughput": metrics.output_throughput, - "mean_ttft_ms": metrics.mean_ttft_ms, - "median_ttft_ms": metrics.median_ttft_ms, - "std_ttft_ms": metrics.std_ttft_ms, - "p99_ttft_ms": metrics.p99_ttft_ms, - "mean_tpot_ms": metrics.mean_tpot_ms, - "median_tpot_ms": metrics.median_tpot_ms, - "std_tpot_ms": metrics.std_tpot_ms, - "p99_tpot_ms": metrics.p99_tpot_ms, - "mean_itl_ms": metrics.mean_itl_ms, - "median_itl_ms": metrics.median_itl_ms, - "std_itl_ms": metrics.std_itl_ms, - "p99_itl_ms": metrics.p99_itl_ms, - "input_lens": [output.prompt_len for output in outputs], - "output_lens": output_lens, - "ttfts": [output.ttft for output in outputs], - "itls": [output.itl for output in outputs], - "generated_texts": [output.generated_text for output in outputs], - "errors": [output.error for output in outputs], - "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms, - "median_e2e_latency_ms": metrics.median_e2e_latency_ms, - } + result.update( + { + "input_lens": [output.prompt_len for output in outputs], + "output_lens": output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "generated_texts": [output.generated_text for output in outputs], + "errors": [output.error for output in outputs], + } + ) return result @@ -1425,6 +1454,12 @@ def set_ulimit(target_soft_limit=65535): default=None, help="Output length for each request. Overrides the output length from the ShareGPT dataset.", ) + parser.add_argument( + "--sharegpt-context-len", + type=int, + default=None, + help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.", + ) parser.add_argument( "--random-input-len", type=int, @@ -1464,7 +1499,6 @@ def set_ulimit(target_soft_limit=65535): "actual request rate may be lower than specified with --request-rate, " "if the server is not processing requests fast enough to keep up.", ) - parser.add_argument("--seed", type=int, default=1, help="The random seed.") parser.add_argument( "--multi", action="store_true", @@ -1487,6 +1521,12 @@ def set_ulimit(target_soft_limit=65535): action="store_true", help="Disable streaming mode.", ) + parser.add_argument( + "--return-logprob", + action="store_true", + help="Return logprob.", + ) + parser.add_argument("--seed", type=int, default=1, help="The random seed.") parser.add_argument( "--disable-ignore-eos", action="store_true", @@ -1499,49 +1539,54 @@ def set_ulimit(target_soft_limit=65535): help="Append given JSON object to the request payload. You can use this to specify" "additional generate params like sampling params.", ) + parser.add_argument( + "--apply-chat-template", + action="store_true", + help="Apply chat template", + ) + parser.add_argument( + "--profile", + action="store_true", + help="Use Torch Profiler. The endpoint must be launched with " + "SGLANG_TORCH_PROFILER_DIR to enable profiler.", + ) + parser.add_argument( + "--lora-name", + type=str, + default=None, + help="The name of LoRA adapter", + ) group = parser.add_argument_group("generated-shared-prefix dataset arguments") group.add_argument( - "--gen-num-groups", + "--gsp-num-groups", type=int, default=64, help="Number of system prompt groups for generated-shared-prefix dataset", ) group.add_argument( - "--gen-prompts-per-group", + "--gsp-prompts-per-group", type=int, default=16, help="Number of prompts per system prompt group for generated-shared-prefix dataset", ) group.add_argument( - "--gen-system-prompt-len", + "--gsp-system-prompt-len", type=int, default=2048, help="Target length in tokens for system prompts in generated-shared-prefix dataset", ) group.add_argument( - "--gen-question-len", + "--gsp-question-len", type=int, default=128, help="Target length in tokens for questions in generated-shared-prefix dataset", ) group.add_argument( - "--gen-output-len", + "--gsp-output-len", type=int, default=256, help="Target length in tokens for outputs in generated-shared-prefix dataset", ) - parser.add_argument( - "--profile", - action="store_true", - help="Use Torch Profiler. The endpoint must be launched with " - "SGLANG_TORCH_PROFILER_DIR to enable profiler.", - ) - parser.add_argument( - "--lora-name", - type=str, - default=None, - help="The name of LoRA adapter", - ) args = parser.parse_args() run_benchmark(args) diff --git a/python/sglang/check_env.py b/python/sglang/check_env.py index aafb8c101c2..19b8a8f9b09 100644 --- a/python/sglang/check_env.py +++ b/python/sglang/check_env.py @@ -9,6 +9,13 @@ import torch +from sglang.srt.utils import is_hip + + +def is_cuda_v2(): + return torch.version.cuda is not None + + # List of packages to check versions PACKAGE_LIST = [ "sglang", @@ -63,13 +70,22 @@ def get_cuda_info(): """ Get CUDA-related information if available. """ - cuda_info = {"CUDA available": torch.cuda.is_available()} + if is_cuda_v2(): + cuda_info = {"CUDA available": torch.cuda.is_available()} + + if cuda_info["CUDA available"]: + cuda_info.update(_get_gpu_info()) + cuda_info.update(_get_cuda_version_info()) + + return cuda_info + elif is_hip(): + cuda_info = {"ROCM available": torch.cuda.is_available()} - if cuda_info["CUDA available"]: - cuda_info.update(_get_gpu_info()) - cuda_info.update(_get_cuda_version_info()) + if cuda_info["ROCM available"]: + cuda_info.update(_get_gpu_info()) + cuda_info.update(_get_cuda_version_info()) - return cuda_info + return cuda_info def _get_gpu_info(): @@ -103,34 +119,72 @@ def _get_cuda_version_info(): """ Get CUDA version information. """ - from torch.utils.cpp_extension import CUDA_HOME + if is_cuda_v2(): + from torch.utils.cpp_extension import CUDA_HOME - cuda_info = {"CUDA_HOME": CUDA_HOME} + cuda_info = {"CUDA_HOME": CUDA_HOME} - if CUDA_HOME and os.path.isdir(CUDA_HOME): - cuda_info.update(_get_nvcc_info()) - cuda_info.update(_get_cuda_driver_version()) + if CUDA_HOME and os.path.isdir(CUDA_HOME): + cuda_info.update(_get_nvcc_info()) + cuda_info.update(_get_cuda_driver_version()) - return cuda_info + return cuda_info + elif is_hip(): + from torch.utils.cpp_extension import ROCM_HOME as ROCM_HOME + + cuda_info = {"ROCM_HOME": ROCM_HOME} + + if ROCM_HOME and os.path.isdir(ROCM_HOME): + cuda_info.update(_get_nvcc_info()) + cuda_info.update(_get_cuda_driver_version()) + + return cuda_info + else: + cuda_info = {"CUDA_HOME": ""} + return cuda_info def _get_nvcc_info(): """ Get NVCC version information. """ - from torch.utils.cpp_extension import CUDA_HOME + if is_cuda_v2(): + from torch.utils.cpp_extension import CUDA_HOME - try: - nvcc = os.path.join(CUDA_HOME, "bin/nvcc") - nvcc_output = ( - subprocess.check_output(f'"{nvcc}" -V', shell=True).decode("utf-8").strip() - ) - return { - "NVCC": nvcc_output[ - nvcc_output.rfind("Cuda compilation tools") : nvcc_output.rfind("Build") - ].strip() - } - except subprocess.SubprocessError: + try: + nvcc = os.path.join(CUDA_HOME, "bin/nvcc") + nvcc_output = ( + subprocess.check_output(f'"{nvcc}" -V', shell=True) + .decode("utf-8") + .strip() + ) + return { + "NVCC": nvcc_output[ + nvcc_output.rfind("Cuda compilation tools") : nvcc_output.rfind( + "Build" + ) + ].strip() + } + except subprocess.SubprocessError: + return {"NVCC": "Not Available"} + elif is_hip(): + from torch.utils.cpp_extension import ROCM_HOME + + try: + hipcc = os.path.join(ROCM_HOME, "bin/hipcc") + hipcc_output = ( + subprocess.check_output(f'"{hipcc}" --version', shell=True) + .decode("utf-8") + .strip() + ) + return { + "HIPCC": hipcc_output[ + hipcc_output.rfind("HIP version") : hipcc_output.rfind("AMD clang") + ].strip() + } + except subprocess.SubprocessError: + return {"HIPCC": "Not Available"} + else: return {"NVCC": "Not Available"} @@ -139,20 +193,40 @@ def _get_cuda_driver_version(): Get CUDA driver version. """ versions = set() - try: - output = subprocess.check_output( - [ - "nvidia-smi", - "--query-gpu=driver_version", - "--format=csv,noheader,nounits", - ] - ) - versions = set(output.decode().strip().split("\n")) - if len(versions) == 1: - return {"CUDA Driver Version": versions.pop()} - else: - return {"CUDA Driver Versions": ", ".join(sorted(versions))} - except subprocess.SubprocessError: + if is_cuda_v2(): + try: + output = subprocess.check_output( + [ + "nvidia-smi", + "--query-gpu=driver_version", + "--format=csv,noheader,nounits", + ] + ) + versions = set(output.decode().strip().split("\n")) + if len(versions) == 1: + return {"CUDA Driver Version": versions.pop()} + else: + return {"CUDA Driver Versions": ", ".join(sorted(versions))} + except subprocess.SubprocessError: + return {"CUDA Driver Version": "Not Available"} + elif is_hip(): + try: + output = subprocess.check_output( + [ + "rocm-smi", + "--showdriverversion", + "--csv", + ] + ) + versions = set(output.decode().strip().split("\n")) + versions.discard("name, value") + ver = versions.pop() + ver = ver.replace('"Driver version", ', "").replace('"', "") + + return {"ROCM Driver Version": ver} + except subprocess.SubprocessError: + return {"ROCM Driver Version": "Not Available"} + else: return {"CUDA Driver Version": "Not Available"} @@ -160,16 +234,31 @@ def get_gpu_topology(): """ Get GPU topology information. """ - try: - result = subprocess.run( - ["nvidia-smi", "topo", "-m"], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - check=True, - ) - return "\n" + result.stdout if result.returncode == 0 else None - except subprocess.SubprocessError: + if is_cuda_v2(): + try: + result = subprocess.run( + ["nvidia-smi", "topo", "-m"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=True, + ) + return "\n" + result.stdout if result.returncode == 0 else None + except subprocess.SubprocessError: + return None + elif is_hip(): + try: + result = subprocess.run( + ["rocm-smi", "--showtopotype"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=True, + ) + return "\n" + result.stdout if result.returncode == 0 else None + except subprocess.SubprocessError: + return None + else: return None @@ -196,7 +285,10 @@ def check_env(): gpu_topo = get_gpu_topology() if gpu_topo: - env_info["NVIDIA Topology"] = gpu_topo + if is_cuda_v2(): + env_info["NVIDIA Topology"] = gpu_topo + elif is_hip(): + env_info["AMD Topology"] = gpu_topo hypervisor_vendor = get_hypervisor_vendor() if hypervisor_vendor: diff --git a/python/sglang/lang/backend/openai.py b/python/sglang/lang/backend/openai.py index 6fa93d9b2eb..4f37da79b7e 100644 --- a/python/sglang/lang/backend/openai.py +++ b/python/sglang/lang/backend/openai.py @@ -366,6 +366,11 @@ def select( def openai_completion( client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs ): + # if "ebnf" is in kwargs, warn and remove + if "ebnf" in kwargs: + warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.") + del kwargs["ebnf"] + for attempt in range(retries): try: if is_chat: @@ -398,6 +403,11 @@ def openai_completion( def openai_completion_stream( client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs ): + # if "ebnf" is in kwargs, warn and remove + if "ebnf" in kwargs: + warnings.warn("EBNF is not officially supported by OpenAI endpoints. Ignoring.") + del kwargs["ebnf"] + for attempt in range(retries): try: if is_chat: diff --git a/python/sglang/lang/backend/runtime_endpoint.py b/python/sglang/lang/backend/runtime_endpoint.py index 779bf988d20..01f10b9f063 100644 --- a/python/sglang/lang/backend/runtime_endpoint.py +++ b/python/sglang/lang/backend/runtime_endpoint.py @@ -1,6 +1,11 @@ +import atexit import json +import multiprocessing import warnings -from typing import List, Optional +from typing import Dict, List, Optional, Union + +import aiohttp +import requests from sglang.global_config import global_config from sglang.lang.backend.base_backend import BaseBackend @@ -55,6 +60,7 @@ def flush_cache(self): self.base_url + "/flush_cache", api_key=self.api_key, verify=self.verify, + method="POST", ) self._assert_success(res) @@ -250,11 +256,12 @@ def select( } obj = self._generate_http_request(s, data) - normalized_prompt_logprobs = [ - r["meta_info"]["normalized_prompt_logprob"] for r in obj - ] input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj] output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj] + normalized_prompt_logprobs = [ + compute_normalized_prompt_logprobs(r["meta_info"]["input_token_logprobs"]) + for r in obj + ] # Remove extra token if no token healing occurred for i in range(len(input_token_logprobs)): @@ -318,3 +325,176 @@ def _add_images(self, s: StreamExecutor, data): def _assert_success(self, res): if res.status_code != 200: raise RuntimeError(res.json()) + + +def compute_normalized_prompt_logprobs(input_logprobs): + values = [x[0] for x in input_logprobs if x[0]] + return sum(values) / len(values) + + +class Runtime: + """ + A wrapper for the HTTP server. + This is used for launching the server in a python program without + using the commond line interface. + + It is mainly used for the frontend language. + You should use the Engine class if you want to do normal offline processing without the frontend language. + """ + + def __init__( + self, + log_level: str = "error", + *args, + **kwargs, + ): + """See the arguments in server_args.py::ServerArgs""" + # We delay the import of any `sglang.srt` components in `sglang.lang`, so users can run + # client code without installing SRT server and its dependency if they want. + from sglang.srt.entrypoints.http_server import launch_server + from sglang.srt.server_args import ServerArgs + from sglang.srt.utils import is_port_available + + self.server_args = ServerArgs(*args, log_level=log_level, **kwargs) + + # Pre-allocate ports + for port in range(self.server_args.port, 40000): + if is_port_available(port): + break + self.server_args.port = port + + self.url = self.server_args.url() + self.generate_url = self.url + "/generate" + + # NOTE: We store pid instead of proc to fix some issues during __delete__ + self.pid = None + pipe_reader, pipe_writer = multiprocessing.Pipe(duplex=False) + + proc = multiprocessing.Process( + target=launch_server, + args=(self.server_args, pipe_writer), + ) + proc.start() + pipe_writer.close() + self.pid = proc.pid + + # Before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown() + atexit.register(self.shutdown) + + # TODO: remove this pipe_writer mechanism and use `/health_generate` instead. + try: + init_state = pipe_reader.recv() + except EOFError: + init_state = "" + + if init_state != "ready": + self.shutdown() + raise RuntimeError( + "Initialization failed. Please see the error messages above." + ) + + self.endpoint = RuntimeEndpoint(self.url) + + def shutdown(self): + from sglang.srt.utils import kill_process_tree + + if self.pid is not None: + kill_process_tree(self.pid) + self.pid = None + + def cache_prefix(self, prefix: str): + self.endpoint.cache_prefix(prefix) + + def get_tokenizer(self): + from sglang.srt.hf_transformers_utils import get_tokenizer + + return get_tokenizer( + self.server_args.tokenizer_path, + tokenizer_mode=self.server_args.tokenizer_mode, + trust_remote_code=self.server_args.trust_remote_code, + revision=self.server_args.revision, + ) + + async def async_generate( + self, + prompt: str, + sampling_params: Optional[Dict] = None, + ): + if self.server_args.skip_tokenizer_init: + json_data = { + "input_ids": prompt, + "sampling_params": sampling_params, + "stream": True, + } + else: + json_data = { + "text": prompt, + "sampling_params": sampling_params, + "stream": True, + } + pos = 0 + + timeout = aiohttp.ClientTimeout(total=3 * 3600) + async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: + async with session.post(self.generate_url, json=json_data) as response: + async for chunk, _ in response.content.iter_chunks(): + chunk = chunk.decode("utf-8") + if chunk and chunk.startswith("data:"): + if chunk == "data: [DONE]\n\n": + break + data = json.loads(chunk[5:].strip("\n")) + if "text" in data: + cur = data["text"][pos:] + if cur: + yield cur + pos += len(cur) + else: + yield data + + add_request = async_generate + + def generate( + self, + prompt: Union[str, List[str]], + sampling_params: Optional[Dict] = None, + return_logprob: Optional[Union[List[bool], bool]] = False, + logprob_start_len: Optional[Union[List[int], int]] = None, + top_logprobs_num: Optional[Union[List[int], int]] = None, + lora_path: Optional[List[Optional[str]]] = None, + ): + json_data = { + "text": prompt, + "sampling_params": sampling_params, + "return_logprob": return_logprob, + "logprob_start_len": logprob_start_len, + "top_logprobs_num": top_logprobs_num, + "lora_path": lora_path, + } + assert not isinstance(lora_path, list) or len(lora_path) == len(prompt) + response = requests.post( + self.url + "/generate", + json=json_data, + ) + return json.dumps(response.json()) + + def encode( + self, + prompt: Union[str, List[str], List[Dict], List[List[Dict]]], + ): + json_data = {"text": prompt} + response = requests.post(self.url + "/encode", json=json_data) + return json.dumps(response.json()) + + async def get_server_info(self): + async with aiohttp.ClientSession() as session: + async with session.get(f"{self.url}/get_server_info") as response: + if response.status == 200: + return await response.json() + else: + error_data = await response.json() + raise RuntimeError( + f"Failed to get server info. {error_data['error']['message']}" + ) + + def __del__(self): + self.shutdown() diff --git a/python/sglang/lang/backend/shortfin.py b/python/sglang/lang/backend/shortfin.py index d67a49e0d12..1ce6e50d421 100644 --- a/python/sglang/lang/backend/shortfin.py +++ b/python/sglang/lang/backend/shortfin.py @@ -20,7 +20,9 @@ def __init__( if base_url is None: raise ValueError("`base_url` is required for Shortfin backend") - self.chat_template = chat_template or get_chat_template_by_model_path("default") + self.chat_template = chat_template or get_chat_template_by_model_path( + "llama-3-instruct" + ) self.client_params = {"base_url": base_url, "timeout": timeout} @@ -39,7 +41,7 @@ def _assert_success(self, res): raise RuntimeError(res.json()) def _clean_response_message(self, text): - return text.replace(b"data: ", b"").strip(b"\n") + return text[text.find(": ") + 2 :].rstrip("\n") def get_chat_template(self): return self.chat_template @@ -61,9 +63,9 @@ def generate( ) self._assert_success(resp) - response_message = resp.resp.read() + response_message = resp.resp.read().decode() response_message = self._clean_response_message(response_message) - return response_message.decode("utf-8"), {} + return response_message, {} def generate_stream( self, @@ -81,15 +83,13 @@ def generate_stream( json=shortfin_kwargs, stream=True, timeout=self.client_params["timeout"], + method="POST", ) self._assert_success(resp) - - prefix = b"" for chunk in resp: if chunk == b"data: [DONE]\n\n": break - text = chunk[len(prefix) :] - prefix += text.strip(b"\n") + text = chunk.decode() text = self._clean_response_message(text) if text is not None: - yield text.decode("utf-8"), {} + yield text, {} diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index 3e5ac8dd522..a2c91c561c2 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -88,7 +88,6 @@ def get_chat_template_by_model_path(model_path): ) ) - register_chat_template( ChatTemplate( name="claude", @@ -101,7 +100,6 @@ def get_chat_template_by_model_path(model_path): ) ) - register_chat_template( ChatTemplate( name="chatml", @@ -116,7 +114,6 @@ def get_chat_template_by_model_path(model_path): ) ) - register_chat_template( ChatTemplate( name="chatml-llava", @@ -132,7 +129,6 @@ def get_chat_template_by_model_path(model_path): ) ) - # There is default system prompt for qwen # reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1 # The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" @@ -219,6 +215,21 @@ def get_chat_template_by_model_path(model_path): ) ) +# https://huggingface.co/openbmb/MiniCPM-V-2_6 +register_chat_template( + ChatTemplate( + name="minicpmv", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ("", " "), + "user": ("user:", " "), + "assistant": ("assistant:", ""), + }, + stop_str=("<|im_end|>", "<|endoftext|>"), + image_token="(./)", + ) +) + # The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token. register_chat_template( ChatTemplate( @@ -320,6 +331,59 @@ def get_chat_template_by_model_path(model_path): ) ) +register_chat_template( + ChatTemplate( + name="granite-3-instruct", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ( + "<|start_of_role|>system<|end_of_role|>", + "<|end_of_text|>", + ), + "user": ( + "<|start_of_role|>user<|end_of_role|>", + "<|end_of_text|>", + ), + "assistant": ( + "<|start_of_role|>assistant<|end_of_role|>", + "<|end_of_text|>", + ), + }, + stop_str=("<|end_of_text|>",), + ) +) + + +register_chat_template( + ChatTemplate( + name="deepseek-v3", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ( + "", + "", + ), + "user": ( + "<|User|>", + "", + ), + "assistant": ( + "<|Assistant|>", + "<|end▁of▁sentence|>", + ), + }, + stop_str=("<|end▁of▁sentence|>",), + ) +) + + +@register_chat_template_matching_function +def match_deepseek(model_path: str): + if ( + "deepseek-v3" in model_path.lower() or "deepseek-r1" in model_path.lower() + ) and "base" not in model_path.lower(): + return get_chat_template("deepseek-v3") + @register_chat_template_matching_function def match_dbrx(model_path: str): @@ -402,6 +466,16 @@ def match_c4ai_command_r(model_path: str): return get_chat_template("c4ai-command-r") +@register_chat_template_matching_function +def match_granite_instruct(model_path: str): + model_path = model_path.lower() + # When future versions of Granite are released, this code may + # need to be updated. For now, assume that the Granite 3.0 + # template works across the board. + if "granite" in model_path and "instruct" in model_path: + return get_chat_template("granite-3-instruct") + + if __name__ == "__main__": messages = [ {"role": "system", "content": None}, # None means default diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 55a20336bc7..4c294781c20 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -96,6 +96,7 @@ def run_program_batch( default_sampling_para, num_threads, progress_bar, + generator_style=False, ): if hasattr(backend, "endpoint"): backend = backend.endpoint @@ -109,6 +110,17 @@ def run_program_batch( num_threads = max(96, multiprocessing.cpu_count() * 16) num_threads = min(num_threads, len(batch_arguments)) + if generator_style: + return _run_program_batch_generator( + program, + backend, + batch_arguments, + default_sampling_para, + num_threads, + progress_bar, + ) + + # Original code path when generator_style=False if num_threads == 1: rets = [] if progress_bar: @@ -168,6 +180,64 @@ def run_program_batch( return rets +def _run_program_batch_generator( + program, + backend, + batch_arguments, + default_sampling_para, + num_threads, + progress_bar, +): + """Helper function that yields results one by one using chunking to avoid overwhelming ThreadPoolExecutor.""" + if num_threads == 1: + iterator = tqdm.tqdm(batch_arguments) if progress_bar else batch_arguments + for arguments in iterator: + yield run_program( + program, + backend, + (), + arguments, + default_sampling_para, + False, + True, + ) + else: + pbar = tqdm.tqdm(total=len(batch_arguments)) if progress_bar else None + + # Process in chunks to avoid overwhelming ThreadPoolExecutor + # Otherwise, ThreadPoolExecutor.submit will block after adding certain number of tasks + # so we will never reach "yield" until all tasks are done + chunk_size = 200 + + with ThreadPoolExecutor(num_threads) as executor: + for chunk_start in range(0, len(batch_arguments), chunk_size): + chunk_end = min(chunk_start + chunk_size, len(batch_arguments)) + chunk_futures = [] + + # Submit chunk of tasks + for i in range(chunk_start, chunk_end): + future = executor.submit( + run_program, + program, + backend, + (), + batch_arguments[i], + default_sampling_para, + False, + True, + ) + if pbar: + future.add_done_callback(lambda _: pbar.update()) + chunk_futures.append(future) + + # Yield results from this chunk as they complete + for future in chunk_futures: + yield future.result() + + if pbar: + pbar.close() + + def cache_program(program, backend): from sglang.lang.tracer import extract_prefix_by_tracing @@ -277,7 +347,7 @@ def fork( size: int = 1, position_ids_offset: Optional[List[int]] = None, ): - if size > 1: + if size > 1 and str(self.text_): self.submit(SglCommitLazy()) self.sync() diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index 95d7c29751d..f07b1929e51 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -124,7 +124,7 @@ def to_shortfin_kwargs(self): else False ) kwargs["sampling_params"] = { - "max_completion_tokens": self.max_new_tokens, + "max_completion_tokens": 10, "temperature": self.temperature, } return kwargs @@ -244,6 +244,7 @@ def run_batch( backend=None, num_threads: Union[str, int] = "auto", progress_bar: bool = False, + generator_style: bool = False, ): from sglang.lang.interpreter import run_program_batch @@ -294,6 +295,7 @@ def run_batch( default_sampling_para, num_threads, progress_bar, + generator_style=generator_style, ) def trace(self, *, backend=None, **kwargs): diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index 6b0c25711c6..caae7b0f6cc 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -3,7 +3,7 @@ import os import sys -from sglang.srt.server import launch_server +from sglang.srt.entrypoints.http_server import launch_server from sglang.srt.server_args import prepare_server_args from sglang.srt.utils import kill_process_tree diff --git a/python/sglang/launch_server_llavavid.py b/python/sglang/launch_server_llavavid.py deleted file mode 100644 index 138c2127e16..00000000000 --- a/python/sglang/launch_server_llavavid.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Launch the inference server for Llava-video model.""" - -import json -import sys - -from sglang.srt.server import launch_server, prepare_server_args - -if __name__ == "__main__": - server_args = prepare_server_args(sys.argv[1:]) - - model_override_args = {} - model_override_args["mm_spatial_pool_stride"] = 2 - model_override_args["architectures"] = ["LlavaVidForCausalLM"] - model_override_args["num_frames"] = 16 - model_override_args["model_type"] = "llavavid" - if model_override_args["num_frames"] == 32: - model_override_args["rope_scaling"] = {"factor": 2.0, "rope_type": "linear"} - model_override_args["max_sequence_length"] = 4096 * 2 - model_override_args["tokenizer_model_max_length"] = 4096 * 2 - model_override_args["model_max_length"] = 4096 * 2 - if "34b" in server_args.model_path.lower(): - model_override_args["image_token_index"] = 64002 - server_args.json_model_override_args = json.dumps(model_override_args) - - launch_server(server_args) diff --git a/python/sglang/llama3_eval.py b/python/sglang/llama3_eval.py new file mode 100644 index 00000000000..35bd4a7e4d4 --- /dev/null +++ b/python/sglang/llama3_eval.py @@ -0,0 +1,316 @@ +# Adapt from https://github.com/fw-ai/llm_eval_meta + +import argparse +import asyncio +import os +import pickle +import re +import shutil +from collections import defaultdict +from dataclasses import dataclass + +import httpx +import numpy as np +import openai +import transformers +from datasets import load_dataset +from openai import AsyncOpenAI +from tqdm import tqdm + +# Mapping providers to their clients and models +provider_to_models = { + "b10": { + "8b": "meta-llama/Llama-3.1-8B-Instruct", + "70b": "meta-llama/Llama-3.1-70B-Instruct", + "405b": "meta-llama/Llama-3.1-405B-Instruct", + }, + "oai": { + "8b": "meta-llama/Llama-3.1-8B-Instruct", + "70b": "meta-llama/Llama-3.1-70B-Instruct", + "405b": "meta-llama/Llama-3.1-405B-Instruct", + }, + "sgl": { + "8b": "meta-llama/Llama-3.1-8B-Instruct", + "70b": "meta-llama/Llama-3.1-70B-Instruct", + "405b": "meta-llama/Llama-3.1-405B-Instruct", + }, +} + + +async def fetch_responses( + client, prompt, semaphore, index, provider, model_size, output_dir, max_tokens +): + output_file = os.path.join(output_dir, f"response_{index}.pkl") + if os.path.exists(output_file): + print(f"File {output_file} already exists, skipping.") + return + + async with semaphore: + response = await client.completions.create( + model=provider_to_models[provider][model_size], + prompt=prompt, + temperature=0.0, + max_tokens=max_tokens, + ) + if isinstance(response, openai.BadRequestError): + with open(output_file, "wb") as f: + pickle.dump("bad_response", f) + assert isinstance(response, openai.types.completion.Completion) + # Save response to a file + with open(output_file, "wb") as f: + pickle.dump(response, f) + + +TASK_TO_MAX_TOKENS = { + "evals__mmlu__details": 1, + "evals__mmlu__0_shot__cot__details": 1024, + # Official meta uses 1024, but a small % (.05) of questions are answered correctly after relaxing + "evals__mmlu_pro__details": 2048, + "evals__gsm8k__details": 1024, +} + +TASK_TO_EVAL_SET = { + "mmlu": "evals__mmlu__details", + "mmlu_cot": "evals__mmlu__0_shot__cot__details", + "mmlu_pro": "evals__mmlu_pro__details", + "gsm8k": "evals__gsm8k__details", +} + + +class CustomAsyncHTTPXClient(httpx.AsyncClient): + async def send(self, request: httpx.Request, *args, **kwargs) -> httpx.Response: + request.url = httpx.URL( + f"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict" + ) + return await super().send(request, *args, **kwargs) + + +def get_client(provider): + if provider not in "b10": + if os.getenv("OPENAI_API_KEY") == None: + os.environ["OPENAI_API_KEY"] = "EMPTY" + return { + "oai": AsyncOpenAI(base_url="http://127.0.0.1:8000/v1/"), + "b10": AsyncOpenAI( + api_key=f"Api-Key {os.getenv('OPENAI_API_KEY')}", + base_url=f"https://model-{os.getenv('MODEL_ID')}.api.baseten.co/development/predict", + http_client=CustomAsyncHTTPXClient(), + ), + "sgl": AsyncOpenAI(base_url="http://127.0.0.1:30000/v1/"), + }[provider] + + +# Define the benchmark function +async def benchmark(args): + ds = load_dataset( + "meta-llama/Llama-3.1-405B-Instruct-evals", + f"Llama-3.1-405B-Instruct-{TASK_TO_EVAL_SET[args.task]}", + ) + semaphore = asyncio.Semaphore(args.concurrency) # Limit to 16 concurrent tasks + + if args.num_examples is None: + args.num_examples = len(ds["latest"]["input_final_prompts"]) + prompts = ds["latest"]["input_final_prompts"][: args.num_examples] + + # Create the output directory if it does not exist + os.makedirs(args.output_dir, exist_ok=True) + + tasks = [] + # Create the tasks with tqdm progress bar + max_tokens = TASK_TO_MAX_TOKENS[TASK_TO_EVAL_SET[args.task]] + client = get_client(args.provider) + for idx, prompt in enumerate(tqdm(prompts, desc="Creating tasks")): + tasks.append( + asyncio.create_task( + fetch_responses( + client, + f"<|begin_of_text|>{prompt[0]}", + semaphore, + idx, + args.provider, + args.model_size, + args.output_dir, + max_tokens=max_tokens, + ) + ) + ) + + # Run the tasks with tqdm progress bar + for future in tqdm( + asyncio.as_completed(tasks), total=len(tasks), desc="Processing tasks" + ): + await future + + +def get_mmlu_answer(response): + if response is not None: + return response.choices[0].text.lstrip().rstrip().upper().replace(".", "") + return None + + +def get_mmlu_cot_answer(response): + pattern = r"The best answer is (.+)\.?" + match = re.search(pattern, response.choices[0].text) + if match: + return match.group(1).replace(".", "").replace("*", "") + + pattern = r"the best answer is (.+)\.?" + match = re.search(pattern, response.choices[0].text) + if match: + return match.group(1).replace(".", "") + + pattern = r"The correct answer is (.+)\.?" + match = re.search(pattern, response.choices[0].text) + if match: + return match.group(1).replace(".", "") + + pattern = r"the correct answer is (.+)\.?" + match = re.search(pattern, response.choices[0].text) + if match: + return match.group(1).replace(".", "") + + +def get_answer_gsm8k(response): + pattern = r"The final answer is (.+)\.?" + match = re.search(pattern, response.choices[0].text) + if match: + s = match.group(1) + for ok_symbol in ["%", "$"]: + s = s.replace(ok_symbol, "") + return s + + +TASK_TO_ANSWER_EXTRACTOR = { + "evals__mmlu__details": get_mmlu_answer, + "evals__mmlu__0_shot__cot__details": get_mmlu_cot_answer, + "evals__gsm8k__details": get_answer_gsm8k, + "evals__mmlu_pro__details": get_mmlu_cot_answer, +} + + +def get_dataset_from_task(task, response_path, model_size): + ds_405b = load_dataset( + f"meta-llama/Llama-3.1-405B-Instruct-evals", + f"Llama-3.1-405B-Instruct-{task}", + ) + ds_405b_hash_order = [x[0] for x in ds_405b["latest"]["input_final_prompts_hash"]] + + if "70b" in model_size or "8b" in model_size: + if "70" in model_size: + ref_model_ds = load_dataset( + f"meta-llama/Llama-3.1-70B-Instruct-evals", + f"Llama-3.1-70B-Instruct-{task}", + ) + else: + ref_model_ds = load_dataset( + f"meta-llama/Llama-3.1-8B-Instruct-evals", + f"Llama-3.1-8B-Instruct-{task}", + ) + + hash_to_row = {} + for row in ref_model_ds["latest"]: + hash_to_row[row["input_final_prompts_hash"][0]] = row + reordered_rows = [] + for prompt_hash in ds_405b_hash_order: + reordered_rows.append(hash_to_row[prompt_hash]) + ref_model_ds["latest"] = reordered_rows + return ref_model_ds + + return ds_405b + + +def analyze(task, response_path, model_size): + ds = get_dataset_from_task(task, response_path, model_size) + + responses = [] + total = len(ds["latest"]) + + for i in range(0, total): + response = pickle.load( + open(os.path.join(response_path, f"response_{i}.pkl"), "rb") + ) + responses.append(response) + + @dataclass + class Stats: + correct: int = 0 + total: int = 0 + meta_correct: int = 0 + + average: float = None + + subtask_name_to_stats = defaultdict(lambda: Stats()) + + for response, ds_row in zip(responses, ds["latest"]): + model_answer = TASK_TO_ANSWER_EXTRACTOR[task](response) + + subtask = ds_row["subtask_name"] + + is_eval_correct = model_answer in ds_row["input_correct_responses"] + if is_eval_correct: + subtask_name_to_stats[subtask].correct += 1 + + if ds_row["is_correct"]: + subtask_name_to_stats[subtask].meta_correct += 1 + + subtask_name_to_stats[subtask].total += 1 + + micro_stats = Stats() + for subtask, stats in subtask_name_to_stats.items(): + stats.average = stats.correct / stats.total + stats.meta_average = stats.meta_correct / stats.total + + micro_stats.correct += stats.correct + micro_stats.total += stats.total + micro_stats.meta_correct += stats.meta_correct + + micro_stats.average = micro_stats.correct / micro_stats.total + micro_stats.meta_average = micro_stats.meta_correct / micro_stats.total + + print("Macro average", np.mean([x.average for x in subtask_name_to_stats.values()])) + print( + "Meta Macro average", + np.mean([x.meta_average for x in subtask_name_to_stats.values()]), + ) + print("Micro average", micro_stats.average) + print("Meta Micro average", micro_stats.meta_average) + + +# Entry point for the script +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Script to run model with specified parameters." + ) + parser.add_argument( + "--model-size", + type=str, + default="8b", + help="Size of the model (e.g., 8b or 70b)", + ) + parser.add_argument( + "--provider", + type=str, + default="sgl", + help="Provider name (e.g., sgl, oai, b10)", + ) + parser.add_argument( + "--task", + type=str, + required=True, + help="Task (e.g., mmlu, mmlu_cot, mmlu_pro, gsm8k)", + ) + parser.add_argument( + "--num-examples", type=int, default=None, help="Number of examples to process" + ) + parser.add_argument("--concurrency", type=int, default=16) + parser.add_argument( + "--output-dir", + type=str, + default="tmp-output-dir", + help="Directory to save responses", + ) + + args = parser.parse_args() + asyncio.run(benchmark(args)) + analyze(TASK_TO_EVAL_SET[args.task], args.output_dir, args.model_size) + shutil.rmtree("tmp-output-dir", ignore_errors=True) diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index 9eb7caa1bba..3cb313b9133 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -1,8 +1,9 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/_custom_ops.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py import contextlib import functools import importlib import logging +import os from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch @@ -11,12 +12,19 @@ from sglang.srt.utils import is_hpu logger = logging.getLogger(__name__) +use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=True) if not is_hpu(): - try: - import custom_ar - except ImportError as e: - logger.warning("Failed to import from custom_ar with %r", e) + if use_vllm_custom_allreduce: + try: + import vllm._C + except ImportError as e: + logger.warning("Failed to import from vllm._C with %r", e) + else: + try: + import sgl_kernel + except ImportError as e: + logger.warning("Failed to import from custom_ar with %r", e) def hint_on_error(fn): @@ -48,48 +56,78 @@ def wrapper(*args, **kwargs): return wrapper -# custom ar -def init_custom_ar( - ipc_tensors: List[torch.Tensor], - rank_data: torch.Tensor, - rank: int, - full_nvlink: bool, -) -> int: - return torch.ops._C_vllm_ar.init_custom_ar( - ipc_tensors, rank_data, rank, full_nvlink - ) - - -def all_reduce( - fa: int, - inp: torch.Tensor, - out: torch.Tensor, - reg_buffer: int, - reg_buffer_sz_bytes: int, -) -> None: - torch.ops._C_vllm_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes) - - -def dispose(fa: int) -> None: - torch.ops._C_vllm_ar.dispose(fa) - - -def meta_size() -> int: - return torch.ops._C_vllm_ar.meta_size() - +if use_vllm_custom_allreduce: + # custom ar + def init_custom_ar( + ipc_tensors: List[torch.Tensor], + rank_data: torch.Tensor, + rank: int, + full_nvlink: bool, + ) -> int: + return torch.ops._C_custom_ar.init_custom_ar( + ipc_tensors, rank_data, rank, full_nvlink + ) -def register_buffer(fa: int, ipc_tensors: List[int]) -> None: - return torch.ops._C_vllm_ar.register_buffer(fa, ipc_tensors) + def all_reduce( + fa: int, + inp: torch.Tensor, + out: torch.Tensor, + reg_buffer: int, + reg_buffer_sz_bytes: int, + ) -> None: + torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes) + + def dispose(fa: int) -> None: + torch.ops._C_custom_ar.dispose(fa) + + def meta_size() -> int: + return torch.ops._C_custom_ar.meta_size() + + def register_buffer(fa: int, ipc_tensors: List[int]) -> None: + return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors) + + def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: + return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa) + + def register_graph_buffers( + fa: int, handles: List[List[int]], offsets: List[List[int]] + ) -> None: + torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) + +else: + # custom ar + def init_custom_ar( + rank_id: int, + world_size: int, + rank_data_base: torch.Tensor, + buffers: List[int], + tmp_result_buffers: List[int], + barrier_in: List[int], + barrier_out: List[int], + ) -> int: + return sgl_kernel.ops.init_custom_reduce( + rank_id, + world_size, + rank_data_base, + buffers, + tmp_result_buffers, + barrier_in, + barrier_out, + ) + def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: + sgl_kernel.ops.custom_reduce(fa, inp, out) -def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: - return torch.ops._C_vllm_ar.get_graph_buffer_ipc_meta(fa) + def dispose(fa: int) -> None: + sgl_kernel.ops.custom_dispose(fa) + def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: + return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa) -def register_graph_buffers( - fa: int, handles: List[List[int]], offsets: List[List[int]] -) -> None: - torch.ops._C_vllm_ar.register_graph_buffers(fa, handles, offsets) + def register_graph_buffers( + fa: int, handles: List[List[int]], offsets: List[List[int]] + ) -> None: + sgl_kernel.ops.register_graph_buffers(fa, handles, offsets) # temporary fix for https://github.com/vllm-project/vllm/issues/5456 diff --git a/python/sglang/srt/aio_rwlock.py b/python/sglang/srt/aio_rwlock.py new file mode 100644 index 00000000000..deda1fe7903 --- /dev/null +++ b/python/sglang/srt/aio_rwlock.py @@ -0,0 +1,100 @@ +import asyncio + + +class RWLock: + def __init__(self): + # Protects internal state + self._lock = asyncio.Lock() + + # Condition variable used to wait for state changes + self._cond = asyncio.Condition(self._lock) + + # Number of readers currently holding the lock + self._readers = 0 + + # Whether a writer is currently holding the lock + self._writer_active = False + + # How many writers are queued waiting for a turn + self._waiting_writers = 0 + + @property + def reader_lock(self): + """ + A context manager for acquiring a shared (reader) lock. + + Example: + async with rwlock.reader_lock: + # read-only access + """ + return _ReaderLock(self) + + @property + def writer_lock(self): + """ + A context manager for acquiring an exclusive (writer) lock. + + Example: + async with rwlock.writer_lock: + # exclusive access + """ + return _WriterLock(self) + + async def acquire_reader(self): + async with self._lock: + # Wait until there is no active writer or waiting writer + # to ensure fairness. + while self._writer_active or self._waiting_writers > 0: + await self._cond.wait() + self._readers += 1 + + async def release_reader(self): + async with self._lock: + self._readers -= 1 + # If this was the last reader, wake up anyone waiting + # (potentially a writer or new readers). + if self._readers == 0: + self._cond.notify_all() + + async def acquire_writer(self): + async with self._lock: + # Increment the count of writers waiting + self._waiting_writers += 1 + try: + # Wait while either a writer is active or readers are present + while self._writer_active or self._readers > 0: + await self._cond.wait() + self._writer_active = True + finally: + # Decrement waiting writers only after we've acquired the writer lock + self._waiting_writers -= 1 + + async def release_writer(self): + async with self._lock: + self._writer_active = False + # Wake up anyone waiting (readers or writers) + self._cond.notify_all() + + +class _ReaderLock: + def __init__(self, rwlock: RWLock): + self._rwlock = rwlock + + async def __aenter__(self): + await self._rwlock.acquire_reader() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self._rwlock.release_reader() + + +class _WriterLock: + def __init__(self, rwlock: RWLock): + self._rwlock = rwlock + + async def __aenter__(self): + await self._rwlock.acquire_writer() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self._rwlock.release_writer() diff --git a/python/sglang/srt/configs/__init__.py b/python/sglang/srt/configs/__init__.py index 600b58e4937..3d81c5d4fd5 100644 --- a/python/sglang/srt/configs/__init__.py +++ b/python/sglang/srt/configs/__init__.py @@ -1,3 +1,5 @@ +from sglang.srt.configs.chatglm import ChatGLMConfig +from sglang.srt.configs.dbrx import DbrxConfig from sglang.srt.configs.exaone import ExaoneConfig from sglang.srt.configs.qwen2vl import Qwen2VLConfig, Qwen2VLVisionConfig @@ -5,4 +7,6 @@ "ExaoneConfig", "Qwen2VLConfig", "Qwen2VLVisionConfig", + "ChatGLMConfig", + "DbrxConfig", ] diff --git a/python/sglang/srt/configs/chatglm.py b/python/sglang/srt/configs/chatglm.py new file mode 100644 index 00000000000..9370c218aab --- /dev/null +++ b/python/sglang/srt/configs/chatglm.py @@ -0,0 +1,78 @@ +# Adapted from +# https://github.com/THUDM/ChatGLM2-6B +# https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/chatglm.py + +# ChatGLM2 and ChatGLM3 share the same config. +# ChatGLM4 is officially supported by Huggingface +# transformers >= 4.46.0 is required +# https://huggingface.co/docs/transformers/en/model_doc/glm +from transformers import PretrainedConfig + + +class ChatGLMConfig(PretrainedConfig): + model_type = "chatglm" + attribute_map = { + "num_hidden_layers": "num_layers", + "n_head_kv": "multi_query_group_num", + } + + def __init__( + self, + num_layers=28, + padded_vocab_size=65024, + hidden_size=4096, + ffn_hidden_size=13696, + kv_channels=128, + num_attention_heads=32, + seq_length=2048, + hidden_dropout=0.0, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + rmsnorm=True, + apply_residual_connection_post_layernorm=False, + post_layer_norm=True, + add_bias_linear=False, + add_qkv_bias=False, + interleaved_qkv=False, + bias_dropout_fusion=True, + multi_query_attention=False, + multi_query_group_num=1, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=True, + fp32_residual_connection=False, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs + ): + self.num_layers = num_layers + self.vocab_size = padded_vocab_size + self.padded_vocab_size = padded_vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.kv_channels = kv_channels + self.num_attention_heads = num_attention_heads + self.seq_length = seq_length + # It is to be compatible with long lora. + self.max_position_embeddings = seq_length + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.layernorm_epsilon = layernorm_epsilon + self.rmsnorm = rmsnorm + self.apply_residual_connection_post_layernorm = ( + apply_residual_connection_post_layernorm + ) + self.post_layer_norm = post_layer_norm + self.add_bias_linear = add_bias_linear + self.add_qkv_bias = add_qkv_bias + self.bias_dropout_fusion = bias_dropout_fusion + self.multi_query_attention = multi_query_attention + self.multi_query_group_num = multi_query_group_num + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.fp32_residual_connection = fp32_residual_connection + self.quantization_bit = quantization_bit + self.pre_seq_len = pre_seq_len + self.prefix_projection = prefix_projection + self.interleaved_qkv = interleaved_qkv + super().__init__(**kwargs) diff --git a/python/sglang/srt/configs/dbrx.py b/python/sglang/srt/configs/dbrx.py new file mode 100644 index 00000000000..75ccbde944e --- /dev/null +++ b/python/sglang/srt/configs/dbrx.py @@ -0,0 +1,279 @@ +# Adapted from +# https://huggingface.co/databricks/dbrx-base/blob/main/configuration_dbrx.py +# https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/dbrx.py +"""Dbrx configuration.""" + +from typing import Any, Optional + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} # type: ignore + + +class DbrxAttentionConfig(PretrainedConfig): + """Configuration class for Dbrx Attention. + + [`DbrxAttention`] class. It is used to instantiate attention layers + according to the specified arguments, defining the layers architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + attn_pdrop (`float`, *optional*, defaults to 0.0): + The dropout probability for the attention layers. + clip_qkv (`float`, *optional*, defaults to None): + If not `None`, clip the queries, keys, and values in the attention layer to this value. + kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads. + rope_theta (float): The base frequency for rope. + """ + + def __init__( + self, + attn_pdrop: float = 0, + clip_qkv: Optional[float] = None, + kv_n_heads: int = 1, + rope_theta: float = 10000.0, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.attn_pdrop = attn_pdrop + self.clip_qkv = clip_qkv + self.kv_n_heads = kv_n_heads + self.rope_theta = rope_theta + + for k in ["model_type"]: + if k in kwargs: + kwargs.pop(k) + if len(kwargs) != 0: + raise ValueError(f"Found unknown {kwargs=}") + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: str, **kwargs: Any + ) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs + ) + + if config_dict.get("model_type") == "dbrx": + config_dict = config_dict["attn_config"] + + if ( + "model_type" in config_dict + and hasattr(cls, "model_type") + and config_dict["model_type"] != cls.model_type + ): + logger.warning( + "You are using a model of type %s to instantiate a model of " + "type %s. This is not supported for all configurations of " + "models and can yield errors.", + config_dict["model_type"], + cls.model_type, + ) + + return cls.from_dict(config_dict, **kwargs) + + +class DbrxFFNConfig(PretrainedConfig): + """Configuration class for Dbrx FFN. + + [`DbrxFFN`] class. It is used to instantiate feedforward layers according to + the specified arguments, defining the layers architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + ffn_act_fn (dict, optional): A dict specifying activation function for the FFN. + The dict should have a key 'name' with the value being the name of + the activation function along with any additional keyword arguments. + ffn_hidden_size (int, optional): The hidden size of the feedforward network. + moe_num_experts (int, optional): The number of experts in the mixture of experts layer. + moe_top_k (int, optional): The number of experts to use in the mixture of experts layer. + moe_jitter_eps (float, optional): The jitter epsilon for the mixture of experts layer. + moe_loss_weight (float, optional): The loss weight for the mixture of experts layer. + moe_normalize_expert_weights (float, optional): The normalization factor for the expert weights. + uniform_expert_assignment (bool, optional): Whether to use uniform expert assignment. + This should only be used for benchmarking purposes. + """ + + def __init__( + self, + ffn_act_fn: Optional[dict] = None, + ffn_hidden_size: int = 3584, + moe_num_experts: int = 4, + moe_top_k: int = 1, + moe_jitter_eps: Optional[float] = None, + moe_loss_weight: float = 0.01, + moe_normalize_expert_weights: Optional[float] = 1, + uniform_expert_assignment: bool = False, + **kwargs: Any, + ): + super().__init__() + if ffn_act_fn is None: + ffn_act_fn = {"name": "silu"} + self.ffn_act_fn = ffn_act_fn + self.ffn_hidden_size = ffn_hidden_size + self.moe_num_experts = moe_num_experts + self.moe_top_k = moe_top_k + self.moe_jitter_eps = moe_jitter_eps + self.moe_loss_weight = moe_loss_weight + self.moe_normalize_expert_weights = moe_normalize_expert_weights + self.uniform_expert_assignment = uniform_expert_assignment + + for k in ["model_type"]: + if k in kwargs: + kwargs.pop(k) + if len(kwargs) != 0: + raise ValueError(f"Found unknown {kwargs=}") + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: str, **kwargs: Any + ) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs + ) + + if config_dict.get("model_type") == "dbrx": + config_dict = config_dict["ffn_config"] + + if ( + "model_type" in config_dict + and hasattr(cls, "model_type") + and config_dict["model_type"] != cls.model_type + ): + logger.warning( + "You are using a model of type %s to instantiate a model of " + "type %s. This is not supported for all " + "configurations of models and can yield errors.", + config_dict["model_type"], + cls.model_type, + ) + + return cls.from_dict(config_dict, **kwargs) + + +class DbrxConfig(PretrainedConfig): + """Configuration class for Dbrx. + + [`DbrxModel`]. It is used to instantiate a Dbrx model according to the + specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + d_model (`int`, *optional*, defaults to 6144): + Dimensionality of the embeddings and hidden states. + n_heads (`int`, *optional*, defaults to 48): + Number of attention heads for each attention layer in the Transformer encoder. + n_layers (`int`, *optional*, defaults to 40): + Number of hidden layers in the Transformer encoder. + max_seq_len (`int`, *optional*, defaults to 32768): + The maximum sequence length of the model. + vocab_size (`int`, *optional*, defaults to 100352): + Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by + the `inputs_ids` passed when calling [`DbrxModel`]. + resid_pdrop (`float`, *optional*, defaults to 0.0): + The dropout probability applied to the attention output before combining with residual. + emb_pdrop (`float`, *optional*, defaults to 0.0): + The dropout probability for the embedding layer. + attn_config (`dict`, *optional*): + A dictionary used to configure the model's attention module. + ffn_config (`dict`, *optional*): + A dictionary used to configure the model's FFN module. + use_cache (`bool`, *optional*, defaults to `False`): + Whether or not the model should return the last key/values attentions (not used by all models). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss. See [here]() for more details + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + + + Example: + ```python + >>> from transformers import DbrxConfig, DbrxModel + + >>> # Initializing a Dbrx configuration + >>> configuration = DbrxConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = DbrxModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "dbrx" + attribute_map = { + "num_attention_heads": "n_heads", + "hidden_size": "d_model", + "num_hidden_layers": "n_layers", + "max_position_embeddings": "max_seq_len", + } + + def __init__( + self, + d_model: int = 2048, + n_heads: int = 16, + n_layers: int = 24, + max_seq_len: int = 2048, + vocab_size: int = 32000, + resid_pdrop: float = 0.0, + emb_pdrop: float = 0.0, + attn_config: Optional[DbrxAttentionConfig] = None, + ffn_config: Optional[DbrxFFNConfig] = None, + use_cache: bool = True, + initializer_range: float = 0.02, + output_router_logits: bool = False, + router_aux_loss_coef: float = 0.05, + **kwargs: Any, + ): + if attn_config is None: + self.attn_config = DbrxAttentionConfig() + elif isinstance(attn_config, dict): + self.attn_config = DbrxAttentionConfig(**attn_config) + else: + self.attn_config = attn_config + + if ffn_config is None: + self.ffn_config = DbrxFFNConfig() + elif isinstance(ffn_config, dict): + self.ffn_config = DbrxFFNConfig(**ffn_config) + else: + self.ffn_config = ffn_config + + self.d_model = d_model + self.n_heads = n_heads + self.n_layers = n_layers + self.max_seq_len = max_seq_len + self.vocab_size = vocab_size + self.resid_pdrop = resid_pdrop + self.emb_pdrop = emb_pdrop + self.use_cache = use_cache + self.initializer_range = initializer_range + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + + tie_word_embeddings = kwargs.pop("tie_word_embeddings", False) + if tie_word_embeddings: + raise ValueError("tie_word_embeddings is not supported for Dbrx models.") + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/python/sglang/srt/configs/device_config.py b/python/sglang/srt/configs/device_config.py index 74deb891902..d95e848ddae 100644 --- a/python/sglang/srt/configs/device_config.py +++ b/python/sglang/srt/configs/device_config.py @@ -10,7 +10,7 @@ class DeviceConfig: device: Optional[torch.device] def __init__(self, device: str = "cuda") -> None: - if device in ["cuda", "xpu", "hpu"]: + if device in ["cuda", "xpu", "hpu", "cpu"]: self.device_type = device else: raise RuntimeError(f"Not supported device type: {device}") diff --git a/python/sglang/srt/configs/load_config.py b/python/sglang/srt/configs/load_config.py index 2b2b341faeb..6cb35ab47c6 100644 --- a/python/sglang/srt/configs/load_config.py +++ b/python/sglang/srt/configs/load_config.py @@ -20,6 +20,7 @@ class LoadFormat(str, enum.Enum): GGUF = "gguf" BITSANDBYTES = "bitsandbytes" MISTRAL = "mistral" + LAYERED = "layered" @dataclass diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 596afb83e0b..6d144f84433 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -15,7 +15,7 @@ import json import logging from enum import IntEnum, auto -from typing import List, Optional, Union +from typing import List, Optional, Set, Union import torch from transformers import PretrainedConfig @@ -47,6 +47,7 @@ def __init__( self.model_path = model_path self.revision = revision self.quantization = quantization + # Parse args self.model_override_args = json.loads(model_override_args) self.hf_config = get_config( @@ -94,7 +95,10 @@ def __init__( ) # FIXME: temporary special judge for MLA architecture - if "DeepseekV2ForCausalLM" in self.hf_config.architectures: + if ( + "DeepseekV2ForCausalLM" in self.hf_config.architectures + or "DeepseekV3ForCausalLM" in self.hf_config.architectures + ): self.head_dim = 256 self.attention_arch = AttentionArch.MLA self.kv_lora_rank = self.hf_config.kv_lora_rank @@ -124,8 +128,13 @@ def __init__( self.num_hidden_layers = self.hf_text_config.num_hidden_layers self.vocab_size = self.hf_text_config.vocab_size + # Verify quantization self._verify_quantization() + # Cache attributes + self.hf_eos_token_id = self.get_hf_eos_token_id() + self.image_token_id = getattr(self.hf_config, "image_token_id", None) + # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289 def get_total_num_kv_heads(self) -> int: """Returns the total number of KV heads.""" @@ -214,7 +223,11 @@ def _verify_quantization(self) -> None: "compressed_tensors", "compressed-tensors", "experts_int8", + "w8a8_int8", ] + compatible_quantization_methods = { + "w8a8_int8": ["compressed-tensors", "compressed_tensors"] + } if self.quantization is not None: self.quantization = self.quantization.lower() @@ -238,12 +251,17 @@ def _verify_quantization(self) -> None: if self.quantization is None: self.quantization = quant_method elif self.quantization != quant_method: - raise ValueError( - "Quantization method specified in the model config " - f"({quant_method}) does not match the quantization " - f"method specified in the `quantization` argument " - f"({self.quantization})." - ) + if ( + self.quantization not in compatible_quantization_methods + or quant_method + not in compatible_quantization_methods[self.quantization] + ): + raise ValueError( + "Quantization method specified in the model config " + f"({quant_method}) does not match the quantization " + f"method specified in the `quantization` argument " + f"({self.quantization})." + ) if self.quantization is not None: if self.quantization not in supported_quantization: @@ -264,6 +282,13 @@ def _verify_quantization(self) -> None: self.quantization, ) + def get_hf_eos_token_id(self) -> Optional[Set[int]]: + eos_ids = getattr(self.hf_config, "eos_token_id", None) + if eos_ids: + # it can be either int or list of int + eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids) + return eos_ids + def get_hf_text_config(config: PretrainedConfig): """Get the "sub" config relevant to llm for multi modal models. @@ -377,6 +402,7 @@ def is_multimodal_model(model_architectures: List[str]): or "LlavaVidForCausalLM" in model_architectures or "MllamaForConditionalGeneration" in model_architectures or "Qwen2VLForConditionalGeneration" in model_architectures + or "MiniCPMV" in model_architectures ): return True else: diff --git a/python/sglang/srt/constrained/__init__.py b/python/sglang/srt/constrained/__init__.py deleted file mode 100644 index 458d1925241..00000000000 --- a/python/sglang/srt/constrained/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -# TODO(lmzheng): make this an optional dependency -from sglang.srt.constrained.outlines_backend import build_regex_from_object diff --git a/python/sglang/srt/constrained/base_grammar_backend.py b/python/sglang/srt/constrained/base_grammar_backend.py index 7c88229cf16..6f304ea171e 100644 --- a/python/sglang/srt/constrained/base_grammar_backend.py +++ b/python/sglang/srt/constrained/base_grammar_backend.py @@ -18,6 +18,8 @@ from threading import Event, Lock from typing import Any, Optional, Tuple +from sglang.srt.server_args import ServerArgs + @dataclass class CacheEntry: @@ -69,3 +71,22 @@ def get_future_value(self, key: Tuple[str, str]) -> Future: def reset(self): with self.cache_lock: self.cache.clear() + + +def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size): + if server_args.grammar_backend == "outlines": + from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend + + grammar_backend = OutlinesGrammarBackend( + tokenizer, + whitespace_pattern=server_args.constrained_json_whitespace_pattern, + allow_jump_forward=not server_args.disable_jump_forward, + ) + elif server_args.grammar_backend == "xgrammar": + from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend + + grammar_backend = XGrammarGrammarBackend(tokenizer, vocab_size=vocab_size) + else: + raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}") + + return grammar_backend diff --git a/python/sglang/srt/constrained/xgrammar_backend.py b/python/sglang/srt/constrained/xgrammar_backend.py index ee8e8eb07f4..c423a567eda 100644 --- a/python/sglang/srt/constrained/xgrammar_backend.py +++ b/python/sglang/srt/constrained/xgrammar_backend.py @@ -19,6 +19,7 @@ import torch from xgrammar import ( CompiledGrammar, + Grammar, GrammarCompiler, GrammarMatcher, TokenizerInfo, @@ -117,17 +118,29 @@ def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar: key_type, key_string = key if key_type == "json": try: - ctx = self.grammar_compiler.compile_json_schema(schema=key_string) + if key_string == "$$ANY$$": + ctx = self.grammar_compiler.compile_builtin_json_grammar() + else: + ctx = self.grammar_compiler.compile_json_schema(schema=key_string) except RuntimeError as e: logging.warning( f"Skip invalid json_schema: json_schema={key_string}, {e=}" ) return None + elif key_type == "ebnf": + try: + ctx = self.grammar_compiler.compile_grammar(key_string) + except RuntimeError as e: + logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}") + return None elif key_type == "regex": - logger.warning( - "regex hasn't been supported by xgrammar yet. This is skipped." - ) - return None + try: + ctx = self.grammar_compiler.compile_grammar( + Grammar.from_regex(key_string) + ) + except RuntimeError as e: + logging.warning(f"Skip invalid regex: regex={key_string}, {e=}") + return None else: raise ValueError(f"Invalid key_type: {key_type}") diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index 60dba87cb08..3a775aa1e95 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -452,7 +452,6 @@ def generate_chat_conv( # Add a blank message for the assistant. conv.append_message(conv.roles[1], None) - return conv @@ -555,3 +554,17 @@ def generate_chat_conv( image_token="<|vision_start|><|image_pad|><|vision_end|>", ) ) + +# Reference: https://huggingface.co/openbmb/MiniCPM-V-2_6#usage +register_conv_template( + Conversation( + name="minicpmv", + system_message="You are a helpful assistant", + system_template="<|im_start|>system\n{system_message}.", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep="<|im_end|>\n", + sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE, + stop_str=("<|im_end|>", "<|endoftext|>"), + image_token="(./)", + ) +) diff --git a/python/sglang/srt/distributed/__init__.py b/python/sglang/srt/distributed/__init__.py index db325cfabf5..12f802055c5 100644 --- a/python/sglang/srt/distributed/__init__.py +++ b/python/sglang/srt/distributed/__init__.py @@ -1,3 +1,3 @@ -from .communication_op import * -from .parallel_state import * -from .utils import * +from sglang.srt.distributed.communication_op import * +from sglang.srt.distributed.parallel_state import * +from sglang.srt.distributed.utils import * diff --git a/python/sglang/srt/distributed/communication_op.py b/python/sglang/srt/distributed/communication_op.py index ddf3b8ef568..95600edfb41 100644 --- a/python/sglang/srt/distributed/communication_op.py +++ b/python/sglang/srt/distributed/communication_op.py @@ -1,4 +1,5 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/communication_op.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/communication_op.py + from typing import Any, Dict, Optional, Union import torch diff --git a/python/sglang/srt/distributed/device_communicators/cuda_wrapper.py b/python/sglang/srt/distributed/device_communicators/cuda_wrapper.py index ab4ee33fcfc..c902f314112 100644 --- a/python/sglang/srt/distributed/device_communicators/cuda_wrapper.py +++ b/python/sglang/srt/distributed/device_communicators/cuda_wrapper.py @@ -1,4 +1,5 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/cuda_wrapper.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/cuda_wrapper.py + """This file is a pure Python wrapper for the cudart library. It avoids the need to compile a separate shared library, and is convenient for use when we just need to call a few functions. diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index b6df234407d..faeac0bbae9 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -1,4 +1,5 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/custom_all_reduce.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/custom_all_reduce.py + import ctypes import logging import os @@ -6,7 +7,6 @@ from functools import wraps from typing import Callable, List, Optional, TypeVar, Union -import pynvml import torch import torch.distributed as dist from torch.distributed import ProcessGroup @@ -20,8 +20,19 @@ from sglang.srt.distributed.parallel_state import in_the_same_node_as from sglang.srt.utils import cuda_device_count_stateless, is_cuda +logger = logging.getLogger(__name__) + +if is_cuda(): + try: + import pynvml + except ImportError as e: + logger.warning("Failed to import pynvml with %r", e) + try: - ops.meta_size() + if ops.use_vllm_custom_allreduce: + ops.meta_size() + else: + import sgl_kernel custom_ar = True except Exception: # For AMD GPUs and CPUs @@ -29,7 +40,6 @@ logger = logging.getLogger(__name__) - _P = ParamSpec("_P") _R = TypeVar("_R") @@ -47,7 +57,7 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: @with_nvml_context -def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool: +def is_full_nvlink(physical_device_ids: List[int]) -> bool: """ query if the set of gpus are fully connected by nvlink (1 hop) """ @@ -175,9 +185,12 @@ def __init__( # test nvlink first, this will filter out most of the cases # where custom allreduce is not supported # this checks hardware and driver support for NVLink - assert is_cuda() + if is_cuda(): + assert is_cuda() - full_nvlink = is_full_nvlink(physical_device_ids) + full_nvlink = is_full_nvlink(physical_device_ids) + else: + full_nvlink = False if world_size > 2 and not full_nvlink: logger.warning( "Custom allreduce is disabled because it's not supported on" @@ -196,32 +209,64 @@ def __init__( ) return - self.disabled = False - # Buffers memory are owned by this Python class and passed to C++. - # Meta data composes of two parts: meta data for synchronization and a - # temporary buffer for storing intermediate allreduce results. - self.meta_ptrs = self.create_shared_buffer( - ops.meta_size() + max_size, group=group - ) - # This is a pre-registered IPC buffer. In eager mode, input tensors - # are first copied into this buffer before allreduce is performed - self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) - # This is a buffer for storing the tuples of pointers pointing to - # IPC buffers from all ranks. Each registered tuple has size of - # 8*world_size bytes where world_size is at most 8. Allocating 8MB - # is enough for 131072 such tuples. The largest model I've seen only - # needs less than 10000 of registered tuples. - self.rank_data = torch.empty( - 8 * 1024 * 1024, dtype=torch.uint8, device=self.device - ) self.max_size = max_size self.rank = rank self.world_size = world_size self.full_nvlink = full_nvlink - self._ptr = ops.init_custom_ar( - self.meta_ptrs, self.rank_data, rank, self.full_nvlink - ) - ops.register_buffer(self._ptr, self.buffer_ptrs) + + if ops.use_vllm_custom_allreduce: + # Buffers memory are owned by this Python class and passed to C++. + # Meta data composes of two parts: meta data for synchronization and a + # temporary buffer for storing intermediate allreduce results. + self.meta_ptrs = self.create_shared_buffer( + ops.meta_size() + max_size, group=group + ) + # This is a pre-registered IPC buffer. In eager mode, input tensors + # are first copied into this buffer before allreduce is performed + self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) + # This is a buffer for storing the tuples of pointers pointing to + # IPC buffers from all ranks. Each registered tuple has size of + # 8*world_size bytes where world_size is at most 8. Allocating 8MB + # is enough for 131072 such tuples. The largest model I've seen only + # needs less than 10000 of registered tuples. + self.rank_data = torch.empty( + 8 * 1024 * 1024, dtype=torch.uint8, device=self.device + ) + self._ptr = ops.init_custom_ar( + self.meta_ptrs, self.rank_data, rank, self.full_nvlink + ) + ops.register_buffer(self._ptr, self.buffer_ptrs) + else: + # From TensorRT-LLM getMaxRequiredWorkspaceSize + self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024] + + # sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE; + self.barrier_max_size = 8 * (36 + 2) * 8 + + self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) + self.tmp_result_buffer_ptrs = self.create_shared_buffer( + max_size, group=group + ) + self.rank_data_base = torch.empty( + 8 * 1024 * 1024, dtype=torch.uint8, device=self.device + ) + self.barrier_in_ptrs = self.create_shared_buffer( + self.barrier_max_size, group=group + ) + self.barrier_out_ptrs = self.create_shared_buffer( + self.barrier_max_size, group=group + ) + + self._ptr = ops.init_custom_ar( + rank, + world_size, + self.rank_data_base, + self.buffer_ptrs, + self.tmp_result_buffer_ptrs, + self.barrier_in_ptrs, + self.barrier_out_ptrs, + ) + self.disabled = False @staticmethod def create_shared_buffer( @@ -300,12 +345,31 @@ def should_custom_ar(self, inp: torch.Tensor): return False # for 4 or more non NVLink-capable GPUs, custom allreduce provides # little performance improvement over NCCL. - if self.world_size == 2 or self.full_nvlink: - return inp_size < self.max_size + if ops.use_vllm_custom_allreduce: + if self.world_size == 2 or self.full_nvlink: + return inp_size < self.max_size + return False + + if self.world_size == 2: + return ( + inp_size < self.max_size + and inp_size < self.max_required_workspace_size[0] + ) + + if self.full_nvlink: + return ( + inp_size < self.max_size + and inp_size < self.max_required_workspace_size[1] + ) + return False def all_reduce( - self, inp: torch.Tensor, *, out: torch.Tensor = None, registered: bool = False + self, + inp: torch.Tensor, + *, + out: torch.Tensor = None, + registered: bool = False, ): """Performs an out-of-place all reduce. @@ -315,12 +379,15 @@ def all_reduce( """ if out is None: out = torch.empty_like(inp) - if registered: - ops.all_reduce(self._ptr, inp, out, 0, 0) + if ops.use_vllm_custom_allreduce: + if registered: + ops.all_reduce(self._ptr, inp, out, 0, 0) + else: + ops.all_reduce( + self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size + ) else: - ops.all_reduce( - self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size - ) + ops.all_reduce(self._ptr, inp, out) return out def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: @@ -336,17 +403,20 @@ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: # allreduce is out-of-place. return torch.empty_like(input) else: - # Note: outside of cuda graph context, custom allreduce incurs a - # cost of cudaMemcpy, which should be small (<=1% of overall - # latency) compared to the performance gain of using custom kernels return self.all_reduce(input, registered=False) def close(self): if not self.disabled and self._ptr: ops.dispose(self._ptr) + if ops.use_vllm_custom_allreduce: + self.free_shared_buffer(self.meta_ptrs) + self.free_shared_buffer(self.buffer_ptrs) + else: + self.free_shared_buffer(self.buffer_ptrs) + self.free_shared_buffer(self.tmp_result_buffer_ptrs) + self.free_shared_buffer(self.barrier_in_ptrs) + self.free_shared_buffer(self.barrier_out_ptrs) self._ptr = 0 - self.free_shared_buffer(self.meta_ptrs) - self.free_shared_buffer(self.buffer_ptrs) def __del__(self): self.close() diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py index d807dfd5ce5..4073491aa62 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py @@ -1,4 +1,5 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/custom_all_reduce_utils.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/custom_all_reduce_utils.py + import ctypes import json import logging @@ -7,7 +8,6 @@ import subprocess import sys import tempfile -from functools import lru_cache from itertools import product from typing import Dict, List, Optional, Sequence diff --git a/python/sglang/srt/distributed/device_communicators/hpu_communicator.py b/python/sglang/srt/distributed/device_communicators/hpu_communicator.py index 72ef3889e01..722e494cf77 100644 --- a/python/sglang/srt/distributed/device_communicators/hpu_communicator.py +++ b/python/sglang/srt/distributed/device_communicators/hpu_communicator.py @@ -1,4 +1,5 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/hpu_communicator.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/hpu_communicator.py + import torch import torch.distributed as dist from torch.distributed import ProcessGroup diff --git a/python/sglang/srt/distributed/device_communicators/pynccl.py b/python/sglang/srt/distributed/device_communicators/pynccl.py index baee270da90..9f65939f6d9 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl.py @@ -1,8 +1,10 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/pynccl.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/pynccl.py + import logging from contextlib import contextmanager from typing import Optional, Union +# ===================== import region ===================== import torch import torch.distributed as dist from torch.distributed import ProcessGroup, ReduceOp @@ -143,6 +145,57 @@ def all_reduce( cudaStream_t(stream.cuda_stream), ) + def all_gather( + self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None + ): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}" + ) + if stream is None: + stream = self.stream + self.nccl.ncclAllGather( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), + input_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + self.comm, + cudaStream_t(stream.cuda_stream), + ) + + def reduce_scatter( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None, + ): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}" + ) + if stream is None: + stream = self.stream + self.nccl.ncclReduceScatter( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), + output_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), + self.comm, + cudaStream_t(stream.cuda_stream), + ) + def send(self, tensor: torch.Tensor, dst: int, stream=None): if self.disabled: return @@ -179,6 +232,32 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None): cudaStream_t(stream.cuda_stream), ) + def broadcast(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}" + ) + if stream is None: + stream = self.stream + if src == self.rank: + sendbuff = buffer_type(tensor.data_ptr()) + # NCCL requires the sender also to have a receive buffer + recvbuff = buffer_type(tensor.data_ptr()) + else: + sendbuff = buffer_type() + recvbuff = buffer_type(tensor.data_ptr()) + self.nccl.ncclBroadcast( + sendbuff, + recvbuff, + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + src, + self.comm, + cudaStream_t(stream.cuda_stream), + ) + @contextmanager def change_state( self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None diff --git a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py index e72284f5117..afb47733476 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py @@ -1,4 +1,4 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/pynccl.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/pynccl.py # This file is a pure Python wrapper for the NCCL library. # The main purpose is to use NCCL combined with CUDA graph. @@ -57,7 +57,7 @@ def find_nccl_library() -> str: so_file = "librccl.so.1" else: raise ValueError("NCCL only supports CUDA and ROCm backends.") - logger.info("Found nccl from library %s", so_file) + logger.debug("Found nccl from library %s", so_file) return so_file @@ -187,6 +187,43 @@ class NCCLLibrary: cudaStream_t, ], ), + # ncclResult_t ncclAllGather( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function( + "ncclAllGather", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclComm_t, + cudaStream_t, + ], + ), + # ncclResult_t ncclReduceScatter( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function( + "ncclReduceScatter", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclRedOp_t, + ncclComm_t, + cudaStream_t, + ], + ), # ncclResult_t ncclSend( # const void* sendbuff, size_t count, ncclDataType_t datatype, # int dest, ncclComm_t comm, cudaStream_t stream); @@ -217,6 +254,23 @@ class NCCLLibrary: cudaStream_t, ], ), + # ncclResult_t ncclBroadcast( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, int root, ncclComm_t comm, + # cudaStream_t stream); + Function( + "ncclBroadcast", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ctypes.c_int, + ncclComm_t, + cudaStream_t, + ], + ), # be cautious! this is a collective call, it will block until all # processes in the communicator have called this function. # because Python object destruction can happen in random order, @@ -321,6 +375,46 @@ def ncclAllReduce( ) ) + def ncclReduceScatter( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + op: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK( + self._funcs["ncclReduceScatter"]( + sendbuff, recvbuff, count, datatype, op, comm, stream + ) + ) + + def ncclAllGather( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + # `datatype` actually should be `ncclDataType_t` + # which is an aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK( + self._funcs["ncclAllGather"]( + sendbuff, recvbuff, count, datatype, comm, stream + ) + ) + def ncclSend( self, sendbuff: buffer_type, @@ -347,6 +441,22 @@ def ncclRecv( self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream) ) + def ncclBroadcast( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + root: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + self.NCCL_CHECK( + self._funcs["ncclBroadcast"]( + sendbuff, recvbuff, count, datatype, root, comm, stream + ) + ) + def ncclCommDestroy(self, comm: ncclComm_t) -> None: self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) diff --git a/python/sglang/srt/distributed/device_communicators/shm_broadcast.py b/python/sglang/srt/distributed/device_communicators/shm_broadcast.py index 1afe6fca526..7a3b22e27a8 100644 --- a/python/sglang/srt/distributed/device_communicators/shm_broadcast.py +++ b/python/sglang/srt/distributed/device_communicators/shm_broadcast.py @@ -1,11 +1,9 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/shm_broadcast.py -import ipaddress +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/shm_broadcast.py + import logging import os import pickle -import socket import time -import warnings from contextlib import contextmanager from dataclasses import dataclass, field from multiprocessing import shared_memory @@ -18,6 +16,8 @@ from zmq import IPV6 # type: ignore from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore +from sglang.srt.utils import get_ip, get_open_port, is_valid_ipv6_address + # SGLANG_RINGBUFFER_WARNING_INTERVAL can be set to 60 SGLANG_RINGBUFFER_WARNING_INTERVAL = int( os.environ.get("SGLANG_RINGBUFFER_WARNING_INTERVAL", "60") @@ -26,73 +26,6 @@ logger = logging.getLogger(__name__) -def get_ip() -> str: - # SGLANG_HOST_IP env can be ignore - host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "") - if host_ip: - return host_ip - - # IP is not set, try to get it from the network interface - - # try ipv4 - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - try: - s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable - return s.getsockname()[0] - except Exception: - pass - - # try ipv6 - try: - s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) - # Google's public DNS server, see - # https://developers.google.com/speed/public-dns/docs/using#addresses - s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable - return s.getsockname()[0] - except Exception: - pass - - warnings.warn( - "Failed to get the IP address, using 0.0.0.0 by default." - "The value can be set by the environment variable" - " SGLANG_HOST_IP or HOST_IP.", - stacklevel=2, - ) - return "0.0.0.0" - - -def get_open_port() -> int: - - port = os.getenv("SGLANG_PORT") - if port is not None: - while True: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", port)) - return port - except OSError: - port += 1 # Increment port number if already in use - logger.info("Port %d is already in use, trying port %d", port - 1, port) - # try ipv4 - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) - return s.getsockname()[1] - except OSError: - # try ipv6 - with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: - s.bind(("", 0)) - return s.getsockname()[1] - - -def is_valid_ipv6_address(address: str) -> bool: - try: - ipaddress.IPv6Address(address) - return True - except ValueError: - return False - - class ShmRingBuffer: def __init__( @@ -313,7 +246,7 @@ def __init__( remote_subscribe_port=remote_subscribe_port, ) - logger.info("vLLM message queue communication handle: %s", self.handle) + logger.debug("Message queue communication handle: %s", self.handle) def export_handle(self) -> Handle: return self.handle diff --git a/python/sglang/srt/distributed/device_communicators/xpu_communicator.py b/python/sglang/srt/distributed/device_communicators/xpu_communicator.py index ff0981b80bc..532279f70c3 100644 --- a/python/sglang/srt/distributed/device_communicators/xpu_communicator.py +++ b/python/sglang/srt/distributed/device_communicators/xpu_communicator.py @@ -1,4 +1,5 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/xpu_communicator.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/xpu_communicator.py + import torch import torch.distributed as dist from torch.distributed import ProcessGroup diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 26d04b04ce9..c6d1a830781 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -1,4 +1,4 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/parallel_state.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/parallel_state.py # Copyright 2023 The vLLM team. # Adapted from diff --git a/python/sglang/srt/distributed/utils.py b/python/sglang/srt/distributed/utils.py index a225fbb9182..e117aa30d07 100644 --- a/python/sglang/srt/distributed/utils.py +++ b/python/sglang/srt/distributed/utils.py @@ -1,4 +1,5 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/utils.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/utils.py + # Copyright 2023 The vLLM team. # Adapted from # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py new file mode 100644 index 00000000000..098a3d1e325 --- /dev/null +++ b/python/sglang/srt/entrypoints/engine.py @@ -0,0 +1,452 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +The entry point of inference server. (SRT = SGLang Runtime) + +This file implements python APIs for the inference engine. +""" + +import asyncio +import atexit +import dataclasses +import logging +import multiprocessing as mp +import os +import signal +import threading +from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union + +# Fix a bug of Python threading +setattr(threading, "_register_atexit", lambda *args, **kwargs: None) + +import torch +import uvloop + +from sglang.srt.managers.data_parallel_controller import ( + run_data_parallel_controller_process, +) +from sglang.srt.managers.detokenizer_manager import run_detokenizer_process +from sglang.srt.managers.io_struct import ( + EmbeddingReqInput, + GenerateReqInput, + GetWeightsByNameReqInput, + InitWeightsUpdateGroupReqInput, + ReleaseMemoryOccupationReqInput, + ResumeMemoryOccupationReqInput, + UpdateWeightsFromDistributedReqInput, + UpdateWeightsFromTensorReqInput, +) +from sglang.srt.managers.scheduler import run_scheduler_process +from sglang.srt.managers.tokenizer_manager import TokenizerManager +from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter +from sglang.srt.utils import ( + MultiprocessingSerializer, + assert_pkg_version, + configure_logger, + kill_process_tree, + launch_dummy_health_check_server, + maybe_set_triton_cache_manager, + prepare_model_and_tokenizer, + set_prometheus_multiproc_dir, + set_ulimit, +) +from sglang.version import __version__ + +logger = logging.getLogger(__name__) +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + + +class Engine: + """ + The entry point to the inference engine. + + - The engine consists of three components: + 1. TokenizerManager: Tokenizes the requests and sends them to the scheduler. + 2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager. + 3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager. + + Note: + 1. The HTTP server, Engine, and TokenizerManager both run in the main process. + 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library. + """ + + def __init__(self, **kwargs): + """ + The arguments of this function is the same as `sglang/srt/server_args.py::ServerArgs`. + Please refer to `ServerArgs` for the documentation. + """ + if "server_args" in kwargs: + # Directly load server_args + server_args = kwargs["server_args"] + else: + # Construct server_args from kwargs + if "log_level" not in kwargs: + # Do not print logs by default + kwargs["log_level"] = "error" + server_args = ServerArgs(**kwargs) + + # Shutdown the subprocesses automatically when the program exists + atexit.register(self.shutdown) + + # Launch subprocesses + tokenizer_manager, scheduler_info = _launch_subprocesses( + server_args=server_args + ) + self.tokenizer_manager = tokenizer_manager + self.scheduler_info = scheduler_info + + def generate( + self, + # The input prompt. It can be a single prompt or a batch of prompts. + prompt: Optional[Union[List[str], str]] = None, + sampling_params: Optional[Union[List[Dict], Dict]] = None, + # The token ids for text; one can either specify text or input_ids. + input_ids: Optional[Union[List[List[int]], List[int]]] = None, + return_logprob: Optional[Union[List[bool], bool]] = False, + logprob_start_len: Optional[Union[List[int], int]] = None, + top_logprobs_num: Optional[Union[List[int], int]] = None, + lora_path: Optional[List[Optional[str]]] = None, + custom_logit_processor: Optional[Union[List[str], str]] = None, + stream: bool = False, + ) -> Union[Dict, Iterator[Dict]]: + """ + The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. + Please refer to `GenerateReqInput` for the documentation. + """ + obj = GenerateReqInput( + text=prompt, + input_ids=input_ids, + sampling_params=sampling_params, + return_logprob=return_logprob, + logprob_start_len=logprob_start_len, + top_logprobs_num=top_logprobs_num, + lora_path=lora_path, + custom_logit_processor=custom_logit_processor, + stream=stream, + ) + loop = asyncio.get_event_loop() + generator = self.tokenizer_manager.generate_request(obj, None) + + if stream: + + def generator_wrapper(): + while True: + try: + chunk = loop.run_until_complete(generator.__anext__()) + yield chunk + except StopAsyncIteration: + break + + return generator_wrapper() + else: + ret = loop.run_until_complete(generator.__anext__()) + return ret + + async def async_generate( + self, + # The input prompt. It can be a single prompt or a batch of prompts. + prompt: Optional[Union[List[str], str]] = None, + sampling_params: Optional[Union[List[Dict], Dict]] = None, + # The token ids for text; one can either specify text or input_ids. + input_ids: Optional[Union[List[List[int]], List[int]]] = None, + return_logprob: Optional[Union[List[bool], bool]] = False, + logprob_start_len: Optional[Union[List[int], int]] = None, + top_logprobs_num: Optional[Union[List[int], int]] = None, + lora_path: Optional[List[Optional[str]]] = None, + custom_logit_processor: Optional[Union[List[str], str]] = None, + stream: bool = False, + ) -> Union[Dict, AsyncIterator[Dict]]: + """ + The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. + Please refer to `GenerateReqInput` for the documentation. + """ + obj = GenerateReqInput( + text=prompt, + input_ids=input_ids, + sampling_params=sampling_params, + return_logprob=return_logprob, + logprob_start_len=logprob_start_len, + top_logprobs_num=top_logprobs_num, + lora_path=lora_path, + stream=stream, + custom_logit_processor=custom_logit_processor, + ) + generator = self.tokenizer_manager.generate_request(obj, None) + + if stream is True: + return generator + else: + return await generator.__anext__() + + def encode( + self, + prompt: Union[str, List[str], List[Dict], List[List[Dict]]], + ) -> Dict: + """ + The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`. + Please refer to `EmbeddingReqInput` for the documentation. + """ + + obj = EmbeddingReqInput(text=prompt) + loop = asyncio.get_event_loop() + generator = self.tokenizer_manager.generate_request(obj, None) + ret = loop.run_until_complete(generator.__anext__()) + return ret + + def shutdown(self): + """Shutdown the engine""" + kill_process_tree(os.getpid(), include_parent=False) + + def start_profile(self): + self.tokenizer_manager.start_profile() + + def stop_profile(self): + self.tokenizer_manager.stop_profile() + + def get_server_info(self): + return { + **dataclasses.asdict(self.tokenizer_manager.server_args), # server args + **self.scheduler_info, + "version": __version__, + } + + def init_weights_update_group( + self, + master_address: str, + master_port: int, + rank_offset: int, + world_size: int, + group_name: str, + backend: str = "nccl", + ): + """Initialize parameter update group.""" + obj = InitWeightsUpdateGroupReqInput( + master_address=master_address, + master_port=master_port, + rank_offset=rank_offset, + world_size=world_size, + group_name=group_name, + backend=backend, + ) + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.init_weights_update_group(obj, None) + ) + + def update_weights_from_distributed(self, name: str, dtype, shape): + """Update weights from distributed source.""" + obj = UpdateWeightsFromDistributedReqInput( + name=name, + dtype=dtype, + shape=shape, + ) + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.update_weights_from_distributed(obj, None) + ) + + def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]): + """Update weights from distributed source.""" + obj = UpdateWeightsFromTensorReqInput( + serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors) + ) + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.update_weights_from_tensor(obj, None) + ) + + def get_weights_by_name(self, name: str, truncate_size: int = 100): + """Get weights by parameter name.""" + obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.get_weights_by_name(obj, None) + ) + + def release_memory_occupation(self): + """Release GPU occupation temporarily.""" + obj = ReleaseMemoryOccupationReqInput() + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.release_memory_occupation(obj, None) + ) + + def resume_memory_occupation(self): + """Resume GPU occupation.""" + obj = ResumeMemoryOccupationReqInput() + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.resume_memory_occupation(obj, None) + ) + + +def _set_envs_and_config(server_args: ServerArgs): + # Set global environments + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + os.environ["NCCL_CUMEM_ENABLE"] = "0" + os.environ["NCCL_NVLS_ENABLE"] = "0" + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" + + # Set prometheus env vars + if server_args.enable_metrics: + set_prometheus_multiproc_dir() + + # Set ulimit + set_ulimit() + + # Fix triton bugs + if server_args.tp_size * server_args.dp_size > 1: + # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. + maybe_set_triton_cache_manager() + + # Check flashinfer version + if server_args.attention_backend == "flashinfer": + assert_pkg_version( + "flashinfer", + "0.1.6", + "Please uninstall the old version and " + "reinstall the latest version by following the instructions " + "at https://docs.flashinfer.ai/installation.html.", + ) + + # Register the signal handler. + # The child processes will send SIGQUIT to this process when any error happens + # This process then clean up the whole process tree + def sigquit_handler(signum, frame): + logger.error( + "Received sigquit from a child proces. It usually means the child failed." + ) + kill_process_tree(os.getpid()) + + signal.signal(signal.SIGQUIT, sigquit_handler) + + # Set mp start method + mp.set_start_method("spawn", force=True) + + +def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dict]: + """ + Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess. + """ + # Configure global environment + configure_logger(server_args) + server_args.check_server_args() + _set_envs_and_config(server_args) + + # Allocate ports for inter-process communications + port_args = PortArgs.init_new(server_args) + logger.info(f"{server_args=}") + + # If using model from www.modelscope.cn, first download the model. + server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer( + server_args.model_path, server_args.tokenizer_path + ) + + scheduler_procs = [] + if server_args.dp_size == 1: + # Launch tensor parallel scheduler processes + memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=server_args.enable_memory_saver + ) + + scheduler_pipe_readers = [] + tp_size_per_node = server_args.tp_size // server_args.nnodes + tp_rank_range = range( + tp_size_per_node * server_args.node_rank, + tp_size_per_node * (server_args.node_rank + 1), + ) + for tp_rank in tp_rank_range: + reader, writer = mp.Pipe(duplex=False) + gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node + proc = mp.Process( + target=run_scheduler_process, + args=(server_args, port_args, gpu_id, tp_rank, None, writer), + ) + with memory_saver_adapter.configure_subprocess(): + proc.start() + scheduler_procs.append(proc) + scheduler_pipe_readers.append(reader) + else: + # Launch the data parallel controller + reader, writer = mp.Pipe(duplex=False) + scheduler_pipe_readers = [reader] + proc = mp.Process( + target=run_data_parallel_controller_process, + args=(server_args, port_args, writer), + ) + proc.start() + scheduler_procs.append(proc) + + if server_args.node_rank >= 1: + # In multi-node cases, non-zero rank nodes do not need to run tokenizer or detokenizer, + # so they can just wait here. + + for reader in scheduler_pipe_readers: + data = reader.recv() + assert data["status"] == "ready" + + if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0": + # When using `Engine` as a Python API, we don't want to block here. + return None, None + + launch_dummy_health_check_server(server_args.host, server_args.port) + + for proc in scheduler_procs: + proc.join() + logger.error( + f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}" + ) + return None, None + + # Launch detokenizer process + detoken_proc = mp.Process( + target=run_detokenizer_process, + args=( + server_args, + port_args, + ), + ) + detoken_proc.start() + + # Launch tokenizer process + tokenizer_manager = TokenizerManager(server_args, port_args) + if server_args.chat_template: + load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) + + # Wait for the model to finish loading + scheduler_infos = [] + for i in range(len(scheduler_pipe_readers)): + try: + data = scheduler_pipe_readers[i].recv() + except EOFError: + logger.error( + f"Rank {i} scheduler is dead. Please check if there are relevant logs." + ) + scheduler_procs[i].join() + logger.error(f"Exit code: {scheduler_procs[i].exitcode}") + raise + + if data["status"] != "ready": + raise RuntimeError( + "Initialization failed. Please see the error messages above." + ) + scheduler_infos.append(data) + + # Assume all schedulers have the same scheduler_info + scheduler_info = scheduler_infos[0] + tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] + return tokenizer_manager, scheduler_info diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py new file mode 100644 index 00000000000..1759cd2bb60 --- /dev/null +++ b/python/sglang/srt/entrypoints/http_server.py @@ -0,0 +1,603 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +The entry point of inference server. (SRT = SGLang Runtime) + +This file implements HTTP APIs for the inferenc engine via fastapi. +""" + +import asyncio +import dataclasses +import logging +import multiprocessing as multiprocessing +import os +import threading +import time +from http import HTTPStatus +from typing import AsyncIterator, Dict, Optional + +# Fix a bug of Python threading +setattr(threading, "_register_atexit", lambda *args, **kwargs: None) + +import orjson +import requests +import uvicorn +import uvloop +from fastapi import FastAPI, File, Form, Request, UploadFile +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import ORJSONResponse, Response, StreamingResponse + +from sglang.srt.entrypoints.engine import _launch_subprocesses +from sglang.srt.function_call_parser import FunctionCallParser +from sglang.srt.managers.io_struct import ( + CloseSessionReqInput, + ConfigureLoggingReq, + EmbeddingReqInput, + FunctionCallReqInput, + GenerateReqInput, + GetWeightsByNameReqInput, + InitWeightsUpdateGroupReqInput, + OpenSessionReqInput, + ReleaseMemoryOccupationReqInput, + ResumeMemoryOccupationReqInput, + UpdateWeightFromDiskReqInput, + UpdateWeightsFromDistributedReqInput, +) +from sglang.srt.managers.tokenizer_manager import TokenizerManager +from sglang.srt.metrics.func_timer import enable_func_timer +from sglang.srt.openai_api.adapter import ( + v1_batches, + v1_cancel_batch, + v1_chat_completions, + v1_completions, + v1_delete_file, + v1_embeddings, + v1_files_create, + v1_retrieve_batch, + v1_retrieve_file, + v1_retrieve_file_content, +) +from sglang.srt.openai_api.protocol import ModelCard, ModelList +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import ( + add_api_key_middleware, + add_prometheus_middleware, + delete_directory, + kill_process_tree, + set_uvicorn_logging_configs, +) +from sglang.utils import get_exception_traceback +from sglang.version import __version__ + +logger = logging.getLogger(__name__) +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + +# Fast API +app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +# Store global states +@dataclasses.dataclass +class _GlobalState: + tokenizer_manager: TokenizerManager + scheduler_info: Dict + + +_global_state: Optional[_GlobalState] = None + + +def set_global_state(global_state: _GlobalState): + global _global_state + _global_state = global_state + + +##### Native API endpoints ##### + + +@app.get("/health") +async def health() -> Response: + """Check the health of the http server.""" + return Response(status_code=200) + + +@app.get("/health_generate") +async def health_generate(request: Request) -> Response: + """Check the health of the inference server by generating one token.""" + + sampling_params = {"max_new_tokens": 1, "temperature": 0.7} + + if _global_state.tokenizer_manager.is_generation: + gri = GenerateReqInput( + input_ids=[0], sampling_params=sampling_params, log_metrics=False + ) + else: + gri = EmbeddingReqInput( + input_ids=[0], sampling_params=sampling_params, log_metrics=False + ) + + try: + async for _ in _global_state.tokenizer_manager.generate_request(gri, request): + break + return Response(status_code=200) + except Exception as e: + logger.exception(e) + return Response(status_code=503) + + +@app.get("/get_model_info") +async def get_model_info(): + """Get the model information.""" + result = { + "model_path": _global_state.tokenizer_manager.model_path, + "tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path, + "is_generation": _global_state.tokenizer_manager.is_generation, + } + return result + + +@app.get("/get_server_info") +async def get_server_info(): + return { + **dataclasses.asdict(_global_state.tokenizer_manager.server_args), + **_global_state.scheduler_info, + "version": __version__, + } + + +# fastapi implicitly converts json in the request to obj (dataclass) +@app.api_route("/generate", methods=["POST", "PUT"]) +async def generate_request(obj: GenerateReqInput, request: Request): + """Handle a generate request.""" + if obj.stream: + + async def stream_results() -> AsyncIterator[bytes]: + try: + async for out in _global_state.tokenizer_manager.generate_request( + obj, request + ): + yield b"data: " + orjson.dumps( + out, option=orjson.OPT_NON_STR_KEYS + ) + b"\n\n" + except ValueError as e: + out = {"error": {"message": str(e)}} + yield b"data: " + orjson.dumps( + out, option=orjson.OPT_NON_STR_KEYS + ) + b"\n\n" + yield b"data: [DONE]\n\n" + + return StreamingResponse( + stream_results(), + media_type="text/event-stream", + background=_global_state.tokenizer_manager.create_abort_task(obj), + ) + else: + try: + ret = await _global_state.tokenizer_manager.generate_request( + obj, request + ).__anext__() + return ret + except ValueError as e: + logger.error(f"Error: {e}") + return _create_error_response(e) + + +@app.api_route("/encode", methods=["POST", "PUT"]) +async def encode_request(obj: EmbeddingReqInput, request: Request): + """Handle an embedding request.""" + try: + ret = await _global_state.tokenizer_manager.generate_request( + obj, request + ).__anext__() + return ret + except ValueError as e: + return _create_error_response(e) + + +@app.api_route("/classify", methods=["POST", "PUT"]) +async def classify_request(obj: EmbeddingReqInput, request: Request): + """Handle a reward model request. Now the arguments and return values are the same as embedding models.""" + try: + ret = await _global_state.tokenizer_manager.generate_request( + obj, request + ).__anext__() + return ret + except ValueError as e: + return _create_error_response(e) + + +@app.post("/flush_cache") +async def flush_cache(): + """Flush the radix cache.""" + _global_state.tokenizer_manager.flush_cache() + return Response( + content="Cache flushed.\nPlease check backend logs for more details. " + "(When there are running or waiting requests, the operation will not be performed.)\n", + status_code=200, + ) + + +@app.api_route("/start_profile", methods=["GET", "POST"]) +async def start_profile_async(): + """Start profiling.""" + _global_state.tokenizer_manager.start_profile() + return Response( + content="Start profiling.\n", + status_code=200, + ) + + +@app.api_route("/stop_profile", methods=["GET", "POST"]) +async def stop_profile_async(): + """Stop profiling.""" + _global_state.tokenizer_manager.stop_profile() + return Response( + content="Stop profiling. This will take some time.\n", + status_code=200, + ) + + +@app.post("/update_weights_from_disk") +async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request): + """Update the weights from disk in-place without re-launching the server.""" + success, message = await _global_state.tokenizer_manager.update_weights_from_disk( + obj, request + ) + content = {"success": success, "message": message} + if success: + return ORJSONResponse( + content, + status_code=HTTPStatus.OK, + ) + else: + return ORJSONResponse( + content, + status_code=HTTPStatus.BAD_REQUEST, + ) + + +@app.post("/init_weights_update_group") +async def init_weights_update_group( + obj: InitWeightsUpdateGroupReqInput, request: Request +): + """Initialize the parameter update group.""" + success, message = await _global_state.tokenizer_manager.init_weights_update_group( + obj, request + ) + content = {"success": success, "message": message} + if success: + return ORJSONResponse(content, status_code=200) + else: + return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) + + +@app.post("/update_weights_from_distributed") +async def update_weights_from_distributed( + obj: UpdateWeightsFromDistributedReqInput, request: Request +): + """Update model parameter from distributed online.""" + success, message = ( + await _global_state.tokenizer_manager.update_weights_from_distributed( + obj, request + ) + ) + content = {"success": success, "message": message} + if success: + return ORJSONResponse(content, status_code=200) + else: + return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) + + +@app.api_route("/get_weights_by_name", methods=["GET", "POST"]) +async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request): + """Get model parameter by name.""" + try: + ret = await _global_state.tokenizer_manager.get_weights_by_name(obj, request) + if ret is None: + return _create_error_response("Get parameter by name failed") + else: + return ORJSONResponse(ret, status_code=200) + except Exception as e: + return _create_error_response(e) + + +@app.api_route("/release_memory_occupation", methods=["GET", "POST"]) +async def release_memory_occupation( + obj: ReleaseMemoryOccupationReqInput, request: Request +): + """Release GPU occupation temporarily""" + try: + await _global_state.tokenizer_manager.release_memory_occupation(obj, request) + except Exception as e: + return _create_error_response(e) + + +@app.api_route("/resume_memory_occupation", methods=["GET", "POST"]) +async def resume_memory_occupation( + obj: ResumeMemoryOccupationReqInput, request: Request +): + """Resume GPU occupation""" + try: + await _global_state.tokenizer_manager.resume_memory_occupation(obj, request) + except Exception as e: + return _create_error_response(e) + + +@app.api_route("/open_session", methods=["GET", "POST"]) +async def open_session(obj: OpenSessionReqInput, request: Request): + """Open a session, and return its unique session id.""" + try: + session_id = await _global_state.tokenizer_manager.open_session(obj, request) + if session_id is None: + raise Exception( + "Failed to open the session. Check if a session with the same id is still open." + ) + return session_id + except Exception as e: + return _create_error_response(e) + + +@app.api_route("/close_session", methods=["GET", "POST"]) +async def close_session(obj: CloseSessionReqInput, request: Request): + """Close the session""" + try: + await _global_state.tokenizer_manager.close_session(obj, request) + return Response(status_code=200) + except Exception as e: + return _create_error_response(e) + + +@app.api_route("/configure_logging", methods=["GET", "POST"]) +async def configure_logging(obj: ConfigureLoggingReq, request: Request): + """Close the session""" + _global_state.tokenizer_manager.configure_logging(obj) + return Response(status_code=200) + + +@app.post("/function_call") +async def function_call_request(obj: FunctionCallReqInput, request: Request): + """ + A native API endpoint to parse function calls from a text. + """ + # 1) Initialize the parser based on the request body + parser = FunctionCallParser(tools=obj.tools, tool_call_parser=obj.tool_call_parser) + + # 2) Call the non-stream parsing method (non-stream) + normal_text, calls = parser.parse_non_stream(obj.text) + + # 3) Organize the response content + response_data = { + "normal_text": normal_text, + "calls": [ + call.model_dump() for call in calls + ], # Convert pydantic objects to dictionaries + } + + return ORJSONResponse(content=response_data, status_code=200) + + +##### OpenAI-compatible API endpoints ##### + + +@app.post("/v1/completions") +async def openai_v1_completions(raw_request: Request): + return await v1_completions(_global_state.tokenizer_manager, raw_request) + + +@app.post("/v1/chat/completions") +async def openai_v1_chat_completions(raw_request: Request): + return await v1_chat_completions(_global_state.tokenizer_manager, raw_request) + + +@app.post("/v1/embeddings", response_class=ORJSONResponse) +async def openai_v1_embeddings(raw_request: Request): + response = await v1_embeddings(_global_state.tokenizer_manager, raw_request) + return response + + +@app.get("/v1/models", response_class=ORJSONResponse) +def available_models(): + """Show available models.""" + served_model_names = [_global_state.tokenizer_manager.served_model_name] + model_cards = [] + for served_model_name in served_model_names: + model_cards.append(ModelCard(id=served_model_name, root=served_model_name)) + return ModelList(data=model_cards) + + +@app.post("/v1/files") +async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")): + return await v1_files_create( + file, purpose, _global_state.tokenizer_manager.server_args.file_storage_pth + ) + + +@app.delete("/v1/files/{file_id}") +async def delete_file(file_id: str): + # https://platform.openai.com/docs/api-reference/files/delete + return await v1_delete_file(file_id) + + +@app.post("/v1/batches") +async def openai_v1_batches(raw_request: Request): + return await v1_batches(_global_state.tokenizer_manager, raw_request) + + +@app.post("/v1/batches/{batch_id}/cancel") +async def cancel_batches(batch_id: str): + # https://platform.openai.com/docs/api-reference/batch/cancel + return await v1_cancel_batch(_global_state.tokenizer_manager, batch_id) + + +@app.get("/v1/batches/{batch_id}") +async def retrieve_batch(batch_id: str): + return await v1_retrieve_batch(batch_id) + + +@app.get("/v1/files/{file_id}") +async def retrieve_file(file_id: str): + # https://platform.openai.com/docs/api-reference/files/retrieve + return await v1_retrieve_file(file_id) + + +@app.get("/v1/files/{file_id}/content") +async def retrieve_file_content(file_id: str): + # https://platform.openai.com/docs/api-reference/files/retrieve-contents + return await v1_retrieve_file_content(file_id) + + +def _create_error_response(e): + return ORJSONResponse( + {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST + ) + + +def launch_server( + server_args: ServerArgs, + pipe_finish_writer: Optional[multiprocessing.connection.Connection] = None, +): + """ + Launch SRT (SGLang Runtime) Server. + + The SRT server consists of an HTTP server and an SRT engine. + + - HTTP server: A FastAPI server that routes requests to the engine. + - The engine consists of three components: + 1. TokenizerManager: Tokenizes the requests and sends them to the scheduler. + 2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager. + 3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager. + + Note: + 1. The HTTP server, Engine, and TokenizerManager both run in the main process. + 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library. + """ + tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args) + set_global_state( + _GlobalState( + tokenizer_manager=tokenizer_manager, + scheduler_info=scheduler_info, + ) + ) + + # Add api key authorization + if server_args.api_key: + add_api_key_middleware(app, server_args.api_key) + + # Add prometheus middleware + if server_args.enable_metrics: + add_prometheus_middleware(app) + enable_func_timer() + + # Send a warmup request + t = threading.Thread( + target=_wait_and_warmup, + args=( + server_args, + pipe_finish_writer, + _global_state.tokenizer_manager.image_token_id, + ), + ) + t.start() + + try: + # Update logging configs + set_uvicorn_logging_configs() + + # Listen for HTTP requests + uvicorn.run( + app, + host=server_args.host, + port=server_args.port, + log_level=server_args.log_level_http or server_args.log_level, + timeout_keep_alive=5, + loop="uvloop", + ) + finally: + t.join() + + +def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text): + headers = {} + url = server_args.url() + if server_args.api_key: + headers["Authorization"] = f"Bearer {server_args.api_key}" + + # Wait until the server is launched + success = False + for _ in range(120): + time.sleep(1) + try: + res = requests.get(url + "/get_model_info", timeout=5, headers=headers) + assert res.status_code == 200, f"{res=}, {res.text=}" + success = True + break + except (AssertionError, requests.exceptions.RequestException): + last_traceback = get_exception_traceback() + pass + + if not success: + if pipe_finish_writer is not None: + pipe_finish_writer.send(last_traceback) + logger.error(f"Initialization failed. warmup error: {last_traceback}") + kill_process_tree(os.getpid()) + return + + model_info = res.json() + + # Send a warmup request + request_name = "/generate" if model_info["is_generation"] else "/encode" + max_new_tokens = 8 if model_info["is_generation"] else 1 + json_data = { + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + }, + } + if server_args.skip_tokenizer_init: + json_data["input_ids"] = [10, 11, 12] + else: + json_data["text"] = "The capital city of France is" + + try: + for _ in range(server_args.dp_size): + res = requests.post( + url + request_name, + json=json_data, + headers=headers, + timeout=600, + ) + assert res.status_code == 200, f"{res}" + except Exception: + last_traceback = get_exception_traceback() + if pipe_finish_writer is not None: + pipe_finish_writer.send(last_traceback) + logger.error(f"Initialization failed. warmup error: {last_traceback}") + kill_process_tree(os.getpid()) + return + + # Debug print + # logger.info(f"{res.json()=}") + + logger.info("The server is fired up and ready to roll!") + if pipe_finish_writer is not None: + pipe_finish_writer.send("ready") + + if server_args.delete_ckpt_after_loading: + delete_directory(server_args.model_path) diff --git a/python/sglang/srt/function_call_parser.py b/python/sglang/srt/function_call_parser.py new file mode 100644 index 00000000000..3def4e1eb27 --- /dev/null +++ b/python/sglang/srt/function_call_parser.py @@ -0,0 +1,494 @@ +import json +import re +from abc import ABC, abstractmethod +from json import JSONDecodeError, JSONDecoder +from typing import Any, Dict, List, Optional, Tuple + +import partial_json_parser +from partial_json_parser.core.options import Allow +from pydantic import BaseModel, Field + +TOOLS_TAG_LIST = [ + "<|plugin|>", + "", + "<|python_tag|>", + "[TOOL_CALLS]", +] + + +class Function(BaseModel): + """Function Tool Template.""" + + description: Optional[str] = Field(default=None, examples=[None]) + name: Optional[str] = None + parameters: Optional[object] = None + + +class ToolCallItem(BaseModel): + """Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts.""" + + tool_index: int + name: Optional[str] = None + parameters: str # JSON string + + +def _find_common_prefix(s1: str, s2: str) -> str: + prefix = "" + min_length = min(len(s1), len(s2)) + for i in range(0, min_length): + if s1[i] == s2[i]: + prefix += s1[i] + else: + break + return prefix + + +def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]: + try: + return (partial_json_parser.loads(input_str, flags), len(input_str)) + except JSONDecodeError as e: + if "Extra data" in e.msg: + dec = JSONDecoder() + return dec.raw_decode(input_str) + raise + + +def _is_complete_json(input_str: str) -> bool: + try: + json.loads(input_str) + return True + except JSONDecodeError: + return False + + +class StreamingParseResult: + """Result of streaming incremental parsing.""" + + def __init__( + self, normal_text: str = "", calls: Optional[List[ToolCallItem]] = None + ): + self.normal_text = normal_text + self.calls = calls or [] + + +class BaseFormatDetector: + """Base class providing two sets of interfaces: one-time and streaming incremental.""" + + def __init__(self): + # initialize properties used for state when parsing tool calls in + self._buffer = "" + # streaming mode + self.prev_tool_call_arr: List[Dict] = [] + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.streamed_args_for_tool: List[str] = ( + [] + ) # map what has been streamed for each tool so far to a list + self.bot_token = "" + self.eot_token = "" + + def parse_base_json(self, action: Dict, tools: List[Function]): + name, parameters = action["name"], json.dumps( + action.get("parameters", action.get("arguments", {})), + ensure_ascii=False, + ) + tool_index = [tool.function.name for tool in tools].index(name) + tool_call_item = ToolCallItem( + tool_index=tool_index, name=name, parameters=parameters + ) + calls = [tool_call_item] + return calls + + def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]: + """ + Parses the text in one go. Returns success=True if the format matches, otherwise False. + Note that leftover_text here represents "content that this parser will not consume further". + """ + action = json.loads(text) + return self.parse_base_json(action, tools) + + def parse_streaming_increment( + self, new_text: str, tools: List[Function] + ) -> StreamingParseResult: + """ + Streaming incremental parsing, referencing the logic of Llama32Detector. + We partially parse JSON within ..., and handle + incremental argument output. + """ + # Append new text to buffer + self._buffer += new_text + current_text = self._buffer + if not (self.bot_token in current_text or current_text.startswith("{")): + self._buffer = "" + if self.eot_token in new_text: + new_text = new_text.replace(self.eot_token, "") + return StreamingParseResult(normal_text=new_text) + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR + try: + tool_call_arr = [] + is_complete = [] + try: + # depending on the prompt format the Llama model may or may not + # prefix the output with the <|python_tag|> token + start_idx = ( + len(self.bot_token) + if current_text.startswith(self.bot_token) + else 0 + ) + while start_idx < len(current_text): + (obj, end_idx) = _partial_json_loads( + current_text[start_idx:], flags + ) + is_complete.append( + _is_complete_json(current_text[start_idx : start_idx + end_idx]) + ) + start_idx += end_idx + len("; ") + # depending on the prompt Llama can use + # either arguments or parameters + if "parameters" in obj: + assert ( + "arguments" not in obj + ), "model generated both parameters and arguments" + obj["arguments"] = obj["parameters"] + tool_call_arr.append(obj) + + except partial_json_parser.core.exceptions.MalformedJSON: + # not enough tokens to parse into JSON yet + return StreamingParseResult() + + # select as the current tool call the one we're on the state at + current_tool_call: Dict = ( + tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {} + ) + + # case -- if no tokens have been streamed for the tool, e.g. + # only the array brackets, stream nothing + if len(tool_call_arr) == 0: + return StreamingParseResult() + + # case: we are starting a new tool in the array + # -> array has > 0 length AND length has moved past cursor + elif ( + len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1 + ): + + # if we're moving on to a new call, first make sure we + # haven't missed anything in the previous one that was + # auto-generated due to JSON completions, but wasn't + # streamed to the client yet. + if self.current_tool_id >= 0: + cur_arguments = current_tool_call.get("arguments") + if cur_arguments: + cur_args_json = json.dumps(cur_arguments) + sent = len(self.streamed_args_for_tool[self.current_tool_id]) + argument_diff = cur_args_json[sent:] + + res = StreamingParseResult( + normal_text=None, + calls=[ + ToolCallItem( + tool_index=self.current_tool_id, + name="", + parameters=argument_diff, + ) + ], + ) + self.streamed_args_for_tool[ + self.current_tool_id + ] += argument_diff + else: + res = StreamingParseResult() + else: + res = StreamingParseResult() + # re-set stuff pertaining to progress in the current tool + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + print("starting on new tool %d", self.current_tool_id) + return res + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + elif not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + if function_name: + res = StreamingParseResult( + normal_text=None, + calls=[ + ToolCallItem( + tool_index=self.current_tool_id, + name=function_name, + parameters="", + ) + ], + ) + self.current_tool_name_sent = True + else: + res = StreamingParseResult() + + # now we know we're on the same tool call and we're streaming + # arguments + else: + cur_arguments = current_tool_call.get("arguments") + res = StreamingParseResult() + + if cur_arguments: + sent = len(self.streamed_args_for_tool[self.current_tool_id]) + cur_args_json = json.dumps(cur_arguments) + prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments" + ) + + argument_diff = None + if is_complete[self.current_tool_id]: + argument_diff = cur_args_json[sent:] + self._buffer = "" + self.prev_tool_call_arr[self.current_tool_id].clear() + self.current_tool_name_sent: bool = False + self.streamed_args_for_tool[self.current_tool_id] = "" + + elif prev_arguments: + prev_args_json = json.dumps(prev_arguments) + if cur_args_json != prev_args_json: + + prefix = _find_common_prefix(prev_args_json, cur_args_json) + argument_diff = prefix[sent:] + + if argument_diff is not None: + res = StreamingParseResult( + calls=[ + ToolCallItem( + tool_index=self.current_tool_id, + name="", + parameters=argument_diff, + ) + ], + ) + if not is_complete[self.current_tool_id]: + self.streamed_args_for_tool[ + self.current_tool_id + ] += argument_diff + + self.prev_tool_call_arr = tool_call_arr + return res + + except Exception as e: + print(e) + # Skipping chunk as a result of tool streaming extraction error + return StreamingParseResult() + + +class Qwen25Detector(BaseFormatDetector): + """ + Detector for Qwen 2.5 models. + Assumes function call format: + {"name":"xxx", "arguments":{...}} + """ + + def __init__(self): + """ + Initializes the detector with necessary state variables. + """ + super().__init__() + self.bot_token = "" + self.eot_token = "" + + def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]: + """ + One-time parsing: Detects and parses tool calls in the provided text. + + :param text: The complete text to parse. + :param tools: List of available tools. + :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls. + """ + if "" not in text: + return [] + pattern = r"(.*?)" + match_result_list = re.findall(pattern, text, re.DOTALL) + calls = [] + for match_result in match_result_list: + match_result = json.loads(match_result) + calls.extend(self.parse_base_json(match_result, tools)) + return calls + + +class MistralDetector(BaseFormatDetector): + """ + Detector for Mistral models. + Assumes function call format: + <|action_start|><|plugin|>{"name":"xxx", "arguments":{...}}<|action_end|> + """ + + def __init__(self): + """ + Initializes the detector with necessary state variables. + """ + super().__init__() + self.bot_token = "[TOOL_CALLS] [" + self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL) + + def _clean_text(self, text: str) -> str: + """ + clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]' + for example, + text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.' + return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]' + The key pattern is [TOOL_CALLS] [...] + """ + find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL) + if len(find_results) > 0: + return find_results[0] + else: + return "" + + def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]: + """ + One-time parsing: Detects and parses tool calls in the provided text. + + :param text: The complete text to parse. + :param tools: List of available tools. + :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls. + """ + text = self._clean_text(text) + tool_content = text.replace("[TOOL_CALLS]", "").strip() + raw_tool_calls = self.tool_call_regex.findall(tool_content) + calls = [] + if len(raw_tool_calls) > 0: + raw_tool_call = raw_tool_calls[0] + function_call_arr = json.loads(raw_tool_call) + for match_result in function_call_arr: + calls.extend(self.parse_base_json(match_result, tools)) + return calls + + +class Llama32Detector(BaseFormatDetector): + """ + Detector for Llama 3.2 models. + Assumes function call format: + <|python_tag|>{"name":"xxx", "arguments":{...}} + Does not require a closing tag "", + relies on json.loads(...) success to determine if JSON is complete. + """ + + def __init__(self): + """ + Initializes the detector with necessary state variables. + """ + super().__init__() + self.bot_token = "<|python_tag|>" + + def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]: + """ + One-time parsing: Detects and parses tool calls in the provided text. + + :param text: The complete text to parse. + :param tools: List of available tools. + :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls. + """ + + if "<|python_tag|>" not in text: + return [] + _, action = text.split("<|python_tag|>") + action = json.loads(action) + return self.parse_base_json(action, tools) + + +class MultiFormatParser: + def __init__(self, detectors: List[BaseFormatDetector]): + """ + :param detectors: A series of available Detector instances passed in + """ + self.detectors = detectors + + def parse_once(self, text: str, tools: List[Function]): + """ + One-time parsing: Loop through detectors until there are no new matches or text is exhausted + Return: (final_text, all_calls) + - final_text: The remaining text after parsing that was not consumed by any Detector (can be treated as normal text) + - all_calls: All calls parsed by the Detectors + """ + final_calls = [] + final_normal_text = text + for detector in self.detectors: + tool_call_list = detector.detect_and_parse(text, tools) + if len(tool_call_list) > 0: # parsed successfully + final_calls = tool_call_list + break + + # leftover_text is the normal text not consumed by any Detector + return final_normal_text, final_calls + + def parse_streaming_increment(self, new_text: str, tools: List[Function]): + """ + Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment + and merge their produced normal_text/calls to return. + (The logic here can be "priority-based" or "parallel parsing" based on your needs) + """ + final_normal_text = "" + final_calls = [] + + for detector in self.detectors: + sp_result = detector.parse_streaming_increment(new_text, tools) + # Merge normal_text and calls + # If one sp_result contains result call, this should be a successful parse + # If one sp_result only contains normal_text, this can either be a successful + # parse or it is not using the desired parsing tool. + if sp_result.normal_text: + final_normal_text = sp_result.normal_text + if sp_result.calls: + final_calls.extend(sp_result.calls) + final_normal_text = sp_result.normal_text + break + + return final_normal_text, final_calls + + +class FunctionCallParser: + """ + In streaming scenarios, each time new_text is received, it calls multi_format_parser.parse_streaming_increment + and returns the resulting normal_text and calls to the upper layer (or SSE). + """ + + ToolCallParserEnum: Dict[str, BaseFormatDetector] = { + "llama3": Llama32Detector, + "qwen25": Qwen25Detector, + "mistral": MistralDetector, + } + + def __init__(self, tools: List[Function], tool_call_parser: str = None): + detectors = [] + if tool_call_parser: + detector_class = self.ToolCallParserEnum.get(tool_call_parser) + if detector_class: + detectors.append(detector_class()) + else: + raise ValueError(f"Unsupported tool_call_parser: {tool_call_parser}") + else: + raise ValueError("Tool Call Parser Not Given!") + + self.multi_format_parser = MultiFormatParser(detectors) + self.tools = tools + + def parse_non_stream(self, full_text: str): + """ + Non-streaming call: one-time parsing + """ + full_normal_text, calls = self.multi_format_parser.parse_once( + full_text, self.tools + ) + return full_normal_text, calls + + def parse_stream_chunk(self, chunk_text: str): + """ + Streaming call: incremental parsing + """ + normal_text, calls = self.multi_format_parser.parse_streaming_increment( + chunk_text, self.tools + ) + return normal_text, calls diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 92b01d4524f..ea39d73f2ee 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -30,20 +30,15 @@ ) from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES -try: - from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig - - from sglang.srt.configs import ExaoneConfig, Qwen2VLConfig - - _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { - ChatGLMConfig.model_type: ChatGLMConfig, - DbrxConfig.model_type: DbrxConfig, - ExaoneConfig.model_type: ExaoneConfig, - Qwen2VLConfig.model_type: Qwen2VLConfig, - } -except ImportError: - # We want this file to run without vllm dependency - _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {} +from sglang.srt.configs import ChatGLMConfig, DbrxConfig, ExaoneConfig, Qwen2VLConfig + +_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { + ChatGLMConfig.model_type: ChatGLMConfig, + DbrxConfig.model_type: DbrxConfig, + ExaoneConfig.model_type: ExaoneConfig, + Qwen2VLConfig.model_type: Qwen2VLConfig, +} + for name, cls in _CONFIG_REGISTRY.items(): with contextlib.suppress(ValueError): diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index c4c54f0b03c..d69d854ab2e 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -20,18 +20,18 @@ import torch.nn as nn import torch.nn.functional as F -from sglang.srt.utils import is_flashinfer_available +from sglang.srt.utils import is_cuda_available -if is_flashinfer_available(): - from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul +if is_cuda_available(): + from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul -from vllm.distributed import ( +from vllm.model_executor.custom_op import CustomOp + +from sglang.srt.distributed import ( divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from vllm.model_executor.custom_op import CustomOp - from sglang.srt.layers.custom_op_util import register_custom_op from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.utils import set_weight_attrs @@ -149,8 +149,8 @@ def get_act_fn( return act_fn -if not is_flashinfer_available(): +if not is_cuda_available(): logger.info( - "FlashInfer is not available on Non-NV platforms. Fallback to other kernel libraries." + "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries." ) from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul diff --git a/python/sglang/srt/layers/attention/__init__.py b/python/sglang/srt/layers/attention/__init__.py index a70e9537bfe..74559864302 100644 --- a/python/sglang/srt/layers/attention/__init__.py +++ b/python/sglang/srt/layers/attention/__init__.py @@ -1,11 +1,14 @@ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import Optional +from typing import TYPE_CHECKING, Optional import torch -from torch import nn -from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.forward_batch_info import ForwardBatch +if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode + from sglang.srt.speculative.spec_info import SpecInfo class AttentionBackend(ABC): @@ -23,9 +26,12 @@ def init_cuda_graph_state(self, max_bs: int): def init_forward_metadata_capture_cuda_graph( self, bs: int, + num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, - encoder_lens: Optional[torch.Tensor] = None, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInfo], ): """Init the metadata for a forward pass for capturing a cuda graph.""" raise NotImplementedError() @@ -36,7 +42,9 @@ def init_forward_metadata_replay_cuda_graph( req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, - encoder_lens: Optional[torch.Tensor] = None, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInfo], ): """Init the metadata for a forward pass for replying a cuda graph.""" raise NotImplementedError() @@ -58,7 +66,14 @@ def forward( if forward_batch.forward_mode.is_decode(): return self.forward_decode(q, k, v, layer, forward_batch, save_kv_cache) else: - return self.forward_extend(q, k, v, layer, forward_batch, save_kv_cache) + return self.forward_extend( + q, + k, + v, + layer, + forward_batch, + save_kv_cache, + ) def forward_decode( self, diff --git a/python/sglang/srt/layers/attention/double_sparsity_backend.py b/python/sglang/srt/layers/attention/double_sparsity_backend.py index 856aa984c38..a5e54f32d51 100644 --- a/python/sglang/srt/layers/attention/double_sparsity_backend.py +++ b/python/sglang/srt/layers/attention/double_sparsity_backend.py @@ -3,7 +3,6 @@ from typing import TYPE_CHECKING import torch -import torch.nn as nn from sglang.srt.layers.attention import AttentionBackend from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -52,8 +51,6 @@ def __init__(self, model_runner: ModelRunner): self.forward_metadata = None - self.cuda_graph_max_seq_len = model_runner.model_config.context_len - def init_forward_metadata(self, forward_batch: ForwardBatch): """Init auxiliary variables for triton attention backend.""" @@ -115,55 +112,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): ds_req_to_token, ) - def init_cuda_graph_state(self, max_bs: int): - # TODO(Andy): Support CUDA graph for double sparse attention - raise ValueError( - "Double sparse attention does not support CUDA graph for now. Please --disable-cuda-graph" - ) - self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len - - self.cuda_graph_start_loc = torch.zeros( - (max_bs,), dtype=torch.int32, device="cuda" - ) - self.cuda_graph_attn_logits = torch.empty( - ( - self.num_head, - self.cuda_graph_max_total_num_tokens, - ), - dtype=self.reduce_dtype, - device="cuda", - ) - - def init_forward_metadata_capture_cuda_graph( - self, - bs: int, - req_pool_indices: torch.Tensor, - seq_lens: torch.Tensor, - encoder_lens=None, - ): - # NOTE: encoder_lens expected to be zeros or None - self.forward_metadata = ( - self.cuda_graph_start_loc, - self.cuda_graph_attn_logits, - self.cuda_graph_max_seq_len, - None, - ) - - def init_forward_metadata_replay_cuda_graph( - self, - bs: int, - req_pool_indices: torch.Tensor, - seq_lens: torch.Tensor, - seq_lens_sum: int, - encoder_lens=None, - ): - # NOTE: encoder_lens expected to be zeros or None - self.cuda_graph_start_loc.zero_() - self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) - - def get_cuda_graph_seq_len_fill_value(self): - return 1 - def forward_extend( self, q, diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index f89bc2ccaa2..cc6da781f56 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -8,8 +8,9 @@ """ import os +from dataclasses import dataclass from enum import Enum, auto -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Optional, Union import torch import triton @@ -17,16 +18,14 @@ from sglang.global_config import global_config from sglang.srt.layers.attention import AttentionBackend -from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.utils import ( - get_bool_env_var, - is_flashinfer_available, - should_use_tensor_core, -) +from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.utils import is_flashinfer_available if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.speculative.spec_info import SpecInfo if is_flashinfer_available(): from flashinfer import ( @@ -42,21 +41,33 @@ class WrapperDispatch(Enum): CROSS_ATTENTION = auto() +@dataclass +class DecodeMetadata: + decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper] + + +@dataclass +class PrefillMetadata: + prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper] + use_ragged: bool + extend_no_prefix: bool + + class FlashInferAttnBackend(AttentionBackend): """Flashinfer attention kernels.""" def __init__(self, model_runner: ModelRunner): super().__init__() + # Parse constants self.decode_use_tensor_cores = should_use_tensor_core( kv_cache_dtype=model_runner.kv_cache_dtype, num_attention_heads=model_runner.model_config.num_attention_heads - // model_runner.tp_size, + // get_attention_tp_size(), num_kv_heads=model_runner.model_config.get_num_kv_heads( - model_runner.tp_size + get_attention_tp_size() ), ) - self.max_context_len = model_runner.model_config.context_len assert not ( @@ -74,6 +85,10 @@ def __init__(self, model_runner: ModelRunner): self.num_wrappers = 1 self.dispatch_reason = None + # Qwen2 models require higher flashinfer workspace size + if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures: + global_config.flashinfer_workspace_size = 512 * 1024 * 1024 + # Allocate buffers self.workspace_buffer = torch.empty( global_config.flashinfer_workspace_size, @@ -104,11 +119,15 @@ def __init__(self, model_runner: ModelRunner): # Two wrappers: one for sliding window attention and one for full attention. # Using two wrappers is unnecessary in the current PR, but are prepared for future PRs self.prefill_wrappers_paged = [] + self.prefill_wrappers_verify = [] self.decode_wrappers = [] for _ in range(self.num_wrappers): self.prefill_wrappers_paged.append( BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") ) + self.prefill_wrappers_verify.append( + BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") + ) self.decode_wrappers.append( BatchDecodeWithPagedKVCacheWrapper( self.workspace_buffer, @@ -124,19 +143,49 @@ def __init__(self, model_runner: ModelRunner): ) # Other metadata - self.forward_metadata = None - self.cuda_graph_metadata = {} + self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None + self.decode_cuda_graph_metadata = {} + self.prefill_cuda_graph_metadata = {} def init_forward_metadata(self, forward_batch: ForwardBatch): - if forward_batch.forward_mode.is_decode(): + if forward_batch.forward_mode.is_decode_or_idle(): self.indices_updater_decode.update( forward_batch.req_pool_indices, forward_batch.seq_lens, forward_batch.seq_lens_sum, - decode_wrappers=None, + decode_wrappers=self.decode_wrappers, + encoder_lens=forward_batch.encoder_lens, + spec_info=forward_batch.spec_info, + ) + self.forward_metadata = DecodeMetadata(self.decode_wrappers) + elif forward_batch.forward_mode.is_draft_extend(): + self.indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + prefix_lens=None, + prefill_wrappers=self.prefill_wrappers_paged, + use_ragged=False, encoder_lens=forward_batch.encoder_lens, + spec_info=forward_batch.spec_info, + ) + self.forward_metadata = PrefillMetadata( + self.prefill_wrappers_paged, False, False + ) + elif forward_batch.forward_mode.is_target_verify(): + self.indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + prefix_lens=None, + prefill_wrappers=self.prefill_wrappers_verify, + use_ragged=False, + encoder_lens=forward_batch.encoder_lens, + spec_info=forward_batch.spec_info, + ) + self.forward_metadata = PrefillMetadata( + self.prefill_wrappers_verify, False, False ) - self.forward_metadata = (self.decode_wrappers,) else: prefix_lens = forward_batch.extend_prefix_lens @@ -153,11 +202,14 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): forward_batch.seq_lens, forward_batch.seq_lens_sum, prefix_lens, + prefill_wrappers=self.prefill_wrappers_paged, use_ragged=use_ragged, encoder_lens=forward_batch.encoder_lens, + spec_info=None, + ) + self.forward_metadata = PrefillMetadata( + self.prefill_wrappers_paged, use_ragged, extend_no_prefix ) - - self.forward_metadata = (use_ragged, extend_no_prefix) def init_cuda_graph_state(self, max_bs: int): cuda_graph_kv_indices = torch.zeros( @@ -169,37 +221,82 @@ def init_cuda_graph_state(self, max_bs: int): cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1) ] + self.cuda_graph_custom_mask = torch.zeros( + (max_bs * self.max_context_len), + dtype=torch.uint8, + device="cuda", + ) + self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr] + self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr] + def init_forward_metadata_capture_cuda_graph( self, bs: int, + num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, - encoder_lens: torch.Tensor = None, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInfo], ): - decode_wrappers = [] - for i in range(self.num_wrappers): - decode_wrappers.append( - BatchDecodeWithPagedKVCacheWrapper( - self.workspace_buffer, - "NHD", - use_cuda_graph=True, - use_tensor_cores=self.decode_use_tensor_cores, - paged_kv_indptr_buffer=self.kv_indptr[i][: bs + 1], - paged_kv_indices_buffer=self.cuda_graph_kv_indices[i], - paged_kv_last_page_len_buffer=self.kv_last_page_len[:bs], + if forward_mode.is_decode_or_idle(): + decode_wrappers = [] + for i in range(self.num_wrappers): + decode_wrappers.append( + BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + use_cuda_graph=True, + use_tensor_cores=self.decode_use_tensor_cores, + paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1], + paged_kv_indices_buffer=self.cuda_graph_kv_indices[i], + paged_kv_last_page_len_buffer=self.kv_last_page_len[ + :num_tokens + ], + ) ) + seq_lens_sum = seq_lens.sum().item() + self.indices_updater_decode.update( + req_pool_indices, + seq_lens, + seq_lens_sum, + decode_wrappers=decode_wrappers, + encoder_lens=encoder_lens, + spec_info=spec_info, ) - - seq_lens_sum = seq_lens.sum().item() - self.indices_updater_decode.update( - req_pool_indices, - seq_lens, - seq_lens_sum, - decode_wrappers=decode_wrappers, - encoder_lens=encoder_lens, - ) - self.cuda_graph_metadata[bs] = decode_wrappers - self.forward_metadata = (decode_wrappers,) + self.decode_cuda_graph_metadata[bs] = decode_wrappers + self.forward_metadata = DecodeMetadata(decode_wrappers) + elif forward_mode.is_target_verify(): + prefill_wrappers = [] + for i in range(self.num_wrappers): + prefill_wrappers.append( + BatchPrefillWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + use_cuda_graph=True, + qo_indptr_buf=self.cuda_graph_qo_indptr[i][: bs + 1], + paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1], + paged_kv_indices_buf=self.cuda_graph_kv_indices[i], + paged_kv_last_page_len_buf=self.kv_last_page_len[:bs], + custom_mask_buf=self.cuda_graph_custom_mask, + qk_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1], + ) + ) + seq_lens_sum = seq_lens.sum().item() + self.indices_updater_prefill.update( + req_pool_indices, + seq_lens, + seq_lens_sum, + prefix_lens=None, + prefill_wrappers=prefill_wrappers, + use_ragged=False, + encoder_lens=encoder_lens, + spec_info=spec_info, + ) + self.prefill_cuda_graph_metadata[bs] = prefill_wrappers + self.forward_metadata = PrefillMetadata(prefill_wrappers, False, False) + else: + raise ValueError(f"Invalid mode: {forward_mode=}") def init_forward_metadata_replay_cuda_graph( self, @@ -207,44 +304,63 @@ def init_forward_metadata_replay_cuda_graph( req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, - encoder_lens: torch.Tensor = None, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInfo], ): - self.indices_updater_decode.update( - req_pool_indices[:bs], - seq_lens[:bs], - seq_lens_sum, - decode_wrappers=self.cuda_graph_metadata[bs], - encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None, - ) + if forward_mode.is_decode_or_idle(): + self.indices_updater_decode.update( + req_pool_indices[:bs], + seq_lens[:bs], + seq_lens_sum, + decode_wrappers=self.decode_cuda_graph_metadata[bs], + encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None, + spec_info=spec_info, + ) + elif forward_mode.is_target_verify(): + self.indices_updater_prefill.update( + req_pool_indices[:bs], + seq_lens[:bs], + seq_lens_sum, + prefix_lens=None, + prefill_wrappers=self.prefill_cuda_graph_metadata[bs], + use_ragged=False, + encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None, + spec_info=spec_info, + ) + else: + raise ValueError("Invalid forward mode") def get_cuda_graph_seq_len_fill_value(self): return 0 def forward_extend( self, - q, - k, - v, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache=True, ): - prefill_wrapper_paged = self.prefill_wrappers_paged[ + prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[ self._get_wrapper_idx(layer) ] - - use_ragged, extend_no_prefix = self.forward_metadata cache_loc = ( forward_batch.out_cache_loc if not layer.is_cross_attention else forward_batch.encoder_out_cache_loc ) - if not use_ragged: + logits_soft_cap = layer.logit_cap + + if not self.forward_metadata.use_ragged: if k is not None: assert v is not None if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) o = prefill_wrapper_paged.forward( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), @@ -252,7 +368,9 @@ def forward_extend( causal=not layer.is_cross_attention, sm_scale=layer.scaling, window_left=layer.sliding_window_size, - logits_soft_cap=layer.logit_cap, + logits_soft_cap=logits_soft_cap, + k_scale=layer.k_scale, + v_scale=layer.v_scale, ) else: o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( @@ -261,10 +379,10 @@ def forward_extend( v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim), causal=True, sm_scale=layer.scaling, - logits_soft_cap=layer.logit_cap, + logits_soft_cap=logits_soft_cap, ) - if extend_no_prefix: + if self.forward_metadata.extend_no_prefix: o = o1 else: o2, s2 = prefill_wrapper_paged.forward_return_lse( @@ -278,20 +396,24 @@ def forward_extend( o, _ = merge_state(o1, s1, o2, s2) if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) return o.view(-1, layer.tp_q_head_num * layer.head_dim) def forward_decode( self, - q, - k, - v, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache=True, ): - decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)] + decode_wrapper = self.forward_metadata.decode_wrappers[ + self._get_wrapper_idx(layer) + ] cache_loc = ( forward_batch.out_cache_loc if not layer.is_cross_attention @@ -301,13 +423,17 @@ def forward_decode( if k is not None: assert v is not None if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) o = decode_wrapper.forward( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), sm_scale=layer.scaling, logits_soft_cap=layer.logit_cap, + k_scale=layer.k_scale, + v_scale=layer.v_scale, ) return o.view(-1, layer.tp_q_head_num * layer.head_dim) @@ -326,27 +452,25 @@ def _get_wrapper_idx(self, layer: RadixAttention): class FlashInferIndicesUpdaterDecode: def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): - # Constants + # Parse Constants self.num_qo_heads = ( - model_runner.model_config.num_attention_heads // model_runner.tp_size + model_runner.model_config.num_attention_heads // get_attention_tp_size() ) self.num_kv_heads = model_runner.model_config.get_num_kv_heads( - model_runner.tp_size + get_attention_tp_size() ) self.head_dim = model_runner.model_config.head_dim self.data_type = model_runner.kv_cache_dtype self.q_data_type = model_runner.dtype self.sliding_window_size = model_runner.sliding_window_size - self.attn_backend = attn_backend # Buffers and wrappers self.kv_indptr = attn_backend.kv_indptr self.kv_last_page_len = attn_backend.kv_last_page_len self.req_to_token = model_runner.req_to_token_pool.req_to_token - self.decode_wrappers = attn_backend.decode_wrappers - # Dispatch + # Dispatch the update function if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: self.update = self.update_sliding_window elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION: @@ -360,8 +484,9 @@ def update( req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, - decode_wrappers: List, - encoder_lens: torch.Tensor, + decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], + encoder_lens: Optional[torch.Tensor], + spec_info: Optional[SpecInfo], ): # Keep the signature for type checking. It will be assigned during runtime. raise NotImplementedError() @@ -371,8 +496,9 @@ def update_single_wrapper( req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, - decode_wrappers: List, - encoder_lens: torch.Tensor, + decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], + encoder_lens: Optional[torch.Tensor], + spec_info: Optional[SpecInfo], ): decode_wrappers = decode_wrappers or self.decode_wrappers self.call_begin_forward( @@ -382,6 +508,7 @@ def update_single_wrapper( seq_lens_sum, self.kv_indptr[0], None, + spec_info, ) def update_sliding_window( @@ -389,11 +516,10 @@ def update_sliding_window( req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, - decode_wrappers: List, - encoder_lens: torch.Tensor, + decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], + encoder_lens: Optional[torch.Tensor], + spec_info: Optional[SpecInfo], ): - decode_wrappers = decode_wrappers or self.decode_wrappers - for wrapper_id in range(2): if wrapper_id == 0: # Sliding window attention @@ -416,6 +542,7 @@ def update_sliding_window( paged_kernel_lens_sum_tmp, self.kv_indptr[wrapper_id], kv_start_idx_tmp, + spec_info, ) def update_cross_attention( @@ -423,11 +550,10 @@ def update_cross_attention( req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, - decode_wrappers: List, - encoder_lens: torch.Tensor, + decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], + encoder_lens: Optional[torch.Tensor], + spec_info: Optional[SpecInfo], ): - decode_wrappers = decode_wrappers or self.decode_wrappers - for wrapper_id in range(2): if wrapper_id == 0: # Normal attention @@ -446,33 +572,41 @@ def update_cross_attention( seq_lens_sum, self.kv_indptr[wrapper_id], kv_start_idx, + spec_info, ) def call_begin_forward( self, - wrapper, + wrapper: BatchDecodeWithPagedKVCacheWrapper, req_pool_indices: torch.Tensor, paged_kernel_lens: torch.Tensor, paged_kernel_lens_sum: int, kv_indptr: torch.Tensor, kv_start_idx: torch.Tensor, + spec_info: Optional[SpecInfo], ): - bs = len(req_pool_indices) - kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) - kv_indptr = kv_indptr[: bs + 1] - kv_indices = torch.empty( - paged_kernel_lens_sum, dtype=torch.int32, device="cuda" - ) - - create_flashinfer_kv_indices_triton[(bs,)]( - self.req_to_token, - req_pool_indices, - paged_kernel_lens, - kv_indptr, - kv_start_idx, - kv_indices, - self.req_to_token.shape[1], - ) + if spec_info is None: + bs = len(req_pool_indices) + kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty( + paged_kernel_lens_sum, dtype=torch.int32, device="cuda" + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + paged_kernel_lens, + kv_indptr, + kv_start_idx, + kv_indices, + self.req_to_token.shape[1], + ) + else: + bs, kv_indices, kv_indptr = spec_info.generate_attn_arg_decode( + req_pool_indices, + paged_kernel_lens, + self.req_to_token, + ) wrapper.end_forward() wrapper.begin_forward( @@ -490,18 +624,17 @@ def call_begin_forward( class FlashInferIndicesUpdaterPrefill: def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): - # Constants + # Parse Constants self.num_qo_heads = ( - model_runner.model_config.num_attention_heads // model_runner.tp_size + model_runner.model_config.num_attention_heads // get_attention_tp_size() ) self.num_kv_heads = model_runner.model_config.get_num_kv_heads( - model_runner.tp_size + get_attention_tp_size() ) self.head_dim = model_runner.model_config.head_dim self.data_type = model_runner.kv_cache_dtype self.q_data_type = model_runner.dtype self.sliding_window_size = model_runner.sliding_window_size - self.attn_backend = attn_backend # Buffers and wrappers @@ -509,10 +642,9 @@ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): self.kv_last_page_len = attn_backend.kv_last_page_len self.qo_indptr = attn_backend.qo_indptr self.req_to_token = model_runner.req_to_token_pool.req_to_token - self.wrapper_ragged = attn_backend.prefill_wrapper_ragged - self.wrappers_paged = attn_backend.prefill_wrappers_paged + self.prefill_wrapper_ragged = attn_backend.prefill_wrapper_ragged - # Dispatch + # Dispatch the update function if self.attn_backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: self.update = self.update_sliding_window elif self.attn_backend.dispatch_reason == WrapperDispatch.CROSS_ATTENTION: @@ -527,8 +659,10 @@ def update( seq_lens: torch.Tensor, seq_lens_sum: int, prefix_lens: torch.Tensor, + prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, - encoder_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + spec_info: Optional[SpecInfo], ): # Keep the signature for type checking. It will be assigned during runtime. raise NotImplementedError() @@ -539,8 +673,10 @@ def update_single_wrapper( seq_lens: torch.Tensor, seq_lens_sum: int, prefix_lens: torch.Tensor, + prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, - encoder_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + spec_info: Optional[SpecInfo], ): if use_ragged: paged_kernel_lens = prefix_lens @@ -550,8 +686,8 @@ def update_single_wrapper( paged_kernel_lens_sum = seq_lens_sum self.call_begin_forward( - self.wrapper_ragged, - self.wrappers_paged[0], + self.prefill_wrapper_ragged, + prefill_wrappers[0], req_pool_indices, paged_kernel_lens, paged_kernel_lens_sum, @@ -561,6 +697,7 @@ def update_single_wrapper( self.kv_indptr[0], self.qo_indptr[0], use_ragged, + spec_info, ) def update_sliding_window( @@ -569,8 +706,10 @@ def update_sliding_window( seq_lens: torch.Tensor, seq_lens_sum: int, prefix_lens: torch.Tensor, + prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, - encoder_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + spec_info: Optional[SpecInfo], ): for wrapper_id in range(2): if wrapper_id == 0: @@ -588,8 +727,8 @@ def update_sliding_window( kv_start_idx = seq_lens - paged_kernel_lens self.call_begin_forward( - self.wrapper_ragged, - self.wrappers_paged[wrapper_id], + self.prefill_wrapper_ragged, + prefill_wrappers[wrapper_id], req_pool_indices, paged_kernel_lens, paged_kernel_lens_sum, @@ -599,6 +738,7 @@ def update_sliding_window( self.kv_indptr[wrapper_id], self.qo_indptr[wrapper_id], use_ragged, + spec_info, ) def update_cross_attention( @@ -607,8 +747,10 @@ def update_cross_attention( seq_lens: torch.Tensor, seq_lens_sum: int, prefix_lens: torch.Tensor, + prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, - encoder_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + spec_info: Optional[SpecInfo], ): for wrapper_id in range(2): if wrapper_id == 0: @@ -623,8 +765,8 @@ def update_cross_attention( paged_kernel_lens_sum = paged_kernel_lens.sum().item() self.call_begin_forward( - self.wrapper_ragged, - self.wrappers_paged[wrapper_id], + self.prefill_wrapper_ragged, + prefill_wrappers[wrapper_id], req_pool_indices, paged_kernel_lens, paged_kernel_lens_sum, @@ -634,12 +776,13 @@ def update_cross_attention( self.kv_indptr[wrapper_id], self.qo_indptr[wrapper_id], use_ragged, + spec_info, ) def call_begin_forward( self, - wrapper_ragged, - wrapper_paged, + wrapper_ragged: BatchPrefillWithRaggedKVCacheWrapper, + wrapper_paged: BatchPrefillWithPagedKVCacheWrapper, req_pool_indices: torch.Tensor, paged_kernel_lens: torch.Tensor, paged_kernel_lens_sum: int, @@ -649,25 +792,39 @@ def call_begin_forward( kv_indptr: torch.Tensor, qo_indptr: torch.Tensor, use_ragged: bool, + spec_info: Optional[SpecInfo], ): bs = len(req_pool_indices) - kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) - kv_indptr = kv_indptr[: bs + 1] - kv_indices = torch.empty( - paged_kernel_lens_sum, dtype=torch.int32, device="cuda" - ) - create_flashinfer_kv_indices_triton[(bs,)]( - self.req_to_token, - req_pool_indices, - paged_kernel_lens, - kv_indptr, - kv_start_idx, - kv_indices, - self.req_to_token.shape[1], - ) + if spec_info is None: + # Normal extend + kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty( + paged_kernel_lens_sum + 256, + dtype=torch.int32, + device=req_pool_indices.device, + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + paged_kernel_lens, + kv_indptr, + kv_start_idx, + kv_indices, + self.req_to_token.shape[1], + ) - qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0) - qo_indptr = qo_indptr[: bs + 1] + qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0) + qo_indptr = qo_indptr[: bs + 1] + custom_mask = None + else: + kv_indices, kv_indptr, qo_indptr, custom_mask = ( + spec_info.generate_attn_arg_prefill( + req_pool_indices, + paged_kernel_lens, + self.req_to_token, + ) + ) # extend part if use_ragged: @@ -678,6 +835,7 @@ def call_begin_forward( self.num_qo_heads, self.num_kv_heads, self.head_dim, + q_data_type=self.q_data_type, ) # cached part @@ -691,6 +849,8 @@ def call_begin_forward( self.num_kv_heads, self.head_dim, 1, + q_data_type=self.q_data_type, + custom_mask=custom_mask, ) @@ -729,3 +889,51 @@ def create_flashinfer_kv_indices_triton( mask=mask, ) tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask) + + +def should_use_tensor_core( + kv_cache_dtype: torch.dtype, + num_attention_heads: int, + num_kv_heads: int, +) -> bool: + """ + Determine whether to use tensor cores for attention computation. + + Args: + kv_cache_dtype: Data type of the KV cache + num_attention_heads: Number of attention heads + num_kv_heads: Number of key/value heads + + Returns: + bool: Whether to use tensor cores + """ + # Try to use environment variable first + env_override = os.environ.get("SGLANG_FLASHINFER_USE_TENSOR_CORE") + if env_override is not None: + return env_override.lower() == "true" + + # Try to use _grouped_size_compiled_for_decode_kernels if available + # This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug + try: + from flashinfer.decode import _grouped_size_compiled_for_decode_kernels + + if not _grouped_size_compiled_for_decode_kernels( + num_attention_heads, + num_kv_heads, + ): + return True + else: + return False + except (ImportError, AttributeError): + pass + + # Calculate GQA group size + gqa_group_size = num_attention_heads // num_kv_heads + + # Determine based on dtype and GQA group size + if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + return True + elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16): + return gqa_group_size > 4 + else: + return False diff --git a/python/sglang/srt/layers/attention/torch_native_backend.py b/python/sglang/srt/layers/attention/torch_native_backend.py index 5e7e0e66e22..f73cd168e00 100644 --- a/python/sglang/srt/layers/attention/torch_native_backend.py +++ b/python/sglang/srt/layers/attention/torch_native_backend.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import torch from torch.nn.functional import scaled_dot_product_attention @@ -23,43 +23,6 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): """Init the metadata for a forward pass.""" pass - def init_cuda_graph_state(self, max_bs: int): - # TODO: Support CUDA graph - raise ValueError( - "Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph" - ) - - def init_forward_metadata_capture_cuda_graph( - self, - bs: int, - req_pool_indices: torch.Tensor, - seq_lens: torch.Tensor, - encoder_lens: Optional[torch.Tensor] = None, - ): - # TODO: Support CUDA graph - raise ValueError( - "Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph" - ) - - def init_forward_metadata_replay_cuda_graph( - self, - bs: int, - req_pool_indices: torch.Tensor, - seq_lens: torch.Tensor, - seq_lens_sum: int, - encoder_lens: Optional[torch.Tensor] = None, - ): - # TODO: Support CUDA graph - raise ValueError( - "Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph" - ) - - def get_cuda_graph_seq_len_fill_value(self): - # TODO: Support CUDA graph - raise ValueError( - "Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph" - ) - def _run_sdpa_forward_extend( self, query: torch.Tensor, diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 1b7c4c46d26..fade8ed292d 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -1,16 +1,17 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch from sglang.srt.layers.attention import AttentionBackend -from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.layers.dp_attention import get_attention_tp_size +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.speculative.spec_info import SpecInfo class TritonAttnBackend(AttentionBackend): @@ -28,17 +29,12 @@ def __init__(self, model_runner: ModelRunner): self.decode_attention_fwd = decode_attention_fwd self.extend_attention_fwd = extend_attention_fwd - if model_runner.server_args.enable_dp_attention: - self.num_head = model_runner.model_config.num_attention_heads - else: - self.num_head = ( - model_runner.model_config.num_attention_heads // model_runner.tp_size - ) + self.num_head = ( + model_runner.model_config.num_attention_heads // get_attention_tp_size() + ) - if global_server_args_dict.get("triton_attention_reduce_in_fp32", False): - self.reduce_dtype = torch.float32 - else: - self.reduce_dtype = torch.float16 + self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits + self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1] self.forward_metadata = None @@ -50,23 +46,23 @@ def init_forward_metadata(self, forward_batch: ForwardBatch): """Init auxiliary variables for triton attention backend.""" if forward_batch.forward_mode.is_decode(): - start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32) - start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0) - - total_num_tokens = forward_batch.seq_lens_sum attn_logits = torch.empty( - (self.num_head, total_num_tokens), - dtype=self.reduce_dtype, + ( + forward_batch.batch_size, + self.num_head, + self.num_kv_splits, + self.v_head_dim + 1, + ), + dtype=torch.float32, device=self.device, ) - max_seq_len = torch.max(forward_batch.seq_lens).item() max_extend_len = None else: - start_loc = attn_logits = max_seq_len = None + attn_logits = None max_extend_len = torch.max(forward_batch.extend_seq_lens).item() - self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len + self.forward_metadata = attn_logits, max_extend_len def init_cuda_graph_state(self, max_bs: int): self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len @@ -75,26 +71,27 @@ def init_cuda_graph_state(self, max_bs: int): (max_bs,), dtype=torch.int32, device=self.device ) self.cuda_graph_attn_logits = torch.empty( - ( - self.num_head, - self.cuda_graph_max_total_num_tokens, - ), - dtype=self.reduce_dtype, + (max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1), + dtype=torch.float32, device="cuda", ) def init_forward_metadata_capture_cuda_graph( self, bs: int, + num_tokens: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, - encoder_lens=None, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInfo], ): - # NOTE: encoder_lens expected to be zeros or None + assert encoder_lens is None, "Not supported" + assert forward_mode.is_decode(), "Not supported" + assert spec_info is None, "Not supported" + self.forward_metadata = ( - self.cuda_graph_start_loc, self.cuda_graph_attn_logits, - self.cuda_graph_max_seq_len, None, ) @@ -104,7 +101,9 @@ def init_forward_metadata_replay_cuda_graph( req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, - encoder_lens=None, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInfo], ): # NOTE: encoder_lens expected to be zeros or None self.cuda_graph_start_loc.zero_() @@ -115,9 +114,9 @@ def get_cuda_graph_seq_len_fill_value(self): def forward_extend( self, - q, - k, - v, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache=True, @@ -133,7 +132,7 @@ def forward_extend( layer, forward_batch.out_cache_loc, k, v ) - start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata + _, max_extend_len = self.forward_metadata self.extend_attention_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), k.contiguous(), @@ -154,9 +153,9 @@ def forward_extend( def forward_decode( self, - q, - k, - v, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache=True, @@ -171,7 +170,7 @@ def forward_decode( else: o = torch.empty_like(q) - start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata + attn_logits, _ = self.forward_metadata if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( @@ -185,10 +184,9 @@ def forward_decode( o.view(-1, layer.tp_q_head_num, layer.v_head_dim), forward_batch.req_to_token_pool.req_to_token, forward_batch.req_pool_indices, - start_loc, forward_batch.seq_lens, attn_logits, - max_seq_len, + self.num_kv_splits, layer.scaling, layer.logit_cap, ) diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index 56d38693f4f..2b4871af98c 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -17,8 +17,11 @@ """ # Adapted from -# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py -# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py +# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py +# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py + +import logging + import triton import triton.language as tl @@ -26,6 +29,13 @@ is_hip_ = is_hip() +logger = logging.getLogger(__name__) + +# TODO: Remove this when triton>=3.2.0. This issue will not affect performance and accuracy. +logger.warning( + "The following error message 'operation scheduled before its operands' can be ignored." +) + @triton.jit def tanh(x): @@ -37,10 +47,10 @@ def tanh(x): def _fwd_kernel_stage1( Q, K_Buffer, + V_Buffer, sm_scale, Req_to_tokens, B_req_idx, - B_Start_Loc, B_Seqlen, Att_Out, stride_req_to_tokens_b, @@ -48,152 +58,136 @@ def _fwd_kernel_stage1( stride_qh, stride_buf_kbs, stride_buf_kh, - att_stride_h, + stride_buf_vbs, + stride_buf_vh, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, kv_group_num: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_DV: tl.constexpr, BLOCK_N: tl.constexpr, - SPLIT_K: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, logit_cap: tl.constexpr, Lk: tl.constexpr, + Lv: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) - split_k_id = tl.program_id(2) + split_kv_id = tl.program_id(2) - reduce_dtype = Att_Out.dtype.element_ty cur_kv_head = cur_head // kv_group_num offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) cur_batch_req_idx = tl.load(B_req_idx + cur_batch) off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d - q = tl.load(Q + off_q).to(reduce_dtype) - - kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K) - split_k_start = kv_len_per_split * split_k_id - split_k_end = tl.minimum(split_k_start + kv_len_per_split, cur_batch_seq_len) - - for start_n in range(split_k_start, split_k_end, BLOCK_N): - offs_n = start_n + tl.arange(0, BLOCK_N) - k_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n, - mask=offs_n < split_k_end, - other=0, - ) - offs_buf_k = ( - k_loc[:, None] * stride_buf_kbs - + cur_kv_head * stride_buf_kh - + offs_d[None, :] - ) - k = tl.load( - K_Buffer + offs_buf_k, - mask=(offs_n[:, None] < split_k_end) & (offs_d[None, :] < Lk), - other=0.0, - ).to(reduce_dtype) - att_value = tl.sum(q[None, :] * k, 1) - att_value *= sm_scale - - if logit_cap > 0: - att_value = logit_cap * tanh(att_value / logit_cap) + q = tl.load(Q + off_q, mask=mask_d, other=0.0) - off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) - tl.store(Att_Out + off_o, att_value, mask=offs_n < split_k_end) + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + e_max = -float("inf") + e_sum = 0.0 + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n, + mask=offs_n < split_kv_end, + other=0, + ) + offs_buf_k = ( + kv_loc[:, None] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_d[None, :] + ) + k = tl.load( + K_Buffer + offs_buf_k, + mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]), + other=0.0, + ) + qk = tl.sum(q[None, :] * k, 1) + qk *= sm_scale -@triton.jit -def _fwd_kernel_stage2( - logits, - V_Buffer, - Out, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - stride_logic_h, - stride_buf_vbs, - stride_buf_vh, - stride_obs, - stride_oh, - stride_req_to_token_b, - kv_group_num: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - Lv: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) - cur_kv_head = cur_head // kv_group_num + qk = tl.where(offs_n < split_kv_end, qk, float("-inf")) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + offs_buf_v = ( + kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + + offs_dv[None, :] + ) + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) + n_e_max = tl.maximum(tl.max(qk, 0), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max) + acc *= re_scale + acc += tl.sum(p[:, None] * v, 0) - offs_buf_v = cur_kv_head * stride_buf_vh + offs_d[None, :] - v_ptrs = V_Buffer + offs_buf_v + e_sum = e_sum * re_scale + tl.sum(p, 0) + e_max = n_e_max - e_max = float("-inf") - e_sum = 0.0 - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) - - for start_n in range(0, cur_batch_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - v_index = tl.load( - Req_to_tokens - + cur_batch_req_idx * stride_req_to_token_b - + (start_n + offs_n), - mask=(start_n + offs_n) < cur_batch_seq_len, - other=0, + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + offs_dv ) - qk = tl.load( - logits - + cur_head * stride_logic_h - + (cur_batch_start_loc + start_n + offs_n), - mask=start_n + offs_n < cur_batch_seq_len, - other=float("-inf"), + tl.store( + Att_Out + offs_mid_o, + acc / e_sum, + mask=(mask_dv), ) - n_e_max = tl.maximum(tl.max(qk, 0), e_max) - old_scale = tl.exp(e_max - n_e_max) - p = tl.exp(qk - n_e_max) - e_sum = e_sum * old_scale + tl.sum(p, 0) - v = tl.load( - v_ptrs + v_index[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv) + offs_mid_o_1 = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + Lv ) - acc = acc * old_scale + tl.sum(p[:, None] * v, 0) - e_max = n_e_max - acc = acc / e_sum - off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=(offs_d < Lv)) + tl.store( + Att_Out + offs_mid_o_1, + e_max + tl.log(e_sum), + ) def _decode_att_m_fwd( q, k_buffer, + v_buffer, att_out, Req_to_tokens, B_req_idx, - B_Start_Loc, B_Seqlen, - max_len_in_batch, + num_kv_splits, sm_scale, logit_cap, ): - BLOCK = 32 - SPLIT_K = 8 + BLOCK = 64 + NUM_KV_SPLITS = num_kv_splits Lk = k_buffer.shape[-1] + Lv = v_buffer.shape[-1] batch, head_num = B_req_idx.shape[0], q.shape[1] - grid = (batch, head_num, SPLIT_K) + grid = (batch, head_num, NUM_KV_SPLITS) kv_group_num = q.shape[1] // k_buffer.shape[1] if kv_group_num == 1: @@ -202,14 +196,15 @@ def _decode_att_m_fwd( num_warps = 2 BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DV = triton.next_power_of_2(Lv) _fwd_kernel_stage1[grid]( q, k_buffer, + v_buffer, sm_scale, Req_to_tokens, B_req_idx, - B_Start_Loc, B_Seqlen, att_out, Req_to_tokens.stride(0), @@ -217,56 +212,20 @@ def _decode_att_m_fwd( q.stride(1), k_buffer.stride(0), k_buffer.stride(1), + v_buffer.stride(0), + v_buffer.stride(1), att_out.stride(0), + att_out.stride(1), + att_out.stride(2), kv_group_num=kv_group_num, BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DV=BLOCK_DV, BLOCK_N=BLOCK, - SPLIT_K=SPLIT_K, + NUM_KV_SPLITS=NUM_KV_SPLITS, logit_cap=logit_cap, num_warps=num_warps, - num_stages=1, + num_stages=2, Lk=Lk, - ) - - -def _decode_softmax_reducev_fwd( - logits, - v_buffer, - o, - req_to_tokens, - b_req_idx, - b_start_loc, - b_seq_len, -): - BLOCK = 64 - batch, head = b_seq_len.shape[0], logits.shape[0] - grid = (batch, head, 1) - kv_group_num = logits.shape[0] // v_buffer.shape[1] - - num_warps = 1 - - Lv = v_buffer.shape[-1] - BLOCK_DMODEL = triton.next_power_of_2(Lv) - - _fwd_kernel_stage2[grid]( - logits, - v_buffer, - o, - req_to_tokens, - b_req_idx, - b_start_loc, - b_seq_len, - logits.stride(0), - v_buffer.stride(0), - v_buffer.stride(1), - o.stride(0), - o.stride(1), - req_to_tokens.stride(0), - kv_group_num=kv_group_num, - BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=3, Lv=Lv, ) @@ -275,10 +234,10 @@ def _decode_softmax_reducev_fwd( def _fwd_grouped_kernel_stage1( Q, K_Buffer, + V_Buffer, sm_scale, Req_to_tokens, B_req_idx, - B_Start_Loc, B_Seqlen, Att_Out, stride_req_to_tokens_b, @@ -286,23 +245,27 @@ def _fwd_grouped_kernel_stage1( stride_qh, stride_buf_kbs, stride_buf_kh, - att_stride_h, + stride_buf_vbs, + stride_buf_vh, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, kv_group_num: tl.constexpr, q_head_num: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_DPE: tl.constexpr, + BLOCK_DV: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_H: tl.constexpr, - SPLIT_K: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, logit_cap: tl.constexpr, Lk: tl.constexpr, + Lv: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head_id = tl.program_id(1) cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) - split_k_id = tl.program_id(2) - - reduce_dtype = Att_Out.dtype.element_ty + split_kv_id = tl.program_id(2) if BLOCK_H < kv_group_num: VALID_BLOCK_H: tl.constexpr = BLOCK_H @@ -313,171 +276,139 @@ def _fwd_grouped_kernel_stage1( mask_h = mask_h & (cur_head < q_head_num) offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) cur_batch_req_idx = tl.load(B_req_idx + cur_batch) offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] - q = tl.load( - Q + offs_q, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk), other=0.0 - ).to(reduce_dtype) + q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) if BLOCK_DPE > 0: offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + mask_dpe = offs_dpe < Lk off_qpe = ( cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] ) - qpe = tl.load(Q + off_qpe, mask=mask_h[:, None], other=0.0).to(reduce_dtype) - - kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K) - split_k_start = kv_len_per_split * split_k_id - split_k_end = tl.minimum(split_k_start + kv_len_per_split, cur_batch_seq_len) - - for start_n in range(split_k_start, split_k_end, BLOCK_N): - offs_n = start_n + tl.arange(0, BLOCK_N) - k_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n, - mask=offs_n < split_k_end, - other=0, + qpe = tl.load( + Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0 ) - offs_buf_k = ( - k_loc[None, :] * stride_buf_kbs - + cur_kv_head * stride_buf_kh - + offs_d[:, None] - ) - k = tl.load( - K_Buffer + offs_buf_k, - mask=(offs_n[None, :] < split_k_end) & (offs_d[:, None] < Lk), - other=0.0, - ).to(reduce_dtype) - qk = tl.dot(q, k) - if BLOCK_DPE > 0: - offs_buf_kpe = ( - k_loc[None, :] * stride_buf_kbs + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n, + mask=offs_n < split_kv_end, + other=0, + ) + offs_buf_k = ( + kv_loc[None, :] * stride_buf_kbs + cur_kv_head * stride_buf_kh - + offs_dpe[:, None] + + offs_d[:, None] ) - kpe = tl.load( - K_Buffer + offs_buf_kpe, - mask=offs_n[None, :] < split_k_end, + k = tl.load( + K_Buffer + offs_buf_k, + mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), other=0.0, - ).to(reduce_dtype) - qk += tl.dot(qpe, kpe) - qk *= sm_scale - - if logit_cap > 0: - qk = logit_cap * tanh(qk / logit_cap) - - offs_o = cur_head[:, None] * att_stride_h + ( - cur_batch_in_all_start_index + offs_n[None, :] - ) - - tl.store( - Att_Out + offs_o, - qk, - mask=mask_h[:, None] & (offs_n[None, :] < split_k_end), - ) - - -@triton.jit -def _fwd_grouped_kernel_stage2( - logits, - V_Buffer, - Out, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - stride_logic_h, - stride_buf_vbs, - stride_buf_vh, - stride_obs, - stride_oh, - stride_req_to_token_b, - kv_group_num: tl.constexpr, - q_head_num: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_H: tl.constexpr, - Lv: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head_id = tl.program_id(1) - cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) - - if BLOCK_H < kv_group_num: - VALID_BLOCK_H: tl.constexpr = BLOCK_H - else: - VALID_BLOCK_H: tl.constexpr = kv_group_num - cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H) - mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H - mask_h = mask_h & (cur_head < q_head_num) + ) + qk = tl.dot(q, k.to(q.dtype)) + if BLOCK_DPE > 0: + offs_buf_kpe = ( + kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_dpe[:, None] + ) + kpe = tl.load( + K_Buffer + offs_buf_kpe, + mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]), + other=0.0, + ) + qk += tl.dot(qpe, kpe.to(qpe.dtype)) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where( + mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf") + ) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + offs_buf_v = ( + kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + + offs_dv[None, :] + ) + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + acc += tl.dot(p.to(v.dtype), v) - offs_buf_v = cur_kv_head * stride_buf_vh + offs_d[None, :] - v_ptrs = V_Buffer + offs_buf_v + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max - e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") - e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) - acc = tl.zeros([BLOCK_H, BLOCK_DMODEL], dtype=tl.float32) - - for start_n in range(0, cur_batch_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - v_index = tl.load( - Req_to_tokens - + cur_batch_req_idx * stride_req_to_token_b - + (start_n + offs_n), - mask=(start_n + offs_n) < cur_batch_seq_len, - other=0, + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head[:, None] * stride_mid_oh + + split_kv_id * stride_mid_os + + offs_dv[None, :] ) - offs_qk = cur_head[:, None] * stride_logic_h + ( - cur_batch_start_loc + start_n + offs_n[None, :] + tl.store( + Att_Out + offs_mid_o, + acc / e_sum[:, None], + mask=(mask_h[:, None]) & (mask_dv[None, :]), ) - qk = tl.load( - logits + offs_qk, - mask=mask_h[:, None] & (start_n + offs_n[None, :] < cur_batch_seq_len), - other=float("-inf"), + offs_mid_o_1 = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + Lv ) - n_e_max = tl.maximum(tl.max(qk, 1), e_max) - old_scale = tl.exp(e_max - n_e_max) - p = tl.exp(qk - n_e_max[:, None]) - e_sum = e_sum * old_scale + tl.sum(p, 1) - v = tl.load( - v_ptrs + v_index[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv) + tl.store( + Att_Out + offs_mid_o_1, + e_max + tl.log(e_sum), + mask=mask_h, ) - p = p.to(v.dtype) - acc = acc * old_scale[:, None] + tl.dot(p, v) - e_max = n_e_max - - acc = acc / e_sum[:, None] - off_o = cur_batch * stride_obs + cur_head[:, None] * stride_oh + offs_d[None, :] - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=(mask_h[:, None]) & (offs_d[None, :] < Lv)) def _decode_grouped_att_m_fwd( q, k_buffer, + v_buffer, att_out, Req_to_tokens, B_req_idx, - B_Start_Loc, B_Seqlen, - max_len_in_batch, + num_kv_splits, sm_scale, logit_cap, ): - BLOCK = 64 + BLOCK = 32 Lk = k_buffer.shape[-1] + Lv = v_buffer.shape[-1] + + # [TODO] work around shmem limit on MI3xx + if is_hip_ and Lk >= 576: + BLOCK = 16 if Lk == 576: BLOCK_DMODEL = 512 @@ -488,20 +419,19 @@ def _decode_grouped_att_m_fwd( else: BLOCK_DMODEL = triton.next_power_of_2(Lk) BLOCK_DPE = 0 + BLOCK_DV = triton.next_power_of_2(Lv) batch, head_num = B_req_idx.shape[0], q.shape[1] kv_group_num = q.shape[1] // k_buffer.shape[1] - BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num))) - SPLIT_K = 8 + BLOCK_H = 16 + NUM_KV_SPLITS = num_kv_splits grid = ( batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), - SPLIT_K, + NUM_KV_SPLITS, ) - num_warps = 4 - extra_kargs = {} if is_hip_: # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html @@ -511,10 +441,10 @@ def _decode_grouped_att_m_fwd( _fwd_grouped_kernel_stage1[grid]( q, k_buffer, + v_buffer, sm_scale, Req_to_tokens, B_req_idx, - B_Start_Loc, B_Seqlen, att_out, Req_to_tokens.stride(0), @@ -522,41 +452,97 @@ def _decode_grouped_att_m_fwd( q.stride(1), k_buffer.stride(0), k_buffer.stride(1), + v_buffer.stride(0), + v_buffer.stride(1), att_out.stride(0), + att_out.stride(1), + att_out.stride(2), kv_group_num=kv_group_num, q_head_num=head_num, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DPE=BLOCK_DPE, + BLOCK_DV=BLOCK_DV, BLOCK_N=BLOCK, BLOCK_H=BLOCK_H, - SPLIT_K=SPLIT_K, + NUM_KV_SPLITS=NUM_KV_SPLITS, logit_cap=logit_cap, - num_warps=num_warps, - num_stages=1, + num_warps=4, + num_stages=2, Lk=Lk, + Lv=Lv, **extra_kargs, ) -def _decode_grouped_softmax_reducev_fwd( +@triton.jit +def _fwd_kernel_stage2( + Mid_O, + O, + B_Seqlen, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_obs, + stride_oh, + NUM_KV_SPLITS: tl.constexpr, + BLOCK_DV: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + + offs_d = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lv + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv + + for split_kv_id in range(0, NUM_KV_SPLITS): + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + if split_kv_end > split_kv_start: + tv = tl.load( + Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0 + ) + tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os) + n_e_max = tl.maximum(tlogic, e_max) + + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(tlogic - n_e_max) + acc += exp_logic * tv + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, + acc / e_sum, + mask=mask_d, + ) + + +def _decode_softmax_reducev_fwd( logits, - v_buffer, + q, o, - req_to_tokens, - b_req_idx, - b_start_loc, + v_buffer, b_seq_len, + num_kv_splits, ): - BLOCK = 128 - batch, head_num = b_seq_len.shape[0], logits.shape[0] - kv_group_num = logits.shape[0] // v_buffer.shape[1] - BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num))) - grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1) - - num_warps = 8 - + batch, head_num = q.shape[0], q.shape[1] Lv = v_buffer.shape[-1] - BLOCK_DMODEL = triton.next_power_of_2(Lv) + BLOCK_DV = triton.next_power_of_2(Lv) + + NUM_KV_SPLITS = num_kv_splits extra_kargs = {} if is_hip_: @@ -564,28 +550,21 @@ def _decode_grouped_softmax_reducev_fwd( # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} - _fwd_grouped_kernel_stage2[grid]( + grid = (batch, head_num) + _fwd_kernel_stage2[grid]( logits, - v_buffer, o, - req_to_tokens, - b_req_idx, - b_start_loc, b_seq_len, logits.stride(0), - v_buffer.stride(0), - v_buffer.stride(1), + logits.stride(1), + logits.stride(2), o.stride(0), o.stride(1), - req_to_tokens.stride(0), - kv_group_num=kv_group_num, - q_head_num=head_num, - BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_N=BLOCK, - BLOCK_H=BLOCK_H, + NUM_KV_SPLITS=NUM_KV_SPLITS, + BLOCK_DV=BLOCK_DV, Lv=Lv, - num_warps=num_warps, - num_stages=1, + num_warps=4, + num_stages=2, **extra_kargs, ) @@ -597,34 +576,25 @@ def decode_attention_fwd_normal( o, req_to_token, b_req_idx, - b_start_loc, b_seq_len, attn_logits, - max_len_in_batch, + num_kv_splits, sm_scale, logit_cap=0.0, ): _decode_att_m_fwd( q, k_buffer, + v_buffer, attn_logits, req_to_token, b_req_idx, - b_start_loc, b_seq_len, - max_len_in_batch, + num_kv_splits, sm_scale, logit_cap, ) - _decode_softmax_reducev_fwd( - attn_logits, - v_buffer, - o, - req_to_token, - b_req_idx, - b_start_loc, - b_seq_len, - ) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits) def decode_attention_fwd_grouped( @@ -634,34 +604,25 @@ def decode_attention_fwd_grouped( o, req_to_token, b_req_idx, - b_start_loc, b_seq_len, attn_logits, - max_len_in_batch, + num_kv_splits, sm_scale, logit_cap=0.0, ): _decode_grouped_att_m_fwd( q, k_buffer, + v_buffer, attn_logits, req_to_token, b_req_idx, - b_start_loc, b_seq_len, - max_len_in_batch, + num_kv_splits, sm_scale, logit_cap, ) - _decode_grouped_softmax_reducev_fwd( - attn_logits, - v_buffer, - o, - req_to_token, - b_req_idx, - b_start_loc, - b_seq_len, - ) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits) def decode_attention_fwd( @@ -671,13 +632,13 @@ def decode_attention_fwd( o, req_to_token, b_req_idx, - b_start_loc, b_seq_len, attn_logits, - max_len_in_batch, + num_kv_splits, sm_scale, logit_cap=0.0, ): + assert num_kv_splits == attn_logits.shape[2] kv_group_num = q.shape[1] // v_buffer.shape[1] if kv_group_num == 1: @@ -689,10 +650,9 @@ def decode_attention_fwd( o, req_to_token, b_req_idx, - b_start_loc, b_seq_len, attn_logits, - max_len_in_batch, + num_kv_splits, sm_scale, logit_cap, ) @@ -705,10 +665,9 @@ def decode_attention_fwd( o, req_to_token, b_req_idx, - b_start_loc, b_seq_len, attn_logits, - max_len_in_batch, + num_kv_splits, sm_scale, logit_cap, ) diff --git a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py index b7afd62e723..b2654f1f780 100644 --- a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py @@ -292,27 +292,33 @@ def extend_attention_fwd( BLOCK_DPE = 0 BLOCK_DV = triton.next_power_of_2(Lv) - if is_cuda_available and CUDA_CAPABILITY[0] >= 9: - if Lq <= 256: - BLOCK_M, BLOCK_N = (128, 64) - else: - BLOCK_M, BLOCK_N = (32, 64) - elif is_cuda_available and CUDA_CAPABILITY[0] >= 8: - if Lq <= 128: - BLOCK_M, BLOCK_N = (128, 128) - elif Lq <= 256: - BLOCK_M, BLOCK_N = (64, 64) - else: - BLOCK_M, BLOCK_N = (32, 64) + if is_hip_: + BLOCK_M, BLOCK_N = (64, 64) + num_warps = 4 + else: - BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32) + if is_cuda_available and CUDA_CAPABILITY[0] >= 9: + if Lq <= 256: + BLOCK_M, BLOCK_N = (128, 64) + else: + BLOCK_M, BLOCK_N = (32, 64) + elif is_cuda_available and CUDA_CAPABILITY[0] >= 8: + if Lq <= 128: + BLOCK_M, BLOCK_N = (128, 128) + elif Lq <= 256: + BLOCK_M, BLOCK_N = (64, 64) + else: + BLOCK_M, BLOCK_N = (32, 64) + else: + BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32) + + num_warps = 4 if Lk <= 64 else 8 sm_scale = sm_scale or 1.0 / (Lq**0.5) batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1] kv_group_num = q_extend.shape[1] // k_extend.shape[1] grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M)) - num_warps = 4 if Lk <= 64 else 8 num_stages = 1 extra_kargs = {} diff --git a/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py b/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py index 9163eba68de..d022b972147 100644 --- a/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py @@ -166,6 +166,12 @@ def _fwd_kernel( def context_attention_fwd( q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True ): + """ + q, k, v: [b * s, head, head_dim] + b_start_loc: [b] + b_seq_len: [b] + out: [b * s, head, head_dim] + """ if is_cuda_available and CUDA_CAPABILITY[0] > 8: BLOCK = 128 else: diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py new file mode 100644 index 00000000000..03c4cfb46a8 --- /dev/null +++ b/python/sglang/srt/layers/attention/vision.py @@ -0,0 +1,407 @@ +from __future__ import annotations + +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + +from sglang.srt.distributed import parallel_state +from sglang.srt.distributed import utils as dist_utils +from sglang.srt.layers.attention.triton_ops.prefill_attention import ( + context_attention_fwd, +) +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.quantization import QuantizationConfig + + +def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange( + torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 + ) + + +def apply_rotary_emb_torch( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False +) -> torch.Tensor: + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat( + cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) + sin = repeat( + sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) + return torch.cat( + [ + x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, + x[..., ro_dim:], + ], + dim=-1, + ) + + +def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + t_ = t.float() + cos = freqs.cos() + sin = freqs.sin() + output = apply_rotary_emb_torch(t_, cos, sin).type_as(t) + return output + + +class VisionAttention(nn.Module): + r""" + Multi-headed attention without any cache, mostly used for ViT. + + + Args: + use_qkv_parallel (bool, optional): If True, use QKV-parallel attention. + use_context_forward (bool, default to True): + if ``True``, a flash_attn style attention will be applied + Otherwise, a full-sequence attention will be applied. + use_full_precision_softmax (bool, default to False): + if ``True``, the softmax will be performed in full-precision + Otherwise, it will be performed in half-precision + + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + projection_size: int, + use_qkv_parallel: bool, + quant_config: Optional[QuantizationConfig] = None, + dropout: float = 0.0, + use_context_forward: bool = True, + use_full_precision_softmax: bool = False, + flatten_batch: bool = False, + prefix: str = "", + ): + super().__init__() + self.use_context_forward = use_context_forward + world_size = parallel_state.get_tensor_model_parallel_world_size() + self.dropout = dropout + self.head_size = embed_dim // num_heads + self.hidden_size_per_attention_head = dist_utils.divide( + projection_size, num_heads + ) + self.num_attention_heads_per_partition = dist_utils.divide( + num_heads, world_size + ) + + if self.use_context_forward: + self.qkv_backend = VisionTritonAttention() + else: + self.qkv_backend = VisionSdpaAttention( + head_size=self.head_size, + dropout=dropout, + flatten_batch=flatten_batch, + use_full_precision_softmax=use_full_precision_softmax, + ) + + self.use_qkv_parallel = use_qkv_parallel + if use_qkv_parallel: + self.qkv_proj = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.head_size, + total_num_heads=num_heads, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + else: + self.qkv_proj = ColumnParallelLinear( + input_size=embed_dim, + output_size=3 * projection_size, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.proj = RowParallelLinear( + input_size=embed_dim, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + + def forward( + self, + x: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None, + rotary_pos_emb: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r""" + Args: + x: [b, s, embed_dim] + cu_seqlens: [b] + Returns: + [s, b, num_heads * head] + """ + bsz, s, _ = x.shape + if self.use_qkv_parallel: + # [b, s, embed_dim] --> [b, s, embed_dim] + qkv, _ = self.qkv_proj(x) + q, k, v = qkv.chunk(3, dim=-1) + + # [b, s, embed_dim] --> [b * s, num_heads, head_size] + q, k, v = [ + x.reshape( + bsz * s, self.num_attention_heads_per_partition, -1 + ).contiguous() + for x in (q, k, v) + ] + else: + # [b, s, embed_dim] --> [s, b, embed_dim] + x = rearrange(x, "b s ... -> s b ...") + # [s, b, embed_dim] --> [s, b, head * 3 * head_size] + qkv, _ = self.qkv_proj(x) + # [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size] + new_x_shape = qkv.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + qkv = qkv.view(*new_x_shape) + + # [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size] + q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3) + + # [s, b, head, head_size] --> [b, s, head, head_size] + q, k, v = [ + rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v) + ] + + if rotary_pos_emb is not None: + q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) + k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) + + if self.use_qkv_parallel: + pass + else: + # [b, s, head, head_size] --> [b * s, head, head_size] + q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]] + + output = self.qkv_backend.forward(q, k, v, bsz, cu_seqlens, attention_mask) + + if self.use_qkv_parallel: + # [b * s, h, head_size] --> [b, s, h * head_size] + output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz) + + # [b, s, h * head_size] --> [b, s, h * head_size] + output, _ = self.proj(output) + else: + # [b * s, h, head_size] --> [s, b, h * head_size] + context_layer = rearrange( + output, "(b s) h d -> s b (h d)", b=bsz, s=s + ).contiguous() + + # [s, b, h * head_size] --> [s, b, h * head_size] + output, _ = self.proj(context_layer) + + # [s, b, h * head_size] --> [b, s, h * head_size] + output = output.view(bsz, s, -1) + + return output + + +class VisionSdpaAttention(nn.Module): + r""" + Scaled Dot Product Attention inner product + + """ + + # TODO: Should it be released after used? + _mask_cache = {} + + def __init__( + self, + head_size: int, + dropout: float = 0.0, + flatten_batch: bool = False, + use_full_precision_softmax: bool = False, + ): + super().__init__() + self.head_size = head_size + self.flatten_batch = flatten_batch + self.use_full_precision_softmax = use_full_precision_softmax + self.dropout = dropout + + def generate_patch_attention_mask( + self, + s: int, + bsz: int, + device, + cu_seqlens: Optional[torch.Tensor], + flatten_batch: bool = False, + dtype=torch.bfloat16, + ) -> torch.Tensor: + r""" + Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`. + + When `flatten_batch` is True: + - All sequences in the batch are flattened into a single dimension + - `s` represents the total number of tokens across all sequences in the batch + - Returns a unified mask of shape `(1, 1, s, s)` + + When `flatten_batch` is False: + - Each sequence has its own attention mask + - `s` represents the maximum sequence length in the batch + - Returns separate masks of shape `(b, 1, s, s)` + + Args: + flatten_batch: (bool): + If True, treats all sequences in the batch as a single flattened sequence + If False, generates separate masks for each sequence + + Returns: + Tensor of shape `(b, 1, s, s)` or `(1, 1, s, s)`. + """ + + cache_key = (s, bsz, flatten_batch, tuple(cu_seqlens.cpu().tolist())) + + if cache_key in VisionSdpaAttention._mask_cache: + cached_mask = VisionSdpaAttention._mask_cache[cache_key] + # print(f"cache hit for key: {cache_key}") + return cached_mask.to(device=device, dtype=dtype) + + if cu_seqlens is None: + raise ValueError("Internal Error: cu_seqlens cannot be None") + + if flatten_batch: + mask = torch.zeros([1, s, s], device=device, dtype=torch.bool) + for i in range(1, len(cu_seqlens)): + start = cu_seqlens[i - 1] + end = cu_seqlens[i] + mask[ + ..., + start:end, + start:end, + ] = True + else: + # [1, 1, 1, s] + row_indices = torch.arange(s, device=device).view(1, 1, 1, s) + # [1, 1, s, 1] + col_indices = torch.arange(s, device=device).view(1, 1, s, 1) + # [b, 1, 1, 1] + seq_lens = ( + (cu_seqlens[1:] - cu_seqlens[:-1]).to(device=device).view(-1, 1, 1, 1) + ) + + mask = (row_indices < seq_lens) & (col_indices < seq_lens) + + # Convert to attention mask format (False -> 0, True -> -inf) + mask = (~mask).to(dtype) * torch.finfo(dtype).min + + VisionSdpaAttention._mask_cache[cache_key] = mask + + return mask + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + bsz: int, + cu_seqlens: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r""" + Args: + cu_seqlens: [b] + Returns: + [b * s, h, head_size] + """ + + s = q.shape[0] // bsz + + # [b, 1, s, s] + if attention_mask is None: + attention_mask = self.generate_patch_attention_mask( + s, bsz, q.device, cu_seqlens, self.flatten_batch, q.dtype + ) + q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]] + # [b, 1, s] + if self.use_full_precision_softmax: + scale = self.head_size**-0.5 + k_transposed = rearrange(k, "b h s d -> b h d s") + attn_weights = torch.matmul(q, k_transposed) * scale + del k, k_transposed + attn_weights = attn_weights + attention_mask + del attention_mask + # full-precision + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(q.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=False + ) + output = torch.matmul(attn_weights, v) + del attn_weights, v + else: + # SDPA + # [b, h, s, head_size] + output = F.scaled_dot_product_attention( + q, k, v, attention_mask, dropout_p=self.dropout + ) + + # [b, h, s, head_size] --> [b * s, h, head_size] + output = rearrange(output, "b h s d -> (b s) h d") + + return output + + +class VisionTritonAttention(nn.Module): + """ + Triton-implemented attention without a causal mask + """ + + def __init__( + self, + ): + super().__init__() + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + _bsz: int, + cu_seqlens: Optional[torch.Tensor], + **kwargs, + ) -> torch.Tensor: + r""" + Args: + cu_seqlens: [b] + Returns: + [b * s, h, head_size] + """ + + # [b * s, head, head_size] + output = torch.empty_like(q) + seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] + max_seqlen = seq_lens.max().item() + context_attention_fwd( + q, + k, + v, + output, + cu_seqlens.cuda(), + seq_lens.cuda(), + max_seqlen, + is_causal=False, + ) + + return output diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py new file mode 100644 index 00000000000..36b87ca0ba0 --- /dev/null +++ b/python/sglang/srt/layers/dp_attention.py @@ -0,0 +1,71 @@ +import torch + +from sglang.srt.distributed import GroupCoordinator, get_tp_group + +_ATTN_TP_GROUP = None +_ATTN_TP_RANK = None +_ATTN_TP_SIZE = None +_DP_RANK = None +_DP_SIZE = None + + +def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size): + if not enable_dp_attention: + return tp_rank, tp_size, 0 + + attn_tp_size = tp_size // dp_size + dp_rank = tp_rank // attn_tp_size + attn_tp_rank = tp_rank % attn_tp_size + return attn_tp_rank, attn_tp_size, dp_rank + + +def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size): + global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE + + from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP + + _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info( + enable_dp_attention, tp_rank, tp_size, dp_size + ) + _DP_SIZE = dp_size + + tp_group = get_tp_group() + _ATTN_TP_GROUP = GroupCoordinator( + [ + list(range(head, head + _ATTN_TP_SIZE)) + for head in range(0, tp_size, _ATTN_TP_SIZE) + ], + tp_rank, + torch.distributed.get_backend(tp_group.device_group), + SYNC_TOKEN_IDS_ACROSS_TP, + False, + False, + False, + False, + group_name="attention_tp", + ) + + +def get_attention_tp_group(): + assert _ATTN_TP_GROUP is not None, "dp attention not initialized!" + return _ATTN_TP_GROUP + + +def get_attention_tp_rank(): + assert _ATTN_TP_RANK is not None, "dp attention not initialized!" + return _ATTN_TP_RANK + + +def get_attention_tp_size(): + assert _ATTN_TP_SIZE is not None, "dp attention not initialized!" + return _ATTN_TP_SIZE + + +def get_attention_dp_rank(): + assert _DP_RANK is not None, "dp attention not initialized!" + return _DP_RANK + + +def get_attention_dp_size(): + assert _DP_SIZE is not None, "dp attention not initialized!" + return _DP_SIZE diff --git a/python/sglang/srt/layers/ep_moe/__init__.py b/python/sglang/srt/layers/ep_moe/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/python/sglang/srt/layers/fused_moe_patch.py b/python/sglang/srt/layers/fused_moe_patch.py deleted file mode 100644 index baca2581150..00000000000 --- a/python/sglang/srt/layers/fused_moe_patch.py +++ /dev/null @@ -1,133 +0,0 @@ -""" -Torch-native implementation for FusedMoE. This is used for torch.compile. -It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204 -""" - -from typing import Callable, Optional - -import torch -from torch.nn import functional as F - - -def fused_topk_native( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, -): - assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - M, _ = hidden_states.shape - topk_weights = torch.empty( - M, topk, dtype=torch.float32, device=hidden_states.device - ) - topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) - topk_weights = F.softmax(gating_output.float(), dim=-1) - topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1) - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids - - -# This is used by the Deepseek-V2 model -def grouped_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0, -): - - assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - - scores = torch.softmax(gating_output, dim=-1) - num_token = scores.shape[0] - group_scores = ( - scores.view(num_token, num_expert_group, -1).max(dim=-1).values - ) # [n, n_group] - group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ - 1 - ] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = ( - group_mask.unsqueeze(-1) - .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) - .reshape(num_token, -1) - ) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - return topk_weights, topk_ids - - -def select_experts_native( - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - use_grouped_topk: bool, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, -): - # DeekSeekv2 uses grouped_top_k - if use_grouped_topk: - assert topk_group is not None - assert num_expert_group is not None - topk_weights, topk_ids = grouped_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - num_expert_group=num_expert_group, - topk_group=topk_group, - ) - else: - topk_weights, topk_ids = fused_topk_native( - hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - ) - return topk_weights, topk_ids - - -def fused_moe_forward_native( - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, -) -> torch.Tensor: - - if use_grouped_topk: - assert num_expert_group is not None and topk_group is not None - topk_weights, topk_ids = grouped_topk( - x, - router_logits, - top_k, - renormalize, - num_expert_group, - topk_group, - ) - elif custom_routing_function is None: - topk_weights, topk_ids = fused_topk_native(x, router_logits, top_k, renormalize) - else: - topk_weights, topk_ids = custom_routing_function( - x, router_logits, top_k, renormalize - ) - - w13_weights = layer.w13_weight[topk_ids] - w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2) - w2_weights = layer.w2_weight[topk_ids] - x1 = torch.einsum("ti,taoi -> tao", x, w1_weights) - x1 = F.silu(x1) - x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) - expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) - return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype)) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index bd95b9bccce..207ba8d1b7a 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -19,10 +19,10 @@ import torch import torch.nn as nn -from sglang.srt.utils import is_flashinfer_available +from sglang.srt.utils import is_cuda_available -if is_flashinfer_available(): - from flashinfer.norm import ( +if is_cuda_available(): + from sgl_kernel import ( fused_add_rmsnorm, gemma_fused_add_rmsnorm, gemma_rmsnorm, @@ -121,8 +121,8 @@ def forward_cuda( return out -if not is_flashinfer_available(): +if not is_cuda_available(): logger.info( - "FlashInfer is not available on Non-NV platforms. Fallback to other kernel libraries." + "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries." ) from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index f69058ff319..64daf79c50f 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -1,4 +1,4 @@ -# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/linear.py +"""Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py""" import logging from abc import abstractmethod @@ -7,7 +7,8 @@ import torch import torch.nn.functional as F from torch.nn.parameter import Parameter, UninitializedParameter -from vllm.distributed import ( + +from sglang.srt.distributed import ( divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -15,21 +16,18 @@ tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, ) - -# workaround -from vllm.model_executor.layers.linear import LinearBase -from vllm.model_executor.parameter import ( +from sglang.srt.layers.parameter import ( BasevLLMParameter, PackedColumnParameter, PackedvLLMParameter, PerTensorScaleParameter, RowvLLMParameter, ) - from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) +from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter from sglang.srt.utils import set_weight_attrs logger = logging.getLogger(__name__) @@ -41,8 +39,13 @@ "GPTQMarlinLinearMethod", "Fp8LinearMethod", "MarlinLinearMethod", - "GPTQLinearMethod", "QQQLinearMethod", + "GPTQMarlin24LinearMethod", + "TPUInt8LinearMethod", + "GPTQLinearMethod", + "FBGEMMFp8LinearMethod", + "ModelOptFp8LinearMethod", + "IPEXAWQLinearMethod", ] @@ -168,6 +171,45 @@ def apply( return F.linear(x, layer.weight, bias) +class LinearBase(torch.nn.Module): + """Base linear layer. + + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + bias: If true, add bias. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + """ + + def __init__( + self, + input_size: int, + output_size: int, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.skip_bias_add = skip_bias_add + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + if quant_config is None: + self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod() + else: + self.quant_method = quant_config.get_quant_method(self, prefix=prefix) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + class ReplicatedLinear(LinearBase): """Replicated linear layer. @@ -285,15 +327,23 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, output_sizes: Optional[List[int]] = None, prefix: str = "", + tp_rank: Optional[int] = None, + tp_size: Optional[int] = None, + use_presharded_weights: bool = False, ): super().__init__( input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix ) self.gather_output = gather_output + self.use_presharded_weights = use_presharded_weights # Divide the weight matrix along the last dimension. - tp_size = get_tensor_model_parallel_world_size() + if tp_rank is None: + tp_rank = get_tensor_model_parallel_rank() + if tp_size is None: + tp_size = get_tensor_model_parallel_world_size() + self.tp_rank, self.tp_size = tp_rank, tp_size assert self.quant_method is not None self.output_size_per_partition = divide(self.output_size, tp_size) self.output_partition_sizes = [self.output_size_per_partition] @@ -334,7 +384,6 @@ def __init__( self.register_parameter("bias", None) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() output_dim = getattr(param, "output_dim", None) # Special case for GGUF @@ -354,8 +403,9 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): # no need to narrow here if output_dim is not None and not use_bitsandbytes_4bit: shard_size = param_data.shape[output_dim] - start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + start_idx = self.tp_rank * shard_size + if not self.use_presharded_weights: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) # Special case for loading scales off disk, which often do not # have a shape (such as in the case of AutoFP8). @@ -371,7 +421,11 @@ def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor): if len(loaded_weight.shape) == 0: assert loaded_weight.numel() == 1 loaded_weight = loaded_weight.reshape(1) - param.load_column_parallel_weight(loaded_weight=loaded_weight) + param.load_column_parallel_weight( + loaded_weight, + tp_rank=self.tp_rank, + use_presharded_weights=self.use_presharded_weights, + ) def forward(self, input_): bias = self.bias if not self.skip_bias_add else None @@ -391,7 +445,7 @@ def extra_repr(self) -> str: s = f"in_features={self.input_size}" s += f", output_features={self.output_size_per_partition}" s += f", bias={self.bias is not None}" - s += f", tp_size={get_tensor_model_parallel_world_size()}" + s += f", tp_size={self.tp_size}" s += f", gather_output={self.gather_output}" return s @@ -429,10 +483,18 @@ def __init__( params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + tp_rank: Optional[int] = None, + tp_size: Optional[int] = None, + use_presharded_weights: bool = False, ): self.output_sizes = output_sizes - tp_size = get_tensor_model_parallel_world_size() + if tp_rank is None: + tp_rank = get_tensor_model_parallel_rank() + if tp_size is None: + tp_size = get_tensor_model_parallel_world_size() + self.tp_rank, self.tp_size = tp_rank, tp_size assert all(output_size % tp_size == 0 for output_size in output_sizes) + self.use_presharded_weights = use_presharded_weights super().__init__( input_size=input_size, output_size=sum(output_sizes), @@ -442,7 +504,11 @@ def __init__( params_dtype=params_dtype, quant_config=quant_config, prefix=prefix, + tp_rank=tp_rank, + tp_size=tp_size, + use_presharded_weights=use_presharded_weights, ) + self.prefix = prefix def weight_loader( self, @@ -461,12 +527,9 @@ def weight_loader( return if is_gguf_weight: - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - output_dim = getattr(param, "output_dim", None) - shard_size = loaded_weight.size(output_dim) // tp_size - start_idx = tp_rank * shard_size + shard_size = loaded_weight.size(output_dim) // self.tp_size + start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) @@ -520,11 +583,9 @@ def weight_loader( return assert loaded_shard_id < len(self.output_sizes) - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() if output_dim is not None: - shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size - shard_size = self.output_sizes[loaded_shard_id] // tp_size + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size + shard_size = self.output_sizes[loaded_shard_id] // self.tp_size # Special case for quantization. # If quantized, we need to adjust the offset and size to account # for the packing. @@ -543,10 +604,10 @@ def weight_loader( shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id param_data = param_data.narrow(output_dim, shard_offset, shard_size) - start_idx = tp_rank * shard_size + start_idx = self.tp_rank * shard_size # bitsandbytes loads the weights of the specific portion # no need to narrow here - if not use_bitsandbytes_4bit: + if not use_bitsandbytes_4bit and not self.use_presharded_weights: loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) # Special case for AQLM codebooks. elif is_metadata: @@ -622,20 +683,33 @@ def weight_loader_v2( elif type(param) in (RowvLLMParameter, BasevLLMParameter): param.load_merged_column_weight(loaded_weight=loaded_weight) return + # TODO: @dsikka - move to parameter.py self._load_fused_module_from_checkpoint(param, loaded_weight) return assert loaded_shard_id < len(self.output_sizes) - tp_size = get_tensor_model_parallel_world_size() - shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size - shard_size = self.output_sizes[loaded_shard_id] // tp_size + if isinstance(param, BlockQuantScaleParameter): + weight_block_size = self.quant_method.quant_config.weight_block_size + block_n, _ = weight_block_size[0], weight_block_size[1] + shard_offset = ( + (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n + ) // self.tp_size + shard_size = ( + (self.output_sizes[loaded_shard_id] + block_n - 1) + // block_n + // self.tp_size + ) + else: + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size + shard_size = self.output_sizes[loaded_shard_id] // self.tp_size param.load_merged_column_weight( loaded_weight=loaded_weight, shard_id=loaded_shard_id, shard_offset=shard_offset, shard_size=shard_size, + use_presharded_weights=self.use_presharded_weights, ) @@ -676,6 +750,9 @@ def __init__( params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + tp_rank: Optional[int] = None, + tp_size: Optional[int] = None, + load_presharded_attn: bool = False, ): self.hidden_size = hidden_size self.head_size = head_size @@ -684,7 +761,11 @@ def __init__( total_num_kv_heads = total_num_heads self.total_num_kv_heads = total_num_kv_heads # Divide the weight matrix along the last dimension. - tp_size = get_tensor_model_parallel_world_size() + if tp_rank is None: + tp_rank = get_tensor_model_parallel_rank() + if tp_size is None: + tp_size = get_tensor_model_parallel_world_size() + self.tp_rank, self.tp_size = tp_rank, tp_size self.num_heads = divide(self.total_num_heads, tp_size) if tp_size >= self.total_num_kv_heads: self.num_kv_heads = 1 @@ -701,6 +782,7 @@ def __init__( self.num_kv_heads * self.head_size * tp_size, # k_proj self.num_kv_heads * self.head_size * tp_size, # v_proj ] + self.use_presharded_weights = load_presharded_attn super().__init__( input_size=input_size, @@ -711,6 +793,9 @@ def __init__( params_dtype=params_dtype, quant_config=quant_config, prefix=prefix, + tp_rank=tp_rank, + tp_size=tp_size, + use_presharded_weights=self.use_presharded_weights, ) def _get_shard_offset_mapping(self, loaded_shard_id: str): @@ -769,9 +854,10 @@ def _load_fused_module_from_checkpoint( shard_size=shard_size, shard_offset=shard_offset ) - loaded_weight_shard = loaded_weight.narrow( - param.output_dim, shard_offset, shard_size - ) + if not self.use_presharded_weights: + loaded_weight_shard = loaded_weight.narrow( + param.output_dim, shard_offset, shard_size + ) self.weight_loader_v2(param, loaded_weight_shard, shard_id) def weight_loader_v2( @@ -787,6 +873,7 @@ def weight_loader_v2( elif type(param) in (RowvLLMParameter, BasevLLMParameter): param.load_qkv_weight(loaded_weight=loaded_weight) return + # TODO: @dsikka - move to parameter.py self._load_fused_module_from_checkpoint(param, loaded_weight) return @@ -795,12 +882,20 @@ def weight_loader_v2( shard_offset = self._get_shard_offset_mapping(loaded_shard_id) shard_size = self._get_shard_size_mapping(loaded_shard_id) + if isinstance(param, BlockQuantScaleParameter): + weight_block_size = self.quant_method.quant_config.weight_block_size + block_n, _ = weight_block_size[0], weight_block_size[1] + shard_offset = (shard_offset + block_n - 1) // block_n + shard_size = (shard_size + block_n - 1) // block_n + param.load_qkv_weight( loaded_weight=loaded_weight, num_heads=self.num_kv_head_replicas, shard_id=loaded_shard_id, shard_offset=shard_offset, shard_size=shard_size, + tp_rank=self.tp_rank, + use_presharded_weights=self.use_presharded_weights, ) def weight_loader( @@ -821,12 +916,9 @@ def weight_loader( return if is_gguf_weight: - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - output_dim = getattr(param, "output_dim", None) - shard_size = loaded_weight.size(output_dim) // tp_size - start_idx = tp_rank * shard_size + shard_size = loaded_weight.size(output_dim) // self.tp_size + start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) @@ -909,13 +1001,13 @@ def weight_loader( param, orig_qkv_offsets, shard_id ) - loaded_weight_shard = loaded_weight.narrow( - output_dim, shard_offset, shard_size - ) + if not self.use_presharded_weights: + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size + ) self.weight_loader(param, loaded_weight_shard, shard_id) return - tp_rank = get_tensor_model_parallel_rank() assert loaded_shard_id in ["q", "k", "v"] # If output dim is defined, use the default loading process. @@ -965,14 +1057,14 @@ def weight_loader( param_data = param_data.narrow(output_dim, shard_offset, shard_size) if loaded_shard_id == "q": - shard_id = tp_rank + shard_id = self.tp_rank else: - shard_id = tp_rank // self.num_kv_head_replicas + shard_id = self.tp_rank // self.num_kv_head_replicas start_idx = shard_id * shard_size # bitsandbytes loads the weights of the specific portion # no need to narrow here - if not use_bitsandbytes_4bit: + if not use_bitsandbytes_4bit and not self.use_presharded_weights: loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) # Special case for for AQLM codebooks. @@ -1036,6 +1128,9 @@ def __init__( reduce_results: bool = True, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + tp_rank: Optional[int] = None, + tp_size: Optional[int] = None, + use_presharded_weights: bool = False, ): super().__init__( input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix @@ -1045,10 +1140,14 @@ def __init__( self.reduce_results = reduce_results # Divide the weight matrix along the last dimension. - self.tp_rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() + if tp_rank is None: + tp_rank = get_tensor_model_parallel_rank() + if tp_size is None: + tp_size = get_tensor_model_parallel_world_size() + self.tp_rank, self.tp_size = tp_rank, tp_size self.input_size_per_partition = divide(input_size, self.tp_size) assert self.quant_method is not None + self.use_presharded_weights = use_presharded_weights self.quant_method.create_weights( layer=self, @@ -1082,8 +1181,6 @@ def __init__( self.register_parameter("bias", None) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() input_dim = getattr(param, "input_dim", None) use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) @@ -1097,15 +1194,19 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): if is_gguf_weight and isinstance(param, UninitializedParameter): weight_shape = list(loaded_weight.shape) if input_dim: - weight_shape[input_dim] = weight_shape[input_dim] // tp_size + weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) param_data = param.data # bitsandbytes loads the weights of the specific portion # no need to narrow here - if input_dim is not None and not use_bitsandbytes_4bit: + if ( + input_dim is not None + and not use_bitsandbytes_4bit + and not self.use_presharded_weights + ): shard_size = param_data.shape[input_dim] - start_idx = tp_rank * shard_size + start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) # Special case for loading scales off disk, which often do not @@ -1124,17 +1225,27 @@ def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor assert loaded_weight.numel() == 1 loaded_weight = loaded_weight.reshape(1) - param.load_row_parallel_weight(loaded_weight=loaded_weight) + if isinstance(param, BasevLLMParameter): + # This `BasevLLMParameter` is defined in sglang/srt/layers/parameter.py, + # It supports additional parameters like tp_rank and use_presharded_weights. + param.load_row_parallel_weight( + loaded_weight, + tp_rank=self.tp_rank, + use_presharded_weights=self.use_presharded_weights, + ) + else: + # `params` is defined in `vllm/model_executor/parameter.py`, + # It does not support additional parameters. + param.load_row_parallel_weight(loaded_weight) def forward(self, input_): if self.input_is_parallel: input_parallel = input_ else: - tp_rank = get_tensor_model_parallel_rank() splitted_input = split_tensor_along_last_dim( input_, num_partitions=self.tp_size ) - input_parallel = splitted_input[tp_rank].contiguous() + input_parallel = splitted_input[self.tp_rank].contiguous() # Matrix multiply. assert self.quant_method is not None diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 274c4c311ec..08ee5a3509b 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -14,151 +14,115 @@ """Logits processing.""" import dataclasses +import logging from typing import List, Optional, Union import torch +import triton +import triton.language as tl from torch import nn -from vllm.distributed import ( + +from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, ) - from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, +) + +logger = logging.getLogger(__name__) @dataclasses.dataclass class LogitsProcessorOutput: + ## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor # The logits of the next tokens. shape: [#seq, vocab_size] next_token_logits: torch.Tensor - # The logprobs of the next tokens. shape: [#seq, vocab_size] - next_token_logprobs: torch.Tensor = None - - # The normlaized logprobs of prompts. shape: [#seq] - normalized_prompt_logprobs: torch.Tensor = None - # The logprobs of input tokens. shape: [#token, vocab_size] + # Used by speculative decoding (EAGLE) + # The last hidden layers + hidden_states: Optional[torch.Tensor] = None + + ## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler + # The logprobs of the next tokens. shape: [#seq] + next_token_logprobs: Optional[torch.Tensor] = None + # The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k] + next_token_top_logprobs_val: Optional[List] = None + next_token_top_logprobs_idx: Optional[List] = None + + ## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor + # The logprobs of input tokens. shape: [#token] input_token_logprobs: torch.Tensor = None - - # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id) - input_top_logprobs: List = None - # The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id) - output_top_logprobs: List = None + # The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k] + input_top_logprobs_val: List = None + input_top_logprobs_idx: List = None @dataclasses.dataclass class LogitsMetadata: forward_mode: ForwardMode - top_logprobs_nums: Optional[List[int]] - - return_logprob: bool = False - return_top_logprob: bool = False + capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL + extend_return_logprob: bool = False + extend_return_top_logprob: bool = False extend_seq_lens: Optional[torch.Tensor] = None extend_seq_lens_cpu: Optional[List[int]] = None - extend_logprob_start_lens_cpu: Optional[List[int]] = None extend_logprob_pruned_lens_cpu: Optional[List[int]] = None + top_logprobs_nums: Optional[List[int]] = None @classmethod def from_forward_batch(cls, forward_batch: ForwardBatch): - extend_logprob_pruned_lens_cpu = None - - if forward_batch.return_logprob: - return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums) - if forward_batch.forward_mode.is_extend(): - extend_logprob_pruned_lens_cpu = [ - extend_len - start_len - for extend_len, start_len in zip( - forward_batch.extend_seq_lens_cpu, - forward_batch.extend_logprob_start_lens_cpu, - ) - ] + if forward_batch.forward_mode.is_extend() and forward_batch.return_logprob: + extend_return_logprob = True + extend_return_top_logprob = any( + x > 0 for x in forward_batch.top_logprobs_nums + ) + extend_logprob_pruned_lens_cpu = [ + extend_len - start_len + for extend_len, start_len in zip( + forward_batch.extend_seq_lens_cpu, + forward_batch.extend_logprob_start_lens_cpu, + ) + ] else: - return_top_logprob = False + extend_return_logprob = extend_return_top_logprob = ( + extend_logprob_pruned_lens_cpu + ) = False return cls( forward_mode=forward_batch.forward_mode, - top_logprobs_nums=forward_batch.top_logprobs_nums, - return_logprob=forward_batch.return_logprob, - return_top_logprob=return_top_logprob, + capture_hidden_mode=forward_batch.capture_hidden_mode, + extend_return_logprob=extend_return_logprob, + extend_return_top_logprob=extend_return_top_logprob, extend_seq_lens=forward_batch.extend_seq_lens, extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu, extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu, extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu, + top_logprobs_nums=forward_batch.top_logprobs_nums, ) class LogitsProcessor(nn.Module): - def __init__(self, config, skip_all_gather: bool = False): + def __init__( + self, config, skip_all_gather: bool = False, logit_scale: Optional[float] = None + ): super().__init__() self.config = config + self.logit_scale = logit_scale self.do_tensor_parallel_all_gather = ( not skip_all_gather and get_tensor_model_parallel_world_size() > 1 ) - - def _get_normalized_prompt_logprobs( - self, - input_token_logprobs: torch.Tensor, - logits_metadata: LogitsMetadata, - ): - logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32) - pruned_lens = torch.tensor( - logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda" + self.final_logit_softcapping = getattr( + self.config, "final_logit_softcapping", None ) - - start = torch.zeros_like(pruned_lens) - start[1:] = torch.cumsum(pruned_lens[:-1], dim=0) - end = torch.clamp( - start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1 - ) - sum_logp = ( - logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start] - ) - normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1) - return normalized_prompt_logprobs - - @staticmethod - def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata): - max_k = max(logits_metadata.top_logprobs_nums) - ret = all_logprobs.topk(max_k, dim=1) - values = ret.values.tolist() - indices = ret.indices.tolist() - - if logits_metadata.forward_mode.is_decode(): - output_top_logprobs = [] - for i, k in enumerate(logits_metadata.top_logprobs_nums): - output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k]))) - return None, output_top_logprobs - else: - input_top_logprobs, output_top_logprobs = [], [] - - pt = 0 - for k, pruned_len in zip( - logits_metadata.top_logprobs_nums, - logits_metadata.extend_logprob_pruned_lens_cpu, - ): - if pruned_len <= 0: - input_top_logprobs.append([]) - output_top_logprobs.append([]) - continue - - input_top_logprobs.append( - [ - list(zip(values[pt + j][:k], indices[pt + j][:k])) - for j in range(pruned_len - 1) - ] - ) - output_top_logprobs.append( - list( - zip( - values[pt + pruned_len - 1][:k], - indices[pt + pruned_len - 1][:k], - ) - ) - ) - pt += pruned_len - - return input_top_logprobs, output_top_logprobs + if ( + self.final_logit_softcapping is not None + and self.final_logit_softcapping < 0 + ): + self.final_logit_softcapping = None def forward( self, @@ -166,160 +130,202 @@ def forward( hidden_states, lm_head: VocabParallelEmbedding, logits_metadata: Union[LogitsMetadata, ForwardBatch], - ): + ) -> LogitsProcessorOutput: if isinstance(logits_metadata, ForwardBatch): logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata) - assert isinstance(logits_metadata, LogitsMetadata) # Get the last hidden states and last logits for the next token prediction - if logits_metadata.forward_mode.is_decode(): - last_index = None - last_hidden = hidden_states - else: + if ( + logits_metadata.forward_mode.is_decode_or_idle() + or logits_metadata.forward_mode.is_target_verify() + ): + pruned_states = hidden_states + sample_indices = None + elif ( + logits_metadata.forward_mode.is_extend() + and not logits_metadata.extend_return_logprob + ): + # Prefill without input logprobs. last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 - last_hidden = hidden_states[last_index] - - last_logits = self._get_logits(last_hidden, lm_head) - if self.do_tensor_parallel_all_gather: - last_logits = tensor_model_parallel_all_gather(last_logits) - last_logits = last_logits[:, : self.config.vocab_size].float() - - if hasattr(self.config, "final_logit_softcapping"): - last_logits.div_(self.config.final_logit_softcapping) - torch.tanh(last_logits, out=last_logits) - last_logits.mul_(self.config.final_logit_softcapping) + pruned_states = hidden_states[last_index] + sample_indices = None + else: + # Slice the requested tokens to compute logprob + sample_index_pt = -1 + sample_indices = [] + pt, pruned_states, pruned_input_ids = 0, [], [] + for start_len, extend_len in zip( + logits_metadata.extend_logprob_start_lens_cpu, + logits_metadata.extend_seq_lens_cpu, + ): + pruned_states.append(hidden_states[pt + start_len : pt + extend_len]) + sample_index_pt += extend_len - start_len + sample_indices.append(sample_index_pt) + pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len]) + pt += extend_len + + pruned_states = torch.cat(pruned_states) + + # Compute logits for both input and sampled tokens. + logits = self._get_logits(pruned_states, lm_head, logits_metadata) + sampled_logits = ( + logits[sample_indices] if sample_indices is not None else logits + ) - # Return only last_logits if logprob is not requested - if not logits_metadata.return_logprob: + if ( + not logits_metadata.extend_return_logprob + or logits_metadata.capture_hidden_mode.need_capture() + ): + # Decode mode or extend mode without return_logprob. return LogitsProcessorOutput( - next_token_logits=last_logits, - next_token_logprobs=None, - normalized_prompt_logprobs=None, - input_token_logprobs=None, - input_top_logprobs=None, - output_top_logprobs=None, + next_token_logits=sampled_logits, + hidden_states=( + hidden_states + if logits_metadata.capture_hidden_mode.is_full() + else ( + pruned_states + if logits_metadata.capture_hidden_mode.is_last() + else None + ) + ), ) else: - last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1) - - if logits_metadata.forward_mode.is_decode(): - if logits_metadata.return_top_logprob: - output_top_logprobs = self.get_top_logprobs( - last_logprobs, logits_metadata - )[1] - else: - output_top_logprobs = None - return LogitsProcessorOutput( - next_token_logits=last_logits, - next_token_logprobs=last_logprobs, - normalized_prompt_logprobs=None, - input_token_logprobs=None, - input_top_logprobs=None, - output_top_logprobs=output_top_logprobs, - ) + input_logprobs = logits + del hidden_states, logits + + # Normalize the logprob w/o temperature, top-p + input_logprobs = self.compute_temp_top_p_normalized_logprobs( + input_logprobs, logits_metadata + ) + + # Get the logprob of top-k tokens + if logits_metadata.extend_return_top_logprob: + ( + input_top_logprobs_val, + input_top_logprobs_idx, + ) = self.get_top_logprobs(input_logprobs, logits_metadata) else: - # Slice the requested tokens to compute logprob - pt, states, pruned_input_ids = 0, [], [] - for start_len, extend_len in zip( - logits_metadata.extend_logprob_start_lens_cpu, - logits_metadata.extend_seq_lens_cpu, - ): - states.append(hidden_states[pt + start_len : pt + extend_len]) - pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len]) - pt += extend_len - - # Compute the logits and logprobs for all required tokens - states = torch.cat(states, dim=0) - all_logits = self._get_logits(states, lm_head) - if self.do_tensor_parallel_all_gather: - all_logits = tensor_model_parallel_all_gather(all_logits) - all_logits = all_logits[:, : self.config.vocab_size].float() - - if hasattr(self.config, "final_logit_softcapping"): - all_logits.div_(self.config.final_logit_softcapping) - torch.tanh(all_logits, out=all_logits) - all_logits.mul_(self.config.final_logit_softcapping) - - all_logprobs = all_logits - del all_logits, hidden_states - all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1) - - # Get the logprob of top-k tokens - if logits_metadata.return_top_logprob: - input_top_logprobs, output_top_logprobs = self.get_top_logprobs( - all_logprobs, logits_metadata - ) - else: - input_top_logprobs = output_top_logprobs = None - - # Compute the normalized logprobs for the requested tokens. - # Note that we pad a zero at the end for easy batching. - input_token_logprobs = all_logprobs[ - torch.arange(all_logprobs.shape[0], device="cuda"), - torch.cat( - [ - torch.cat(pruned_input_ids)[1:], - torch.tensor([0], device="cuda"), - ] - ), - ] - normalized_prompt_logprobs = self._get_normalized_prompt_logprobs( - input_token_logprobs, - logits_metadata, - ) + input_top_logprobs_val = input_top_logprobs_idx = None - return LogitsProcessorOutput( - next_token_logits=last_logits, - next_token_logprobs=last_logprobs, - normalized_prompt_logprobs=normalized_prompt_logprobs, - input_token_logprobs=input_token_logprobs, - input_top_logprobs=input_top_logprobs, - output_top_logprobs=output_top_logprobs, - ) + input_token_logprobs = input_logprobs[ + torch.arange(input_logprobs.shape[0], device=input_logprobs.device), + torch.cat( + [ + torch.cat(pruned_input_ids)[1:], + torch.tensor([0], device=input_logprobs.device), + ] + ), + ] + + return LogitsProcessorOutput( + next_token_logits=sampled_logits, + input_token_logprobs=input_token_logprobs, + input_top_logprobs_val=input_top_logprobs_val, + input_top_logprobs_idx=input_top_logprobs_idx, + ) def _get_logits( self, hidden_states: torch.Tensor, lm_head: VocabParallelEmbedding, + logits_metadata: LogitsMetadata, embedding_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: + """Get logits from hidden_states.""" + if hasattr(lm_head, "weight"): logits = torch.matmul(hidden_states, lm_head.weight.T) else: # GGUF models logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias) + + if self.logit_scale is not None: + logits.mul_(self.logit_scale) + + if self.do_tensor_parallel_all_gather: + logits = tensor_model_parallel_all_gather(logits) + + logits = logits[:, : self.config.vocab_size].float() + + if self.final_logit_softcapping: + fused_softcap(logits, self.final_logit_softcapping) + return logits + @staticmethod + def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata): + max_k = max(logits_metadata.top_logprobs_nums) + ret = all_logprobs.topk(max_k, dim=1) + values = ret.values.tolist() + indices = ret.indices.tolist() + + input_top_logprobs_val, input_top_logprobs_idx = [], [] + + pt = 0 + for k, pruned_len in zip( + logits_metadata.top_logprobs_nums, + logits_metadata.extend_logprob_pruned_lens_cpu, + ): + if pruned_len <= 0: + input_top_logprobs_val.append([]) + input_top_logprobs_idx.append([]) + continue + + input_top_logprobs_val.append( + [values[pt + j][:k] for j in range(pruned_len - 1)] + ) + input_top_logprobs_idx.append( + [indices[pt + j][:k] for j in range(pruned_len - 1)] + ) + pt += pruned_len + + return input_top_logprobs_val, input_top_logprobs_idx + + @staticmethod + def compute_temp_top_p_normalized_logprobs( + last_logits: torch.Tensor, logits_metadata: LogitsMetadata + ) -> torch.Tensor: + # TODO: Implement the temp and top-p normalization + return torch.nn.functional.log_softmax(last_logits, dim=-1) + + +@triton.jit +def fused_softcap_kernel( + full_logits_ptr, + softcapping_value, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0).to(tl.int64) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + + # Load values + x = tl.load(full_logits_ptr + offsets, mask=mask) + + # Perform operations in-place + x = x / softcapping_value + + # Manual tanh implementation using exp + exp2x = tl.exp(2 * x) + x = (exp2x - 1) / (exp2x + 1) + + x = x * softcapping_value + + # Store result + tl.store(full_logits_ptr + offsets, x, mask=mask) + + +def fused_softcap(full_logits, final_logit_softcapping): + n_elements = full_logits.numel() + BLOCK_SIZE = 1024 + grid = ((n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE, 1, 1) -def test(): - all_logprobs = torch.tensor( - # s s s - [[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]], - dtype=torch.float32, - device="cuda", + fused_softcap_kernel[grid]( + full_logits_ptr=full_logits, + softcapping_value=final_logit_softcapping, + n_elements=n_elements, + BLOCK_SIZE=BLOCK_SIZE, ) - seq_lens = torch.tensor([2, 0, 3, 0], dtype=torch.int32, device="cuda") - input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda") - - token_logprobs = all_logprobs[ - torch.arange(all_logprobs.shape[0], device="cuda"), - torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]), - ] - logprobs_cumsum = torch.cumsum(token_logprobs, dim=0, dtype=torch.float32) - - len_cumsum = torch.cumsum(seq_lens, dim=0) - start = torch.cat((torch.tensor([0], device="cuda"), len_cumsum[:-1]), 0) - end = start + seq_lens - 2 - start.clamp_(min=0, max=token_logprobs.shape[0] - 1) - end.clamp_(min=0, max=token_logprobs.shape[0] - 1) - sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + token_logprobs[start] - - # assert logprobs == [2, _, 2, 4, _] - print("token logprobs", token_logprobs) - print("start", start) - print("end", end) - print("sum_logp", sum_logp) - - -if __name__ == "__main__": - test() + return full_logits diff --git a/python/sglang/srt/distributed/device_communicators/__init__.py b/python/sglang/srt/layers/moe/ep_moe/__init__.py similarity index 100% rename from python/sglang/srt/distributed/device_communicators/__init__.py rename to python/sglang/srt/layers/moe/ep_moe/__init__.py diff --git a/python/sglang/srt/layers/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py similarity index 100% rename from python/sglang/srt/layers/ep_moe/kernels.py rename to python/sglang/srt/layers/moe/ep_moe/kernels.py diff --git a/python/sglang/srt/layers/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py similarity index 91% rename from python/sglang/srt/layers/ep_moe/layer.py rename to python/sglang/srt/layers/moe/ep_moe/layer.py index eca119845a7..bc927621a84 100644 --- a/python/sglang/srt/layers/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -4,27 +4,27 @@ import torch from torch.nn import Module from vllm import _custom_ops as ops -from vllm.distributed import ( +from vllm.model_executor.custom_op import CustomOp + +from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod - from sglang.srt.layers.custom_op_util import register_custom_op -from sglang.srt.layers.ep_moe.kernels import ( +from sglang.srt.layers.moe.ep_moe.kernels import ( grouped_gemm_triton, post_reorder_triton_kernel, pre_reorder_triton_kernel, run_moe_ep_preproess, silu_and_mul_triton_kernel, ) -from sglang.srt.layers.fused_moe_triton.fused_moe import fused_topk, grouped_topk -from sglang.srt.layers.fused_moe_triton.layer import FusedMoEMethodBase +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase +from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) +from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from sglang.srt.utils import is_hip, set_weight_attrs logger = logging.getLogger(__name__) @@ -113,6 +113,9 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, prefix: str = "", + correction_bias: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + activation: str = "silu", ): super().__init__() @@ -138,6 +141,9 @@ def __init__( assert num_expert_group is not None and topk_group is not None self.num_expert_group = num_expert_group self.topk_group = topk_group + self.correction_bias = correction_bias + self.custom_routing_function = custom_routing_function + self.activation = activation if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod() @@ -164,19 +170,23 @@ def __init__( def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None + assert self.activation == "silu" if self.grouped_gemm_runner is None: self.grouped_gemm_runner = GroupedGemmRunner( hidden_states.device, use_flashinfer=False # TODO: use flashinfer ) - topk_weights, topk_ids = self.select_experts( - hidden_states, - router_logits, - self.top_k, - self.renormalize, - self.topk_group, - self.num_expert_group, + topk_weights, topk_ids = select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + use_grouped_topk=self.use_grouped_topk, + renormalize=self.renormalize, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + correction_bias=self.correction_bias, + custom_routing_function=self.custom_routing_function, ) reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess( @@ -250,16 +260,20 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): dtype=torch.float32, device=hidden_states.device, ) - silu_and_mul_triton_kernel[(gateup_output.shape[0],)]( - gateup_output, - down_input, - gateup_output.shape[1], - reorder_topk_ids, - self.w2_input_scale, - self.start_expert_id, - self.end_expert_id, - BLOCK_SIZE=512, - ) + + if self.activation == "silu": + silu_and_mul_triton_kernel[(gateup_output.shape[0],)]( + gateup_output, + down_input, + gateup_output.shape[1], + reorder_topk_ids, + self.w2_input_scale, + self.start_expert_id, + self.end_expert_id, + BLOCK_SIZE=512, + ) + else: + raise ValueError(f"Unsupported activation: {self.activation=}") # GroupGemm-1 down_output = torch.empty( @@ -297,35 +311,6 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): ) return output - def select_experts( - self, - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - ): - if self.use_grouped_topk: - assert topk_group is not None - assert num_expert_group is not None - topk_weights, topk_ids = grouped_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - num_expert_group=num_expert_group, - topk_group=topk_group, - ) - else: - topk_weights, topk_ids = fused_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - ) - return topk_weights, topk_ids.to(torch.int32) - @classmethod def make_expert_params_mapping( cls, @@ -334,7 +319,6 @@ def make_expert_params_mapping( ckpt_up_proj_name: str, num_experts: int, ) -> List[Tuple[str, str, int, str]]: - return [ # (param_name, weight_name, expert_id, shard_id) ( @@ -379,7 +363,6 @@ def weight_loader( ) return - expert_data = param.data[expert_id] if shard_id == "w2": param.data[expert_id] = loaded_weight elif shard_id == "w1": @@ -644,6 +627,10 @@ def process_weights_after_loading(self, layer: Module) -> None: "QuantConfig has static quantization, but found " "activation scales are None." ) + layer.w13_weight_scale = torch.nn.Parameter( + torch.max(layer.w13_weight_scale, dim=1).values, + requires_grad=False, + ) return def apply( diff --git a/python/sglang/srt/layers/moe/fused_moe_native.py b/python/sglang/srt/layers/moe/fused_moe_native.py new file mode 100644 index 00000000000..042c0a52c56 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_native.py @@ -0,0 +1,129 @@ +""" +Torch-native implementation for FusedMoE. This is used for torch.compile. +It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204 +""" + +from typing import Callable, Optional + +import torch +from torch.nn import functional as F + +from sglang.srt.layers.activation import GeluAndMul, SiluAndMul +from sglang.srt.layers.moe.topk import select_experts + + +def fused_moe_forward_native( + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", +) -> torch.Tensor: + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + correction_bias=correction_bias, + torch_native=True, + ) + + w13_weights = layer.w13_weight[topk_ids] + w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2) + w2_weights = layer.w2_weight[topk_ids] + x1 = torch.einsum("ti,taoi -> tao", x, w1_weights) + if activation == "silu": + x1 = F.silu(x1) + elif activation == "gelu": + x1 = F.gelu(x1) + else: + raise ValueError(f"Unsupported activation: {activation=}") + x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) + expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) + return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype)) + + +def moe_forward_native( + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", +) -> torch.Tensor: + + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + correction_bias=correction_bias, + torch_native=True, + ) + + # Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589 + len_experts = layer.num_experts + + cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts)) + cnts.scatter_(1, topk_ids.to(torch.int64), 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + + sorted_tokens = x[idxs // topk_ids.shape[1]] + tokens_per_expert = tokens_per_expert.cpu().numpy() + + if activation == "silu": + act = SiluAndMul() + elif activation == "gelu": + act = GeluAndMul() + else: + raise ValueError(f"Unsupported activation: {activation=}") + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + + layer_w13_weight = layer.w13_weight[i] + layer_w2_weight = layer.w2_weight[i] + + gate_up = F.linear(tokens_for_this_expert, layer_w13_weight) + gate_up = act(gate_up) + expert_out = F.linear(gate_up, layer_w2_weight) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + new_x = torch.empty_like(outs) + + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_ids.shape, -1) + .type(topk_weights.dtype) + .mul_(topk_weights.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return final_out diff --git a/python/sglang/srt/layers/fused_moe_triton/__init__.py b/python/sglang/srt/layers/moe/fused_moe_triton/__init__.py similarity index 72% rename from python/sglang/srt/layers/fused_moe_triton/__init__.py rename to python/sglang/srt/layers/moe/fused_moe_triton/__init__.py index b895b9e4836..b68961931d5 100644 --- a/python/sglang/srt/layers/fused_moe_triton/__init__.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/__init__.py @@ -1,14 +1,12 @@ from contextlib import contextmanager from typing import Any, Dict, Optional -import sglang.srt.layers.fused_moe_triton.fused_moe # noqa -from sglang.srt.layers.fused_moe_triton.fused_moe import ( +import sglang.srt.layers.moe.fused_moe_triton.fused_moe # noqa +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( fused_experts, - fused_topk, get_config_file_name, - grouped_topk, ) -from sglang.srt.layers.fused_moe_triton.layer import ( +from sglang.srt.layers.moe.fused_moe_triton.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, @@ -37,8 +35,6 @@ def get_config() -> Optional[Dict[str, Any]]: "override_config", "get_config", "fused_moe", - "fused_topk", "fused_experts", "get_config_file_name", - "grouped_topk", ] diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..2e692a1583a --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..a7be90051f8 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..6fcf408755f --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json new file mode 100644 index 00000000000..283ffd8ff1d --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 00000000000..e2f8164cc6b --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json similarity index 87% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json index d4c9ddd1297..ef6a0479cbd 100644 --- a/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json @@ -3,53 +3,53 @@ "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, - "num_warps": 4, + "GROUP_SIZE_M": 32, + "num_warps": 8, "num_stages": 3 }, "2": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 4 + "num_stages": 5 }, "4": { "BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 3 + "num_stages": 5 }, "8": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, "num_warps": 4, - "num_stages": 2 + "num_stages": 3 }, "16": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 16, + "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5 }, "24": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 2 + "num_stages": 3 }, "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, @@ -57,32 +57,32 @@ }, "48": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 3 + "num_stages": 2 }, "64": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 3 + "num_stages": 5 }, "96": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, - "num_stages": 3 + "num_stages": 2 }, "128": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3 @@ -115,7 +115,7 @@ "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4 }, @@ -131,7 +131,7 @@ "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 32, + "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 4 }, @@ -139,7 +139,7 @@ "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 4 } diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 00000000000..67bf2b720fe --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json new file mode 100644 index 00000000000..da71451a78b --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 00000000000..14c6f8c1a35 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 00000000000..2b974a78d39 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json new file mode 100644 index 00000000000..054873f38fc --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 00000000000..869f7512771 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 00000000000..40d85e43e87 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 00000000000..4ebd7a7816e --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json new file mode 100644 index 00000000000..90416662e49 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json new file mode 100644 index 00000000000..8a18afe7d6d --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 00000000000..96437b5ba2c --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json similarity index 90% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json index b8d3be2313f..6e60506948d 100644 --- a/python/sglang/srt/layers/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json @@ -1,11 +1,11 @@ { "1": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 4 + "num_warps": 4, + "num_stages": 3 }, "2": { "BLOCK_SIZE_M": 16, @@ -13,15 +13,15 @@ "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 4 + "num_stages": 3 }, "4": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 4 + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 }, "8": { "BLOCK_SIZE_M": 32, @@ -41,11 +41,11 @@ }, "24": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 64, - "num_warps": 4, - "num_stages": 3 + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 }, "32": { "BLOCK_SIZE_M": 16, @@ -60,28 +60,28 @@ "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 2 }, "64": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 3 + "num_warps": 8, + "num_stages": 2 }, "96": { "BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 2 }, "128": { "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, @@ -99,9 +99,9 @@ "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "GROUP_SIZE_M": 16, "num_warps": 4, - "num_stages": 3 + "num_stages": 4 }, "1024": { "BLOCK_SIZE_M": 128, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 00000000000..824009a5969 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json new file mode 100644 index 00000000000..bf2aa0d5f4d --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 00000000000..62fd21136d0 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json new file mode 100644 index 00000000000..da40db13280 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 00000000000..9d0db1cdc20 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json new file mode 100644 index 00000000000..1c906494ace --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 00000000000..430f50090f4 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json new file mode 100644 index 00000000000..c51565f67b4 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 00000000000..449a1338428 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json new file mode 100644 index 00000000000..d317994cb71 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_L40S.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_L40S.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_L40S.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_L40S.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json similarity index 87% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json index abf258e5641..bb17743b609 100644 --- a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json @@ -55,35 +55,35 @@ "kpack": 2 }, "128": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 0, - "waves_per_eu": 1, + "waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 1 }, "256": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 0, - "waves_per_eu": 1, + "waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 1 }, "512": { "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, - "num_warps": 4, + "num_warps": 8, "num_stages": 0, - "waves_per_eu": 2, + "waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 2 }, diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 00000000000..72c3f560be9 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json new file mode 100644 index 00000000000..dd07b3f6ee0 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 00000000000..a841518ca67 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json new file mode 100644 index 00000000000..13cc2cee1d2 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json similarity index 100% rename from python/sglang/srt/layers/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json new file mode 100644 index 00000000000..b50cfc13da0 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/fused_moe_triton/configs/README b/python/sglang/srt/layers/moe/fused_moe_triton/configs/README similarity index 84% rename from python/sglang/srt/layers/fused_moe_triton/configs/README rename to python/sglang/srt/layers/moe/fused_moe_triton/configs/README index 45d40cbfb1a..4aa527f2719 100644 --- a/python/sglang/srt/layers/fused_moe_triton/configs/README +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/README @@ -8,3 +8,5 @@ the JSON file contains a mapping from M (batch size) to the chosen configuration The example configurations provided are for the Mixtral model for TP2 on H100 and TP4 on A100. Mixtral has intermediate size N = 14336, i.e. for TP2 we have N = 7168 and for TP4 we have N = 3584. + +See `benchmark/kernels/fused_moe_triton/README.md` on how to generate these config files. diff --git a/python/sglang/srt/layers/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py similarity index 63% rename from python/sglang/srt/layers/fused_moe_triton/fused_moe.py rename to python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 4f92512b2d5..32c8fcbb625 100644 --- a/python/sglang/srt/layers/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -6,16 +6,34 @@ import json import logging import os -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import torch import triton import triton.language as tl from vllm import _custom_ops as ops -from sglang.srt.utils import direct_register_custom_op, get_device_name +from sglang.srt.layers.moe.topk import select_experts +from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 +from sglang.srt.utils import ( + direct_register_custom_op, + get_device_name, + is_cuda_available, + is_hip, +) + +is_cuda = is_cuda_available() +is_hip_flag = is_hip() +if is_cuda: + from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size + logger = logging.getLogger(__name__) +padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0 + +enable_moe_align_block_size_triton = bool( + int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0")) +) @triton.jit @@ -46,8 +64,14 @@ def fused_moe_kernel( stride_bn, stride_cm, stride_cn, + stride_asm, + stride_ask, stride_bse, + stride_bsk, stride_bsn, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -58,6 +82,7 @@ def fused_moe_kernel( compute_type: tl.constexpr, use_fp8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexpr, + even_Ks: tl.constexpr, ): """ Implements the fused computation for a Mixture of Experts (MOE) using @@ -130,8 +155,15 @@ def fused_moe_kernel( b_scale = tl.load(b_scale_ptrs) if use_fp8_w8a8: - a_scale = tl.load(a_scale_ptr) - b_scale = tl.load(b_scale_ptr + off_experts) + if group_k > 0 and group_n > 0: + a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm + offs_bsn = offs_bn // group_n + b_scale_ptrs = ( + b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn + ) + else: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) # ----------------------------------------------------------- # Iterate to compute a block of the C matrix. @@ -143,17 +175,36 @@ def fused_moe_kernel( for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): # Load the next block of A and B, generate a mask by checking the # K dimension. - a = tl.load( - a_ptrs, - mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), - other=0.0, - ) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + if even_Ks: + a = tl.load( + a_ptrs, + mask=token_mask[:, None], + other=0.0, + ) + b = tl.load(b_ptrs) + else: + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + # We accumulate along the K dimension. if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) elif use_fp8_w8a8: - accumulator = tl.dot(a, b, acc=accumulator) + if group_k > 0 and group_n > 0: + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_scale = tl.load( + a_scale_ptrs + offs_ks * stride_ask, mask=token_mask, other=0.0 + ) + b_scale = tl.load(b_scale_ptrs + offs_ks * stride_bsk) + + accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :] + else: + accumulator = tl.dot(a, b, acc=accumulator) else: accumulator += tl.dot(a, b) # Advance the ptrs to the next K block. @@ -166,7 +217,10 @@ def fused_moe_kernel( if use_int8_w8a16: accumulator = (accumulator * b_scale).to(compute_type) elif use_fp8_w8a8: - accumulator = (accumulator * a_scale * b_scale).to(compute_type) + if group_k > 0 and group_n > 0: + accumulator = accumulator.to(compute_type) + else: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) else: accumulator = accumulator.to(compute_type) # ----------------------------------------------------------- @@ -177,6 +231,139 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) +def ceil_div(a, b): + return (a + b - 1) // b + + +@triton.jit +def moe_align_block_size_stage1( + topk_ids_ptr, + tokens_cnts_ptr, + num_experts: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + + start_idx = pid * tokens_per_thread + + off_c = (pid + 1) * num_experts + + for i in range(tokens_per_thread): + if start_idx + i < numel: + idx = tl.load(topk_ids_ptr + start_idx + i) + token_cnt = tl.load(tokens_cnts_ptr + off_c + idx) + tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1) + + +@triton.jit +def moe_align_block_size_stage2( + tokens_cnts_ptr, + num_experts: tl.constexpr, +): + pid = tl.program_id(0) + + last_cnt = 0 + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) + last_cnt = last_cnt + token_cnt + tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) + + +@triton.jit +def moe_align_block_size_stage3( + total_tokens_post_pad_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, +): + last_cumsum = 0 + off_cnt = num_experts * num_experts + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1) + last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size + tl.store(cumsum_ptr + i, last_cumsum) + tl.store(total_tokens_post_pad_ptr, last_cumsum) + + +@triton.jit +def moe_align_block_size_stage4( + topk_ids_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + start_idx = tl.load(cumsum_ptr + pid) + end_idx = tl.load(cumsum_ptr + pid + 1) + + for i in range(start_idx, end_idx, block_size): + tl.store(expert_ids_ptr + i // block_size, pid) + + start_idx = pid * tokens_per_thread + off_t = pid * num_experts + + for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)): + expert_id = tl.load(topk_ids_ptr + i) + token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id) + rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id) + tl.store(sorted_token_ids_ptr + rank_post_pad, i) + tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1) + + +def moe_align_block_size_triton( + topk_ids: torch.Tensor, + num_experts: int, + block_size: int, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + numel = topk_ids.numel() + grid = (num_experts,) + tokens_cnts = torch.zeros( + (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device + ) + cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device) + tokens_per_thread = ceil_div(numel, num_experts) + + moe_align_block_size_stage1[grid]( + topk_ids, + tokens_cnts, + num_experts, + numel, + tokens_per_thread, + ) + moe_align_block_size_stage2[grid]( + tokens_cnts, + num_experts, + ) + moe_align_block_size_stage3[(1,)]( + num_tokens_post_pad, + tokens_cnts, + cumsum, + num_experts, + block_size, + ) + moe_align_block_size_stage4[grid]( + topk_ids, + sorted_token_ids, + expert_ids, + tokens_cnts, + cumsum, + num_experts, + block_size, + numel, + tokens_per_thread, + ) + + def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -227,9 +414,45 @@ def moe_align_block_size( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - ops.moe_align_block_size( - topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad - ) + if num_experts >= 224: + if enable_moe_align_block_size_triton or is_hip_flag: + moe_align_block_size_triton( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + token_cnts_buffer = torch.empty( + (num_experts + 1) * num_experts, + dtype=torch.int32, + device=topk_ids.device, + ) + cumsum_buffer = torch.empty( + num_experts + 1, dtype=torch.int32, device=topk_ids.device + ) + + sgl_moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + token_cnts_buffer, + cumsum_buffer, + ) + else: + ops.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) return sorted_ids, expert_ids, num_tokens_post_pad @@ -250,13 +473,24 @@ def invoke_fused_moe_kernel( compute_type: tl.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + block_shape: Optional[List[int]] = None, ) -> None: assert topk_weights.stride(1) == 1 assert sorted_token_ids.stride(0) == 1 + padded_size = 0 if use_fp8_w8a8: - A, A_scale = ops.scaled_fp8_quant(A, A_scale) assert B_scale is not None + if block_shape is None: + padded_size = padding_size + A, A_scale = ops.scaled_fp8_quant(A, A_scale) + else: + assert len(block_shape) == 2 + block_n, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_fp8(A, block_k) + assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] + assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] elif use_int8_w8a16: assert B_scale is not None else: @@ -268,6 +502,12 @@ def invoke_fused_moe_kernel( * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) + K = B.shape[2] - padded_size + if K % config["BLOCK_SIZE_K"] == 0: + even_Ks = True + else: + even_Ks = False + fused_moe_kernel[grid]( A, B, @@ -279,7 +519,7 @@ def invoke_fused_moe_kernel( expert_ids, num_tokens_post_padded, B.shape[1], - B.shape[2], + B.shape[2] - padded_size, sorted_token_ids.shape[0], topk_ids.numel(), A.stride(0), @@ -289,25 +529,42 @@ def invoke_fused_moe_kernel( B.stride(1), C.stride(1), C.stride(2), - B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0, - B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0, + A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0, + A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0, + B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0, + B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0, + B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0, + 0 if block_shape is None else block_shape[0], + 0 if block_shape is None else block_shape[1], MUL_ROUTED_WEIGHT=mul_routed_weight, top_k=top_k, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + even_Ks=even_Ks, **config, ) -def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: +def get_config_file_name( + E: int, N: int, dtype: Optional[str], block_shape: Optional[int] = None +) -> str: device_name = get_device_name().replace(" ", "_") dtype_selector = "" if not dtype else f",dtype={dtype}" - return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" + block_shape_selector = ( + "" if not block_shape or not all(block_shape) else f",block_shape={block_shape}" + ) + return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" @functools.lru_cache -def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]: +def get_moe_configs( + E: int, + N: int, + dtype: Optional[str], + block_n: Optional[int] = 0, + block_k: Optional[int] = 0, +) -> Optional[Dict[int, Any]]: """ Return optimized configurations for the fused MoE kernel. @@ -319,7 +576,7 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, # First look up if an optimized configuration is available in the configs # directory - json_file_name = get_config_file_name(E, N, dtype) + json_file_name = get_config_file_name(E, N, dtype, [block_n, block_k]) config_file_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name @@ -350,21 +607,52 @@ def get_default_config( topk: int, dtype: Optional[str], is_marlin: bool, + block_shape: Optional[List[int]] = None, ) -> Dict[str, int]: - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - } - # A heuristic: fused marlin works faster with this config for small M - if M <= E or (is_marlin and M <= 32): + if dtype == "fp8_w8a8": + if block_shape is None: + config = { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 2 if is_hip_flag else 4, + } + if M <= E: + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 if is_hip_flag else 4, + } + else: + # Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1] + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_shape[0], + "BLOCK_SIZE_K": block_shape[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 if is_hip_flag else 3, + } + else: config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, } + # A heuristic: fused marlin works faster with this config for small M + if M <= E or (is_marlin and M <= 32): + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + } return config @@ -375,8 +663,9 @@ def try_get_optimal_moe_config( dtype: Optional[str], M: int, is_marlin: bool = False, + block_shape: Optional[List[int]] = None, ): - from sglang.srt.layers.fused_moe_triton import get_config + from sglang.srt.layers.moe.fused_moe_triton import get_config override_config = get_config() if override_config: @@ -384,7 +673,9 @@ def try_get_optimal_moe_config( else: # First try to load optimal config from the file E, _, N = w2_shape - configs = get_moe_configs(E, N, dtype) + block_n = block_shape[0] if block_shape else 0 + block_k = block_shape[1] if block_shape else 0 + configs = get_moe_configs(E, N, dtype, block_n, block_k) if configs: # If an optimal configuration map has been found, look up the @@ -392,78 +683,12 @@ def try_get_optimal_moe_config( config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: # Else use the default config - config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin) + config = get_default_config( + M, E, N, w1_shape[2], top_k, dtype, is_marlin, block_shape + ) return config -def fused_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, -): - assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - - M, _ = hidden_states.shape - - topk_weights = torch.empty( - M, topk, dtype=torch.float32, device=hidden_states.device - ) - topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) - token_expert_indicies = torch.empty( - M, topk, dtype=torch.int32, device=hidden_states.device - ) - - ops.topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - gating_output.float(), # TODO(woosuk): Optimize this. - ) - del token_expert_indicies # Not used. Will be used in the future. - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - - return topk_weights, topk_ids - - -# This is used by the Deepseek-V2 model -def grouped_topk( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, - num_expert_group: int = 0, - topk_group: int = 0, -): - - assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" - - scores = torch.softmax(gating_output, dim=-1) - num_token = scores.shape[0] - group_scores = ( - scores.view(num_token, num_expert_group, -1).max(dim=-1).values - ) # [n, n_group] - group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ - 1 - ] # [n, top_k_group] - group_mask = torch.zeros_like(group_scores) # [n, n_group] - group_mask.scatter_(1, group_idx, 1) # [n, n_group] - score_mask = ( - group_mask.unsqueeze(-1) - .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) - .reshape(num_token, -1) - ) # [n, e] - tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] - topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) - - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - - return topk_weights.to(torch.float32), topk_ids.to(torch.int32) - - def get_config_dtype_str( dtype: torch.dtype, use_int8_w8a16: Optional[bool] = False, @@ -486,12 +711,14 @@ def inplace_fused_experts( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ) -> None: fused_experts_impl( hidden_states, @@ -500,12 +727,14 @@ def inplace_fused_experts( topk_weights, topk_ids, True, + activation, use_fp8_w8a8, use_int8_w8a16, w1_scale, w2_scale, a1_scale, a2_scale, + block_shape, ) @@ -515,12 +744,14 @@ def inplace_fused_experts_fake( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ) -> None: pass @@ -539,12 +770,14 @@ def outplace_fused_experts( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ) -> torch.Tensor: return fused_experts_impl( hidden_states, @@ -553,12 +786,14 @@ def outplace_fused_experts( topk_weights, topk_ids, False, + activation, use_fp8_w8a8, use_int8_w8a16, w1_scale, w2_scale, a1_scale, a2_scale, + block_shape, ) @@ -568,12 +803,14 @@ def outplace_fused_experts_fake( w2: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -593,12 +830,14 @@ def fused_experts( topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ): if inplace: torch.ops.sglang.inplace_fused_experts( @@ -607,12 +846,14 @@ def fused_experts( w2, topk_weights, topk_ids, + activation, use_fp8_w8a8, use_int8_w8a16, w1_scale, w2_scale, a1_scale, a2_scale, + block_shape, ) return hidden_states else: @@ -622,12 +863,14 @@ def fused_experts( w2, topk_weights, topk_ids, + activation, use_fp8_w8a8, use_int8_w8a16, w1_scale, w2_scale, a1_scale, a2_scale, + block_shape, ) @@ -638,15 +881,21 @@ def fused_experts_impl( topk_weights: torch.Tensor, topk_ids: torch.Tensor, inplace: bool = False, + activation: str = "silu", use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ): + padded_size = padding_size + if not use_fp8_w8a8 or block_shape is not None: + padded_size = 0 + # Check constraints. - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert hidden_states.shape[1] == w1.shape[2] - padded_size, "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous" @@ -668,9 +917,10 @@ def fused_experts_impl( get_config_func = functools.partial( try_get_optimal_moe_config, w1.shape, - w2.shape, + (w2.shape[0], w2.shape[1], w2.shape[2] - padded_size), topk_ids.shape[1], config_dtype, + block_shape=block_shape, ) config = get_config_func(M) @@ -743,9 +993,15 @@ def fused_experts_impl( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + block_shape=block_shape, ) - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + if activation == "silu": + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + elif activation == "gelu": + ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + else: + raise ValueError(f"Unsupported activation: {activation=}") invoke_fused_moe_kernel( intermediate_cache2, @@ -764,13 +1020,32 @@ def fused_experts_impl( compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, + block_shape=block_shape, ) - torch.sum( - intermediate_cache3.view(*intermediate_cache3.shape), - dim=1, - out=out_hidden_states[begin_chunk_idx:end_chunk_idx], - ) + if is_hip_flag: + ops.moe_sum( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx], + ) + else: + if topk_ids.shape[1] == 1: + out_hidden_states[begin_chunk_idx:end_chunk_idx].copy_( + intermediate_cache3[:, 0] + ) + elif topk_ids.shape[1] == 2: + torch.add( + intermediate_cache3[:, 0], + intermediate_cache3[:, 1], + out=out_hidden_states[begin_chunk_idx:end_chunk_idx], + ).squeeze(dim=1) + elif topk_ids.shape[1] > 2: + torch.sum( + intermediate_cache3.view(*intermediate_cache3.shape), + dim=1, + out=out_hidden_states[begin_chunk_idx:end_chunk_idx], + ) + return out_hidden_states @@ -782,6 +1057,7 @@ def fused_moe( topk: int, renormalize: bool, inplace: bool = False, + activation: str = "silu", use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, @@ -792,6 +1068,7 @@ def fused_moe( w2_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, + block_shape: Optional[List[int]] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -819,6 +1096,12 @@ def fused_moe( w1. - w2_scale (Optional[torch.Tensor]): Optional scale to be used for w2. + - a1_scale (Optional[torch.Tensor]): Optional scale to be used for + a1. + - a2_scale (Optional[torch.Tensor]): Optional scale to be used for + a2. + - block_shape: (Optional[List[int]]): Optional block size for block-wise + quantization. Returns: - torch.Tensor: The output tensor after applying the MoE layer. @@ -826,24 +1109,16 @@ def fused_moe( # Check constraints. assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - if use_grouped_topk: - assert num_expert_group is not None and topk_group is not None - topk_weights, topk_ids = grouped_topk( - hidden_states, - gating_output, - topk, - renormalize, - num_expert_group, - topk_group, - ) - elif custom_routing_function is None: - topk_weights, topk_ids = fused_topk( - hidden_states, gating_output, topk, renormalize - ) - else: - topk_weights, topk_ids = custom_routing_function( - hidden_states, gating_output, topk, renormalize - ) + topk_weights, topk_ids = select_experts( + hidden_states=hidden_states, + router_logits=gating_output, + use_grouped_topk=use_grouped_topk, + top_k=topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + ) return fused_experts( hidden_states, @@ -852,10 +1127,12 @@ def fused_moe( topk_weights, topk_ids, inplace=inplace, + activation=activation, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=use_int8_w8a16, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, + block_shape=block_shape, ) diff --git a/python/sglang/srt/layers/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py similarity index 85% rename from python/sglang/srt/layers/fused_moe_triton/layer.py rename to python/sglang/srt/layers/moe/fused_moe_triton/layer.py index d9503fe2025..b71a878a0ba 100644 --- a/python/sglang/srt/layers/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -5,27 +5,31 @@ from typing import Callable, List, Optional, Tuple import torch -from vllm.distributed import ( +from vllm.model_executor.custom_op import CustomOp + +from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.custom_op import CustomOp - from sglang.srt.layers.custom_op_util import register_custom_op +from sglang.srt.layers.moe.fused_moe_native import moe_forward_native +from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) -from sglang.srt.utils import set_weight_attrs +from sglang.srt.utils import get_bool_env_var, is_hip, permute_weight, set_weight_attrs -if torch.cuda.is_available() or torch.hip.is_available(): - from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts +if torch.cuda.is_available(): + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts else: fused_experts = None # type: ignore import logging +is_hip_ = is_hip() + logger = logging.getLogger(__name__) @@ -33,6 +37,7 @@ class FusedMoeWeightScaleSupported(Enum): TENSOR = "tensor" CHANNEL = "channel" GROUP = "group" + BLOCK = "block" class FusedMoEMethodBase(QuantizeMethodBase): @@ -95,6 +100,20 @@ def create_weights( layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if is_hip_ and get_bool_env_var("CK_MOE"): + layer.w13_weight = torch.nn.Parameter( + permute_weight(layer.w13_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + permute_weight(layer.w2_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + return + def apply( self, layer: torch.nn.Module, @@ -106,6 +125,8 @@ def apply( topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: return self.forward( x=x, @@ -117,6 +138,8 @@ def apply( topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, + correction_bias=correction_bias, + activation=activation, ) def forward_cuda( @@ -130,8 +153,10 @@ def forward_cuda( topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: - topk_weights, topk_ids = FusedMoE.select_experts( + topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -140,19 +165,58 @@ def forward_cuda( topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, + correction_bias=correction_bias, ) - return fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - ) + if is_hip_ and get_bool_env_var("CK_MOE"): + import ater + from ater.fused_moe import fused_experts_ck + + assert activation == "silu", f"{activation=} is not supported." + + return fused_experts_ck( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + ) + else: + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + ) - def forward_cpu(self, *args, **kwargs): - raise NotImplementedError("The CPU backend currently does not support MoE.") + def forward_cpu( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return moe_forward_native( + layer, + x, + use_grouped_topk, + top_k, + router_logits, + renormalize, + topk_group, + num_expert_group, + custom_routing_function, + correction_bias, + ) def forward_tpu(self, *args, **kwargs) -> torch.Tensor: raise NotImplementedError("The TPU backend currently does not support MoE.") @@ -197,6 +261,9 @@ def __init__( tp_size: Optional[int] = None, prefix: str = "", custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + use_presharded_weights: bool = False, ): super().__init__() @@ -208,6 +275,7 @@ def __init__( ) self.top_k = top_k self.num_experts = num_experts + assert intermediate_size % self.tp_size == 0 self.intermediate_size_per_partition = intermediate_size // self.tp_size self.reduce_results = reduce_results self.renormalize = renormalize @@ -217,6 +285,8 @@ def __init__( self.num_expert_group = num_expert_group self.topk_group = topk_group self.custom_routing_function = custom_routing_function + self.correction_bias = correction_bias + self.activation = activation if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( @@ -234,6 +304,7 @@ def __init__( params_dtype=params_dtype, weight_loader=self.weight_loader, ) + self.use_presharded_weights = use_presharded_weights def _load_per_tensor_weight_scale( self, @@ -312,9 +383,12 @@ def _load_w13( # Index the loaded weight for tp sharding. # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim shard_size = expert_data.shape[shard_dim] // 2 - loaded_weight = loaded_weight.narrow( - shard_dim, shard_size * tp_rank, shard_size - ) + + if not self.use_presharded_weights: + loaded_weight = loaded_weight.narrow( + shard_dim, shard_size * tp_rank, shard_size + ) + # Narrow parameter and load. # w1, gate_proj: Load into first logical weight of w13. if shard_id == "w1": @@ -338,9 +412,12 @@ def _load_w2( # down_proj: "RowParallel" so tp sharding on input_dim # Narrow parameter and load. shard_size = expert_data.shape[shard_dim] - loaded_weight = loaded_weight.narrow( - shard_dim, shard_size * tp_rank, shard_size - ) + + if not self.use_presharded_weights: + loaded_weight = loaded_weight.narrow( + shard_dim, shard_size * tp_rank, shard_size + ) + # w2, down_proj: Load into only logical weight of w2. expert_data.copy_(loaded_weight) @@ -381,7 +458,6 @@ def weight_loader( shard_id: str, expert_id: int, ) -> None: - # compressed-tensors checkpoints with packed weights are stored flipped # TODO (mgoin): check self.quant_method.quant_config.quant_format # against known CompressionFormat enum values that have this quality @@ -463,7 +539,10 @@ def weight_loader( expert_data=expert_data, tp_rank=tp_rank, ) - elif quant_method == FusedMoeWeightScaleSupported.GROUP.value: + elif quant_method in [ + FusedMoeWeightScaleSupported.GROUP.value, + FusedMoeWeightScaleSupported.BLOCK.value, + ]: self._load_model_weight_or_group_weight_scale( shard_id=shard_id, shard_dim=shard_dim, @@ -503,51 +582,6 @@ def weight_loader( ) return - @staticmethod - def select_experts( - hidden_states: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - use_grouped_topk: bool, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - ): - from sglang.srt.layers.fused_moe_triton.fused_moe import ( - fused_topk, - grouped_topk, - ) - - # DeekSeekv2 uses grouped_top_k - if use_grouped_topk: - assert topk_group is not None - assert num_expert_group is not None - topk_weights, topk_ids = grouped_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - num_expert_group=num_expert_group, - topk_group=topk_group, - ) - elif custom_routing_function is None: - topk_weights, topk_ids = fused_topk( - hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - ) - else: - topk_weights, topk_ids = custom_routing_function( - hidden_states=hidden_states, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - ) - - return topk_weights, topk_ids - def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None @@ -562,6 +596,8 @@ def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): topk_group=self.topk_group, num_expert_group=self.num_expert_group, custom_routing_function=self.custom_routing_function, + correction_bias=self.correction_bias, + activation=self.activation, ) if self.reduce_results and self.tp_size > 1: diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py new file mode 100644 index 00000000000..dc53e4445db --- /dev/null +++ b/python/sglang/srt/layers/moe/topk.py @@ -0,0 +1,211 @@ +# Copyright 2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from typing import Callable, Optional + +import torch +import torch.nn.functional as F + +from sglang.srt.utils import get_compiler_backend + + +def fused_topk_native( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +): + assert ( + hidden_states.shape[0] == gating_output.shape[0] + ), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}" + M, _ = hidden_states.shape + topk_weights = torch.empty( + M, topk, dtype=torch.float32, device=hidden_states.device + ) + topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) + topk_weights = F.softmax(gating_output.float(), dim=-1) + topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + return topk_weights, topk_ids + + +def fused_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +): + from vllm import _custom_ops as ops + + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + + M, _ = hidden_states.shape + + topk_weights = torch.empty( + M, topk, dtype=torch.float32, device=hidden_states.device + ) + topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) + token_expert_indicies = torch.empty( + M, topk, dtype=torch.int32, device=hidden_states.device + ) + + ops.topk_softmax( + topk_weights, + topk_ids, + token_expert_indicies, + gating_output.float(), + ) + del token_expert_indicies + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_ids + + +# This is used by the Deepseek-V2 model +@torch.compile(dynamic=True, backend=get_compiler_backend()) +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, +): + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + + scores = torch.softmax(gating_output, dim=-1) + num_token = scores.shape[0] + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) + .reshape(num_token, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + +@torch.compile(dynamic=True, backend=get_compiler_backend()) +def biased_grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + correction_bias: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, +): + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + + scores = gating_output.sigmoid() + num_token = scores.shape[0] + scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0) + group_scores = ( + scores_for_choice.view(num_token, num_expert_group, -1) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) + .reshape(num_token, -1) + ) # [n, e] + tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e] + _, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + topk_weights = scores.gather(1, topk_ids) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + +def select_experts( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + use_grouped_topk: bool, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, + torch_native: bool = False, +): + # DeekSeekv2 uses grouped_top_k + if use_grouped_topk: + assert topk_group is not None + assert num_expert_group is not None + if correction_bias is None: + topk_weights, topk_ids = grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + ) + else: + topk_weights, topk_ids = biased_grouped_topk( + hidden_states=hidden_states, + gating_output=router_logits, + correction_bias=correction_bias, + topk=top_k, + renormalize=renormalize, + num_expert_group=num_expert_group, + topk_group=topk_group, + ) + elif torch_native and custom_routing_function is None: + topk_weights, topk_ids = fused_topk_native( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + ) + elif custom_routing_function is None: + topk_weights, topk_ids = fused_topk( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + ) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize, + ) + + return topk_weights, topk_ids diff --git a/python/sglang/srt/layers/parameter.py b/python/sglang/srt/layers/parameter.py new file mode 100644 index 00000000000..78be6798254 --- /dev/null +++ b/python/sglang/srt/layers/parameter.py @@ -0,0 +1,449 @@ +"""Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/parameter.py""" + +import logging +from fractions import Fraction +from typing import Callable, Optional, Union + +import torch +from torch.nn import Parameter + +from sglang.srt.distributed import get_tensor_model_parallel_rank + +__all__ = [ + "BasevLLMParameter", + "PackedvLLMParameter", + "PerTensorScaleParameter", + "ModelWeightParameter", + "ChannelQuantScaleParameter", + "GroupQuantScaleParameter", + "PackedColumnParameter", + "RowvLLMParameter", +] + +logger = logging.getLogger(__name__) + + +class BasevLLMParameter(Parameter): + """ + Base parameter for vLLM linear layers. Extends the torch.nn.parameter + by taking in a linear weight loader. Will copy the loaded weight + into the parameter when the provided weight loader is called. + """ + + def __new__(cls, data: torch.Tensor, **kwargs): + + return super().__new__(cls, data=data, requires_grad=False) + + def __init__(self, data: torch.Tensor, weight_loader: Callable): + """ + Initialize the BasevLLMParameter + + :param data: torch tensor with the parameter data + :param weight_loader: weight loader callable + + :returns: a torch.nn.parameter + """ + + self._weight_loader = weight_loader + + @property + def weight_loader(self): + return self._weight_loader + + def _assert_and_load(self, loaded_weight: torch.Tensor): + assert self.data.shape == loaded_weight.shape + self.data.copy_(loaded_weight) + + def load_column_parallel_weight(self, loaded_weight: torch.Tensor): + self._assert_and_load(loaded_weight) + + def load_row_parallel_weight(self, loaded_weight: torch.Tensor): + self._assert_and_load(loaded_weight) + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): + self._assert_and_load(loaded_weight) + + def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): + self._assert_and_load(loaded_weight) + + +class _ColumnvLLMParameter(BasevLLMParameter): + """ + Private class defining weight loading functionality + (load_merged_column_weight, load_qkv_weight) + for parameters being loaded into linear layers with column + parallelism. This includes QKV and MLP layers which are + not already fused on disk. Requires an output dimension + to be defined. Called within the weight loader of + each of the column parallel linear layers. + """ + + def __init__(self, output_dim: int, **kwargs): + self._output_dim = output_dim + super().__init__(**kwargs) + + @property + def output_dim(self): + return self._output_dim + + def load_column_parallel_weight( + self, + loaded_weight: torch.Tensor, + tp_rank: int, + use_presharded_weights: bool = False, + ): + if not use_presharded_weights: + shard_size = self.data.shape[self.output_dim] + loaded_weight = loaded_weight.narrow( + self.output_dim, tp_rank * shard_size, shard_size + ) + assert self.data.shape == loaded_weight.shape + self.data.copy_(loaded_weight) + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): + + shard_offset = kwargs.get("shard_offset") + shard_size = kwargs.get("shard_size") + use_presharded_weights = kwargs.get("use_presharded_weights") + if ( + isinstance(self, (PackedColumnParameter, PackedvLLMParameter)) + and self.packed_dim == self.output_dim + ): + shard_size, shard_offset = self.adjust_shard_indexes_for_packing( + shard_offset=shard_offset, shard_size=shard_size + ) + + param_data = self.data + + tp_rank = get_tensor_model_parallel_rank() + param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) + if not use_presharded_weights: + loaded_weight = loaded_weight.narrow( + self.output_dim, tp_rank * shard_size, shard_size + ) + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def load_qkv_weight( + self, + loaded_weight: torch.Tensor, + tp_rank: int, + use_presharded_weights: bool = False, + **kwargs, + ): + + shard_offset = kwargs.get("shard_offset") + shard_size = kwargs.get("shard_size") + shard_id = kwargs.get("shard_id") + num_heads = kwargs.get("num_heads") + + if ( + isinstance(self, (PackedColumnParameter, PackedvLLMParameter)) + and self.output_dim == self.packed_dim + ): + shard_size, shard_offset = self.adjust_shard_indexes_for_packing( + shard_offset=shard_offset, shard_size=shard_size + ) + + param_data = self.data + shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads + param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) + if not use_presharded_weights: + loaded_weight = loaded_weight.narrow( + self.output_dim, shard_id * shard_size, shard_size + ) + + assert ( + param_data.shape == loaded_weight.shape + ), f"{param_data.shape=}, {loaded_weight.shape=}" + param_data.copy_(loaded_weight) + + +class RowvLLMParameter(BasevLLMParameter): + """ + Parameter class defining weight_loading functionality + (load_row_parallel_weight) for parameters being loaded + into linear layers with row parallel functionality. + Requires an input_dim to be defined. + """ + + def __init__(self, input_dim: int, **kwargs): + self._input_dim = input_dim + super().__init__(**kwargs) + + @property + def input_dim(self): + return self._input_dim + + def load_row_parallel_weight( + self, + loaded_weight: torch.Tensor, + tp_rank: int, + use_presharded_weights: bool = False, + ): + if not use_presharded_weights: + shard_size = self.data.shape[self.input_dim] + loaded_weight = loaded_weight.narrow( + self.input_dim, tp_rank * shard_size, shard_size + ) + + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert self.data.shape == loaded_weight.shape + self.data.copy_(loaded_weight) + + +class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter): + """ + Parameter class for linear layer weights. Uses both column and + row parallelism. + """ + + pass + + +class GroupQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): + """ + Parameter class for weight scales loaded for weights with + grouped quantization. Uses both column and row parallelism. + """ + + pass + + +class ChannelQuantScaleParameter(_ColumnvLLMParameter): + """ + Parameter class for weight scales loaded for weights with + channel-wise quantization. Equivalent to _ColumnvLLMParameter. + """ + + pass + + +class PerTensorScaleParameter(BasevLLMParameter): + """ + Parameter class for scales where the number of scales is + equivalent to the number of logical matrices in fused linear + layers (e.g. for QKV, there are 3 scales loaded from disk). + This is relevant to weights with per-tensor quantization. + Adds functionality to map the scalers to a shard during + weight loading. + + Note: additional parameter manipulation may be handled + for each quantization config specifically, within + process_weights_after_loading + """ + + def __init__(self, **kwargs): + self.qkv_idxs = {"q": 0, "k": 1, "v": 2} + super().__init__(**kwargs) + + def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: + if isinstance(shard_id, int): + return shard_id + + # if not int, assume shard_id for qkv + # map to int and return + assert isinstance(shard_id, str) + assert shard_id in self.qkv_idxs + return self.qkv_idxs[shard_id] + + # For row parallel layers, no sharding needed + # load weight into parameter as is + def load_row_parallel_weight(self, *args, **kwargs): + kwargs.pop("tp_rank", None) + kwargs.pop("use_presharded_weights", None) + super().load_row_parallel_weight(*args, **kwargs) + + def load_merged_column_weight(self, *args, **kwargs): + self._load_into_shard_id(*args, **kwargs) + + def load_qkv_weight(self, *args, **kwargs): + self._load_into_shard_id(*args, **kwargs) + + def load_column_parallel_weight(self, *args, **kwargs): + kwargs.pop("tp_rank", None) + kwargs.pop("use_presharded_weights", None) + super().load_row_parallel_weight(*args, **kwargs) + + def _load_into_shard_id( + self, loaded_weight: torch.Tensor, shard_id: Union[str, int], **kwargs + ): + """ + Slice the parameter data based on the shard id for + loading. + """ + + param_data = self.data + shard_id = self._shard_id_as_int(shard_id) + + # AutoFP8 scales do not have a shape + # compressed-tensors scales do have a shape + if len(loaded_weight.shape) != 0: + assert loaded_weight.shape[0] == 1 + loaded_weight = loaded_weight[0] + + param_data = param_data[shard_id] + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class PackedColumnParameter(_ColumnvLLMParameter): + """ + Parameter for model parameters which are packed on disk + and support column parallelism only. See PackedvLLMParameter + for more details on the packed properties. + """ + + def __init__( + self, + packed_factor: Union[int, Fraction], + packed_dim: int, + marlin_tile_size: Optional[int] = None, + **kwargs, + ): + self._packed_factor = packed_factor + self._packed_dim = packed_dim + self._marlin_tile_size = marlin_tile_size + super().__init__(**kwargs) + + @property + def packed_dim(self): + return self._packed_dim + + @property + def packed_factor(self): + return self._packed_factor + + @property + def marlin_tile_size(self): + return self._marlin_tile_size + + def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): + return _adjust_shard_indexes_for_packing( + shard_size=shard_size, + shard_offset=shard_offset, + packed_factor=self.packed_factor, + marlin_tile_size=self.marlin_tile_size, + ) + + +class PackedvLLMParameter(ModelWeightParameter): + """ + Parameter for model weights which are packed on disk. + Example: GPTQ Marlin weights are int4 or int8, packed into int32. + Extends the ModelWeightParameter to take in the + packed factor, the packed dimension, and optionally, marlin + tile size for marlin kernels. Adjusts the shard_size and + shard_offset for fused linear layers model weight loading + by accounting for packing and optionally, marlin tile size. + """ + + def __init__( + self, + packed_factor: Union[int, Fraction], + packed_dim: int, + marlin_tile_size: Optional[int] = None, + **kwargs, + ): + self._packed_factor = packed_factor + self._packed_dim = packed_dim + self._marlin_tile_size = marlin_tile_size + super().__init__(**kwargs) + + @property + def packed_dim(self): + return self._packed_dim + + @property + def packed_factor(self): + return self._packed_factor + + @property + def marlin_tile_size(self): + return self._marlin_tile_size + + def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): + return _adjust_shard_indexes_for_packing( + shard_size=shard_size, + shard_offset=shard_offset, + packed_factor=self.packed_factor, + marlin_tile_size=self.marlin_tile_size, + ) + + +def permute_param_layout_( + param: BasevLLMParameter, input_dim: int, output_dim: int, **kwargs +) -> BasevLLMParameter: + """ + Permute a parameter's layout to the specified input and output dimensions, + useful for forcing the parameter into a known layout, for example, if I need + a packed (quantized) weight matrix to be in the layout + {input_dim = 0, output_dim = 1, packed_dim = 0} + then I can call: + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + to ensure x is in the correct layout (permuting it to the correct layout if + required, asserting if it cannot get it to the correct layout) + """ + + curr_input_dim = getattr(param, "input_dim", None) + curr_output_dim = getattr(param, "output_dim", None) + + if curr_input_dim is None or curr_output_dim is None: + assert param.data.dim() == 2, ( + "permute_param_layout_ only supports 2D parameters when either " + "input_dim or output_dim is not set" + ) + + # if one of the dimensions is not set, set it to the opposite of the other + # we can only do this since we asserted the parameter is 2D above + if curr_input_dim is None: + assert curr_output_dim is not None, "either input or output dim must be set" + curr_input_dim = (curr_output_dim + 1) % 2 + if curr_output_dim is None: + assert curr_input_dim is not None, "either input or output dim must be set" + curr_output_dim = (curr_input_dim + 1) % 2 + + # create permutation from the current layout to the layout with + # self.input_dim at input_dim and self.output_dim at output_dim preserving + # other dimensions + perm = [ + i for i in range(param.data.dim()) if i not in [curr_input_dim, curr_output_dim] + ] + perm.insert(input_dim, curr_input_dim) + perm.insert(output_dim, curr_output_dim) + + if "packed_dim" in kwargs: + assert ( + hasattr(param, "packed_dim") + and param.packed_dim == perm[kwargs["packed_dim"]] + ), "permute_param_layout_ currently doesn't support repacking" + + param.data = param.data.permute(*perm) + if hasattr(param, "_input_dim"): + param._input_dim = input_dim + if hasattr(param, "_output_dim"): + param._output_dim = output_dim + if "packed_dim" in kwargs and hasattr(param, "_packed_dim"): + param._packed_dim = kwargs["packed_dim"] + + return param + + +def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, marlin_tile_size): + return shard_size * marlin_tile_size, shard_offset * marlin_tile_size + + +def _adjust_shard_indexes_for_packing( + shard_size, shard_offset, packed_factor, marlin_tile_size +): + shard_size = shard_size // packed_factor + shard_offset = shard_offset // packed_factor + if marlin_tile_size is not None: + return _adjust_shard_indexes_for_marlin( + shard_size=shard_size, + shard_offset=shard_offset, + marlin_tile_size=marlin_tile_size, + ) + return shard_size, shard_offset diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 3e2078c4a4d..1c0092c1a40 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -1,8 +1,7 @@ # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py -from typing import Callable, Dict, Optional, Type +from typing import Dict, Type -import torch from vllm.model_executor.layers.quantization.aqlm import AQLMConfig from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig @@ -22,7 +21,9 @@ from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod +from sglang.srt.layers.quantization.fp8 import Fp8Config +from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config +from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "aqlm": AQLMConfig, @@ -32,6 +33,7 @@ "fp8": Fp8Config, "fbgemm_fp8": FBGEMMFp8Config, "marlin": MarlinConfig, + "modelopt": ModelOptFp8Config, "gguf": GGUFConfig, "gptq_marlin_24": GPTQMarlin24Config, "gptq_marlin": GPTQMarlinConfig, @@ -41,6 +43,7 @@ "bitsandbytes": BitsAndBytesConfig, "qqq": QQQConfig, "experts_int8": ExpertsInt8Config, + "w8a8_int8": W8A8Int8Config, } @@ -53,78 +56,14 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: return QUANTIZATION_METHODS[quantization] -def fp8_moe_apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, -) -> torch.Tensor: - """Enhanced apply method for FP8 MoE.""" - from sglang.srt.layers.fused_moe_triton import FusedMoE - from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts - - # Expert selection - topk_weights, topk_ids = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - ) - - # Expert fusion with FP8 quantization - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_fp8_w8a8=True, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - ) - - -def fp8_get_quant_method(self, layer, prefix): - """Enhanced get_quant_method for FP8 config.""" - from vllm.model_executor.layers.linear import LinearBase - from vllm.model_executor.layers.quantization.utils.quant_utils import ( - is_layer_skipped, - ) - - from sglang.srt.layers.fused_moe_triton.layer import FusedMoE - from sglang.srt.layers.linear import UnquantizedLinearMethod - from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod - - if isinstance(layer, LinearBase): - if is_layer_skipped(prefix, self.ignored_layers): - return UnquantizedLinearMethod() - return Fp8LinearMethod(self) - elif isinstance(layer, FusedMoE): - return Fp8MoEMethod(self) - return None - - def gptq_get_quant_method(self, layer, prefix): - from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinLinearMethod, GPTQMarlinMoEMethod, ) - from sglang.srt.layers.fused_moe_triton.layer import FusedMoE + from sglang.srt.layers.linear import LinearBase + from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE if isinstance(layer, LinearBase): return GPTQMarlinLinearMethod(self) @@ -134,13 +73,13 @@ def gptq_get_quant_method(self, layer, prefix): def awq_get_quant_method(self, layer, prefix): - from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization.awq_marlin import ( AWQMarlinLinearMethod, AWQMoEMethod, ) - from sglang.srt.layers.fused_moe_triton.layer import FusedMoE + from sglang.srt.layers.linear import LinearBase + from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE if isinstance(layer, LinearBase): return AWQMarlinLinearMethod(self) @@ -149,14 +88,30 @@ def awq_get_quant_method(self, layer, prefix): return None +def patch_vllm_linear_base_isinstance(): + import builtins + + from vllm.model_executor.layers.linear import LinearBase + + from sglang.srt.layers.linear import LinearBase as PatchedLinearBase + + original_isinstance = builtins.isinstance + + def patched_isinstance(obj, classinfo): + if classinfo is LinearBase: + return original_isinstance(obj, PatchedLinearBase) + return original_isinstance(obj, classinfo) + + builtins.isinstance = patched_isinstance + + def apply_monkey_patches(): """Apply all monkey patches in one place.""" - setattr(Fp8MoEMethod, "apply", fp8_moe_apply) - setattr(Fp8Config, "get_quant_method", fp8_get_quant_method) setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method) setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method) +patch_vllm_linear_base_isinstance() # Apply patches when module is imported apply_monkey_patches() diff --git a/python/sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..6496a38fba8 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..c098ef2dbb9 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..3618053b658 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..46a982f5ee9 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..035ec027fa5 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..8b49f2781cb --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..851bc9f9f0b --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..d1227c21579 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..1c61451fb34 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..63e661c80de --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..cf354037903 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..6f5adbb9361 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..eccb86a76df --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..88af48431d8 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..dd069726d7e --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..56b939e52fa --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..63d9a0bf5d7 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..7fa398c15a2 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..f15d8f64c70 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..4225c78eb72 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..cd3e07804fd --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..9d5a329d746 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..5e6789d00e0 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..96e1594a3ea --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..5ffd367df83 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..49ac14d2a57 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..eabc423949a --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..dcbb0efc53e --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..51e237b91b8 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..6280219c9ee --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..40c01c0b92b --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..c6fd3659799 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..160f12ed3f9 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..e5c4a1d2c94 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..2bf5eb27e38 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..0a1e14cffbb --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..15b1c93f60f --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..8ff12e64c17 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..4532f93681e --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..dfe5c1e43d6 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..ca7f32b9552 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..5acea242cc0 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 5 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..a87f5de1b18 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..58cdd93e90b --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..b72e0371d14 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..468f9e78da0 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 00000000000..293adce387e --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 5 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + } +} diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index acdce0b8cbd..f5a0005a282 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -4,10 +4,10 @@ from typing import Any, Callable, Dict, List, Optional import torch +import torch.nn.functional as F from torch.nn import Module from torch.nn.parameter import Parameter from vllm import _custom_ops as ops -from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, @@ -22,28 +22,35 @@ per_tensor_dequantize, requantize_with_max_scale, ) -from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter -from sglang.srt.layers.fused_moe_triton import ( - FusedMoE, - FusedMoEMethodBase, - FusedMoeWeightScaleSupported, +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, ) -from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod +from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) -from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz +from sglang.srt.layers.quantization.fp8_utils import ( + BlockQuantScaleParameter, + apply_w8a8_block_fp8_linear, + normalize_e4m3fn_to_e4m3fnuz, +) from sglang.srt.utils import ( get_bool_env_var, is_hip, + permute_weight, print_warning_once, set_weight_attrs, ) ACTIVATION_SCHEMES = ["static", "dynamic"] +is_hip_ = is_hip() + logger = logging.getLogger(__name__) @@ -55,6 +62,7 @@ def __init__( is_checkpoint_fp8_serialized: bool = False, activation_scheme: str = "dynamic", ignored_layers: Optional[List[str]] = None, + weight_block_size: List[int] = None, ) -> None: self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized if is_checkpoint_fp8_serialized: @@ -66,6 +74,20 @@ def __init__( raise ValueError(f"Unsupported activation scheme {activation_scheme}") self.activation_scheme = activation_scheme self.ignored_layers = ignored_layers or [] + if weight_block_size is not None: + if not is_checkpoint_fp8_serialized: + raise ValueError( + f"The block-wise quantization only supports fp8-serialized checkpoint for now." + ) + if len(weight_block_size) != 2: + raise ValueError( + f"The quantization block size of weight must have 2 dimensions, but got {len(weight_block_size)} dimensions." + ) + if activation_scheme != "dynamic": + raise ValueError( + f"The block-wise quantization only supports dynamic activation scheme for now, but got {activation_scheme} activation scheme." + ) + self.weight_block_size = weight_block_size @classmethod def get_name(cls) -> str: @@ -89,10 +111,12 @@ def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": is_checkpoint_fp8_serialized = "fp8" in quant_method activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) + weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None) return cls( is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, activation_scheme=activation_scheme, ignored_layers=ignored_layers, + weight_block_size=weight_block_size, ) def get_quant_method( @@ -100,6 +124,8 @@ def get_quant_method( ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + if isinstance(layer, LinearBase): if is_layer_skipped(prefix, self.ignored_layers): return UnquantizedLinearMethod() @@ -140,7 +166,12 @@ def __init__(self, quant_config: Fp8Config): # kernel for fast weight-only FP8 quantization self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN") # Disable marlin for ROCm - if is_hip(): + if is_hip_: + self.use_marlin = False + + self.block_quant = self.quant_config.weight_block_size is not None + if self.block_quant: + # Marlin doesn't support block-wise fp8 self.use_marlin = False def create_weights( @@ -153,10 +184,35 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): - del input_size, output_size output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") + tp_size = get_tensor_model_parallel_world_size() + if self.block_quant: + block_n, block_k = ( + self.quant_config.weight_block_size[0], + self.quant_config.weight_block_size[1], + ) + # Required by row parallel + if tp_size > 1 and input_size // input_size_per_partition == tp_size: + if input_size_per_partition % block_k != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"weight quantization block_k = {block_k}." + ) + # Required by collum parallel or enabling merged weights + if ( + tp_size > 1 and output_size // output_size_per_partition == tp_size + ) or len(output_partition_sizes) > 1: + for output_partition_size in output_partition_sizes: + if output_partition_size % block_n != 0: + raise ValueError( + f"Weight output_partition_size = " + f"{output_partition_size} is not divisible by " + f"weight quantization block_n = {block_n}." + ) + layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition @@ -184,13 +240,27 @@ def create_weights( # Otherwise, wait until process_weights_after_loading. if self.quant_config.is_checkpoint_fp8_serialized: # WEIGHT SCALE - scale = PerTensorScaleParameter( - data=torch.empty(len(output_partition_sizes), dtype=torch.float32), - weight_loader=weight_loader, - ) - - scale[:] = torch.finfo(torch.float32).min - layer.register_parameter("weight_scale", scale) + if self.block_quant: + assert self.quant_config.activation_scheme == "dynamic" + scale = BlockQuantScaleParameter( + data=torch.empty( + (output_size_per_partition + block_n - 1) // block_n, + (input_size_per_partition + block_k - 1) // block_k, + dtype=torch.float32, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale_inv", scale) + else: + scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", scale) # INPUT ACTIVATION SCALE if self.quant_config.activation_scheme == "static": @@ -205,6 +275,29 @@ def create_weights( layer.register_parameter("input_scale", None) def process_weights_after_loading(self, layer: Module) -> None: + # Block quant doesn't need to process weights after loading + if self.block_quant: + # If ROCm, normalize the weights and scales to e4m3fnuz + if is_hip_: + # activation_scheme: dynamic + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=layer.weight, + weight_scale=layer.weight_scale_inv, + input_scale=None, + ) + layer.weight = torch.nn.Parameter(weight, requires_grad=False) + layer.weight_scale_inv = torch.nn.Parameter( + weight_scale, requires_grad=False + ) + layer.input_scale = None + else: + layer.weight = torch.nn.Parameter( + layer.weight.data, requires_grad=False + ) + layer.weight_scale_inv = torch.nn.Parameter( + layer.weight_scale_inv.data, requires_grad=False + ) + return layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) # If checkpoint not serialized fp8, quantize the weights. if not self.quant_config.is_checkpoint_fp8_serialized: @@ -249,7 +342,7 @@ def process_weights_after_loading(self, layer: Module) -> None: weight_scale = layer.weight_scale # If ROCm, normalize the weights and scales to e4m3fnuz - if is_hip(): + if is_hip_: weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight=weight, weight_scale=weight_scale, @@ -295,6 +388,16 @@ def apply( bias=bias, ) + if self.block_quant: + return apply_w8a8_block_fp8_linear( + input=x, + weight=layer.weight, + block_size=self.quant_config.weight_block_size, + weight_scale=layer.weight_scale_inv, + input_scale=None, + bias=bias, + ) + return apply_fp8_linear( input=x, weight=layer.weight, @@ -306,7 +409,7 @@ def apply( ) -class Fp8MoEMethod(FusedMoEMethodBase): +class Fp8MoEMethod: """MoE method for FP8. Supports loading FP8 checkpoints with static weight scale and dynamic/static activation scale. @@ -319,8 +422,27 @@ class Fp8MoEMethod(FusedMoEMethodBase): quant_config: The quantization config. """ - def __init__(self, quant_config: Fp8Config): + def __new__(cls, *args, **kwargs): + from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase + + if not hasattr(cls, "_initialized"): + original_init = cls.__init__ + new_cls = type( + cls.__name__, + (FusedMoEMethodBase,), + { + "__init__": original_init, + **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, + }, + ) + obj = super(new_cls, new_cls).__new__(new_cls) + obj.__init__(*args, **kwargs) + return obj + return super().__new__(cls) + + def __init__(self, quant_config): self.quant_config = quant_config + self.block_quant = self.quant_config.weight_block_size is not None def create_weights( self, @@ -331,9 +453,32 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported if self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.float8_e4m3fn + tp_size = get_tensor_model_parallel_world_size() + if self.block_quant: + block_n, block_k = ( + self.quant_config.weight_block_size[0], + self.quant_config.weight_block_size[1], + ) + # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n. + # Required by collum parallel or enabling merged weights + if intermediate_size % block_n != 0: + raise ValueError( + f"The output_size of gate's and up's weight = " + f"{intermediate_size} is not divisible by " + f"weight quantization block_n = {block_n}." + ) + if tp_size > 1: + # Required by row parallel + if intermediate_size % block_k != 0: + raise ValueError( + f"The input_size of down's weight = " + f"{intermediate_size} is not divisible by " + f"weight quantization block_k = {block_k}." + ) # WEIGHTS w13_weight = torch.nn.Parameter( @@ -355,21 +500,45 @@ def create_weights( set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES - # Allocate 2 scales for w1 and w3 respectively. - # They will be combined to a single scale after weight loading. - w13_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("w13_weight_scale", w13_weight_scale) - - w2_weight_scale = torch.nn.Parameter( - torch.ones(num_experts, dtype=torch.float32), requires_grad=False - ) - layer.register_parameter("w2_weight_scale", w2_weight_scale) + if self.block_quant: + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * ((intermediate_size + block_n - 1) // block_n), + (hidden_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + (hidden_size + block_n - 1) // block_n, + (intermediate_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) + layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) + assert self.quant_config.activation_scheme == "dynamic" + else: + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) # Add the quantization method used (per tensor/grouped/channel) # to ensure the weight scales are loaded in properly extra_weight_attrs.update( - {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} + if self.block_quant + else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} ) # If loading fp8 checkpoint, pass the weight loaders. # If loading an fp16 checkpoint, do not (we will quantize in @@ -403,11 +572,41 @@ def create_weights( layer.w2_input_scale = None def process_weights_after_loading(self, layer: Module) -> None: + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( + padding_size, # Avoid circular import + ) - # If checkpoint is fp16, quantize in place. + # Block quant doesn't need to process weights after loading + if self.block_quant: + # If ROCm, normalize the weights and scales to e4m3fnuz + if is_hip_: + # activation_scheme: dynamic + w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=layer.w13_weight, + weight_scale=layer.w13_weight_scale_inv, + input_scale=None, + ) + w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=layer.w2_weight, + weight_scale=layer.w2_weight_scale_inv, + input_scale=None, + ) + # Reset the parameter + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w13_weight_scale_inv = torch.nn.Parameter( + w13_weight_scale, requires_grad=False + ) + layer.w13_input_scale = None + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale_inv = torch.nn.Parameter( + w2_weight_scale, requires_grad=False + ) + layer.w2_input_scale = None + return + # If checkpoint is fp16 or bfloat16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: # If ROCm, use float8_e4m3fnuz instead (MI300x HW) - fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn + fp8_dtype = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) @@ -428,6 +627,31 @@ def process_weights_after_loading(self, layer: Module) -> None: ) layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + + if is_hip_: + if get_bool_env_var("CK_MOE"): + layer.w13_weight = torch.nn.Parameter( + permute_weight(layer.w13_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + permute_weight(layer.w2_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + elif get_bool_env_var("MOE_PADDING"): + # If ROCm, apply weight padding (min. Mem channel contention) only if set + layer.w13_weight = torch.nn.Parameter( + F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() return # If checkpoint is fp8, we need to handle that the @@ -456,8 +680,9 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w2_input_scale = torch.nn.Parameter( layer.w2_input_scale.max(), requires_grad=False ) + # If ROCm, normalize the weights and scales to e4m3fnuz - if is_hip(): + if is_hip_: # Normalize the weights and scales w13_weight, w13_weight_scale, w13_input_scale = ( normalize_e4m3fn_to_e4m3fnuz( @@ -486,7 +711,6 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w2_input_scale = torch.nn.Parameter( w2_input_scale, requires_grad=False ) - # Fp8 moe kernel needs single weight scale for w13 per expert. # We take the max then dequant and requant each expert. assert layer.w13_weight_scale is not None @@ -507,6 +731,31 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w13_weight_scale = torch.nn.Parameter( max_w13_scales, requires_grad=False ) + + if is_hip_: + if get_bool_env_var("CK_MOE"): + layer.w13_weight = torch.nn.Parameter( + permute_weight(layer.w13_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + permute_weight(layer.w2_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + elif get_bool_env_var("MOE_PADDING"): + # If ROCm, apply weight padding (min. Mem channel contention) only if set + layer.w13_weight = torch.nn.Parameter( + F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() return def apply( @@ -520,11 +769,14 @@ def apply( topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", ) -> torch.Tensor: + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts + from sglang.srt.layers.moe.topk import select_experts - from vllm.model_executor.layers.fused_moe import fused_experts - - topk_weights, topk_ids = FusedMoE.select_experts( + # Expert selection + topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -533,21 +785,61 @@ def apply( topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, + correction_bias=correction_bias, ) - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_fp8_w8a8=True, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - ) + if is_hip_ and get_bool_env_var("CK_MOE"): + import ater + from ater.fused_moe import fused_experts_ck + + assert activation == "silu", f"{activation=} is not supported." + + return fused_experts_ck( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + use_fp8_w8a8=True, + w1_scale=( + layer.w13_weight_scale_inv + if self.block_quant + else layer.w13_weight_scale + ), + w2_scale=( + layer.w2_weight_scale_inv + if self.block_quant + else layer.w2_weight_scale + ), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + ) + + else: + # Expert fusion with FP8 quantization + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + use_fp8_w8a8=True, + w1_scale=( + layer.w13_weight_scale_inv + if self.block_quant + else layer.w13_weight_scale + ), + w2_scale=( + layer.w2_weight_scale_inv + if self.block_quant + else layer.w2_weight_scale + ), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=self.quant_config.weight_block_size, + ) class Fp8KVCacheMethod(BaseKVCacheMethod): diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py new file mode 100644 index 00000000000..fe57838e591 --- /dev/null +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -0,0 +1,351 @@ +# Copyright 2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import functools +import json +import logging +import os +from typing import Any, Dict, List, Optional, Tuple + +import torch +import triton +import triton.language as tl + +from sglang.srt.utils import get_device_name, is_hip + +is_hip_ = is_hip() +fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn + +logger = logging.getLogger(__name__) + + +@triton.jit +def _per_token_group_quant_fp8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + # Stride of input + y_stride, + # Collums of input + N, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group quantization on a + tensor. + + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * y_stride + y_q_ptr += g_id * y_stride + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) # N <= BLOCK + mask = cols < N + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: torch.dtype = fp8_type_, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Function to perform per-token-group quantization on an input tensor `x`. + + It converts the tensor values into signed float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + + Args: + x: The input tenosr with ndim >= 2. + group_size: The group size used for quantization. + eps: The minimum to avoid dividing zero. + dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. + """ + assert ( + x.shape[-1] % group_size == 0 + ), "the last dimension of `x` cannot be divisible by `group_size`" + assert x.is_contiguous(), "`x` is not contiguous" + + finfo = torch.finfo(dtype) + fp8_max = finfo.max + + if is_hip_: + fp8_max = 224.0 + + fp8_min = -fp8_max + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size + N = group_size + x_s = torch.empty( + x.shape[:-1] + (x.shape[-1] // group_size,), + device=x.device, + dtype=torch.float32, + ) + + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + _per_token_group_quant_fp8[(M,)]( + x, + x_q, + x_s, + group_size, + N, + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s + + +@triton.jit +def _w8a8_block_fp8_matmul( + # Pointers to inputs and output + A, + B, + C, + As, + Bs, + # Shape for matmul + M, + N, + K, + # Block size for block-wise quantization + group_n, + group_k, + # Stride for inputs and output + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_As_m, + stride_As_k, + stride_Bs_k, + stride_Bs_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """Triton-accelerated function used to perform linear operations (dot + product) on input tensors `A` and `B` with block-wise quantization, and store the result in output + tensor `C`. + """ + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + As_ptrs = As + offs_am * stride_As_m + offs_bsn = offs_bn // group_n + Bs_ptrs = Bs + offs_bsn * stride_Bs_n + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_s = tl.load(As_ptrs + offs_ks * stride_As_k) + b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) + + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if C.dtype.element_ty == tl.bfloat16: + c = accumulator.to(tl.bfloat16) + elif C.dtype.element_ty == tl.float16: + c = accumulator.to(tl.float16) + else: + c = accumulator.to(tl.float32) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +@functools.lru_cache +def get_w8a8_block_fp8_configs( + N: int, K: int, block_n: int, block_k: int +) -> Optional[Dict[int, Any]]: + """ + Return optimized configurations for the w8a8 block fp8 kernel. + + The return value will be a dictionary that maps an irregular grid of + batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the + kernel on a given batch size bs, the closest batch size in the grid should + be picked and the associated configuration chosen to invoke the kernel. + """ + + # First look up if an optimized configuration is available in the configs + # directory + device_name = get_device_name().replace(" ", "_") + json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n}, {block_k}].json" + + config_file_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name + ) + if os.path.exists(config_file_path): + with open(config_file_path) as f: + logger.info( + "Using configuration from %s for W8A8 Block FP8 kernel.", + config_file_path, + ) + # If a configuration has been found, return it + return {int(key): val for key, val in json.load(f).items()} + + # If no optimized configuration is available, we will use the default + # configuration + logger.warning( + ( + "Using default W8A8 Block FP8 kernel config. Performance might be sub-optimal! " + "Config file not found at %s" + ), + config_file_path, + ) + return None + + +def w8a8_block_fp8_matmul( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: List[int], + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + """This function performs matrix multiplication with block-wise quantization. + + It takes two input tensors `A` and `B` with scales `As` and `Bs`. + The output is returned in the specified `output_dtype`. + + Args: + A: The input tensor, e.g., activation. + B: The input tensor, e.g., weight. + As: The per-token-group quantization scale for `A`. + Bs: The per-block quantization scale for `B`. + block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128]. + output_dytpe: The dtype of the returned tensor. + + Returns: + torch.Tensor: The result of matmul. + """ + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + + assert A.shape[-1] == B.shape[-1] + assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() + assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] + M = A.numel() // A.shape[-1] + + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + N, K = B.shape + assert triton.cdiv(N, block_n) == Bs.shape[0] + assert triton.cdiv(K, block_k) == Bs.shape[1] + + C_shape = A.shape[:-1] + (N,) + C = A.new_empty(C_shape, dtype=output_dtype) + + configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1]) + if configs: + # If an optimal configuration map has been found, look up the + # optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Default config + # Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1] + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_size[0], + "BLOCK_SIZE_K": block_size[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, + } + + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + + _w8a8_block_fp8_matmul[grid]( + A, + B, + C, + As, + Bs, + M, + N, + K, + block_n, + block_k, + A.stride(-2), + A.stride(-1), + B.stride(1), + B.stride(0), + C.stride(-2), + C.stride(-1), + As.stride(-2), + As.stride(-1), + Bs.stride(1), + Bs.stride(0), + **config, + ) + + return C diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 3ba381a373f..d6ff12ee163 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -1,7 +1,16 @@ -from typing import Optional, Tuple +from typing import List, Optional, Tuple import torch +from sglang.srt.layers.parameter import RowvLLMParameter, _ColumnvLLMParameter +from sglang.srt.layers.quantization.fp8_kernel import ( + per_token_group_quant_fp8, + w8a8_block_fp8_matmul, +) +from sglang.srt.utils import is_hip + +is_hip_ = is_hip() + def normalize_e4m3fn_to_e4m3fnuz( weight: torch.Tensor, @@ -25,3 +34,89 @@ def normalize_e4m3fn_to_e4m3fnuz( if input_scale is not None: input_scale = input_scale * 2.0 return weight, weight_scale, input_scale + + +def apply_w8a8_block_fp8_linear( + input: torch.Tensor, + weight: torch.Tensor, + block_size: List[int], + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + assert input_scale is None + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[0]] + + q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1]) + output = w8a8_block_fp8_matmul( + q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype + ) + + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) + + +def input_to_float8( + x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn +) -> Tuple[torch.Tensor, torch.Tensor]: + """This function quantizes input values to float8 values with tensor-wise quantization.""" + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + fp8_max = finfo.max + if is_hip_: + fp8_max = 224.0 + scale = fp8_max / amax + x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max) + return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() + + +def block_quant_to_tensor_quant( + x_q_block: torch.Tensor, + x_s: torch.Tensor, + block_size: List[int], +) -> Tuple[torch.Tensor, torch.Tensor]: + """This function converts block-wise quantization to tensor-wise quantization. + The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale + and the block size. + The outputs are tensor-wise quantization tensor and tensor-wise quantization scale. + Note only float8 is supported for now. + """ + block_n, block_k = block_size[0], block_size[1] + n, k = x_q_block.shape + n_tiles = (n + block_n - 1) // block_n + k_tiles = (k + block_k - 1) // block_k + assert n_tiles == x_s.shape[0] + assert k_tiles == x_s.shape[1] + + x_dq_block = x_q_block.to(torch.float32) + + x_dq_block_tiles = [ + [ + x_dq_block[ + j * block_n : min((j + 1) * block_n, n), + i * block_k : min((i + 1) * block_k, k), + ] + for i in range(k_tiles) + ] + for j in range(n_tiles) + ] + + for i in range(k_tiles): + for j in range(n_tiles): + x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i] + + x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype) + return x_q_tensor, scale + + +class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): + """ + Parameter class for weight scales loaded for weights with + block-wise quantization. Uses both column and row parallelism. + """ + + pass diff --git a/python/sglang/srt/layers/quantization/int8_kernel.py b/python/sglang/srt/layers/quantization/int8_kernel.py new file mode 100644 index 00000000000..91b56f9e0e9 --- /dev/null +++ b/python/sglang/srt/layers/quantization/int8_kernel.py @@ -0,0 +1,54 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _per_token_quant_int8( + x_ptr, + xq_ptr, + scale_ptr, + stride_x, + stride_xq, + N, + BLOCK: tl.constexpr, +): + # Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282 + row_id = tl.program_id(0) + + cols = tl.arange(0, BLOCK) + mask = cols < N + + x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask, other=0.0).to(tl.float32) + absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10) + scale_x = absmax / 127 + x_q = x * (127 / absmax) + x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8) + + tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask) + tl.store(scale_ptr + row_id, scale_x) + + +def per_token_quant_int8(x): + M = x.numel() // x.shape[-1] + N = x.shape[-1] + x_q = torch.empty_like(x, device=x.device, dtype=torch.int8) + scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=torch.float32) + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + + assert x.is_contiguous() + _per_token_quant_int8[(M,)]( + x, + x_q, + scales, + stride_x=x.stride(-2), + stride_xq=x_q.stride(-2), + N=N, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + + return x_q, scales diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py new file mode 100644 index 00000000000..3e5f996ed10 --- /dev/null +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -0,0 +1,173 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py + +import logging +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + apply_fp8_linear, + cutlass_fp8_supported, + requantize_with_max_scale, +) + +from sglang.srt.layers.linear import LinearBase, LinearMethodBase +from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter +from sglang.srt.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) + +# Initialize logger for the module +logger = logging.getLogger(__name__) + +# Supported activation schemes for the current configuration +ACTIVATION_SCHEMES = ["static"] + + +class ModelOptFp8Config(QuantizationConfig): + """Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks.""" + + def __init__(self, is_checkpoint_fp8_serialized: bool = False) -> None: + """ + Args: + is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format. + """ + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + if is_checkpoint_fp8_serialized: + logger.warning( + "Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change." + ) + + @classmethod + def get_name(cls) -> str: + return "modelopt" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 89 # Minimum hardware capability (e.g., Hopper GPUs). + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["hf_quant_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config": + quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo") + + if "FP8" not in quant_method: + raise ValueError( + "ModelOpt only supports static FP8 quantization in SGLang. " + "Check the `hf_quant_config.json` file for your model's configuration." + ) + + return cls(is_checkpoint_fp8_serialized=True) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: + return ModelOptFp8LinearMethod(self) if isinstance(layer, LinearBase) else None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class ModelOptFp8LinearMethod(LinearMethodBase): + """Linear method for ModelOpt static FP8 quantization. + + Supports loading FP8 checkpoints with static weight and activation scales. + Future support may include dynamic scales. + + **Limitations**: + 1. Only supports per-tensor quantization due to `torch._scaled_mm` limitations. + 2. Only supports the `float8_e4m3fn` data type. + + Args: + quant_config (ModelOptFp8Config): The ModelOpt quantization configuration. + """ + + def __init__(self, quant_config: ModelOptFp8Config): + super().__init__() + self.quant_config = quant_config + self.cutlass_fp8_supported = cutlass_fp8_supported() + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + """Creates and registers weights, weight scales, and input scales for FP8 quantization.""" + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + weight_dtype = ( + torch.float8_e4m3fn + if self.quant_config.is_checkpoint_fp8_serialized + else params_dtype + ) + + # Set layer attributes + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + # Register weight + layer.register_parameter( + "weight", + ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=weight_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ), + ) + + if self.quant_config.is_checkpoint_fp8_serialized: + # Register weight and input scales + for scale_name in ["weight_scale", "input_scale"]: + layer.register_parameter( + scale_name, + PerTensorScaleParameter( + data=torch.full( + (len(output_partition_sizes),), + torch.finfo(torch.float32).min, + dtype=torch.float32, + ), + weight_loader=weight_loader, + ), + ) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """Requantizes weights after loading using the maximum scale.""" + max_w_scale, quantized_weight = requantize_with_max_scale( + layer.weight, layer.weight_scale, layer.logical_widths + ) + layer.weight = Parameter(quantized_weight.t(), requires_grad=False) + layer.weight_scale = Parameter(max_w_scale, requires_grad=False) + layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Applies FP8 linear transformation.""" + return apply_fp8_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias, + cutlass_fp8_supported=self.cutlass_fp8_supported, + ) diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py new file mode 100644 index 00000000000..87ba4cfc559 --- /dev/null +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -0,0 +1,117 @@ +from typing import Any, Dict, List, Optional + +import torch + +from sglang.srt.utils import is_cuda_available + +is_cuda = is_cuda_available() +if is_cuda: + from sgl_kernel import int8_scaled_mm + +from torch.nn.parameter import Parameter + +from sglang.srt.layers.linear import LinearMethodBase +from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter +from sglang.srt.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 + + +class W8A8Int8Config(QuantizationConfig): + """Config class for W8A8 Int8 Quantization. + + - Weight: static, per-channel, symmetric + - Activation: dynamic, per-token, symmetric + """ + + def __init__(self): + pass + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 75 + + @classmethod + def get_name(self) -> str: + return "w8a8_int8" + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "W8A8Int8Config": + return cls() + + def get_quant_method( + self, + layer: torch.nn.Module, + prefix: str, + ) -> Optional["QuantizeMethodBase"]: + from sglang.srt.layers.linear import LinearBase + + if isinstance(layer, LinearBase): + return W8A8Int8LinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class W8A8Int8LinearMethod(LinearMethodBase): + + def __init__(self, quantization_config: W8A8Int8Config): + self.quantization_config = quantization_config + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.weight = Parameter(layer.weight.t(), requires_grad=False) + layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs + ): + + weight_loader = extra_weight_attrs.get("weight_loader") + self.logical_widths = output_partition_sizes + + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8 + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ): + x_q, x_scale = per_token_quant_int8(x) + + return int8_scaled_mm( + x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias + ) diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 1df29ec68a9..0d46e7bba9a 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -47,8 +47,17 @@ def __init__( self.logit_cap = logit_cap self.sliding_window_size = sliding_window_size or -1 self.is_cross_attention = is_cross_attention + self.k_scale = None + self.v_scale = None - def forward(self, q, k, v, forward_batch: ForwardBatch, save_kv_cache=True): + def forward( + self, + q, + k, + v, + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + ): if k is not None: # For cross-layer sharing, kv can be None assert v is not None diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 80158573bd6..7093bb90d81 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -1,54 +1,933 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""MRotaryEmbedding""" +# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/refs/tags/v0.6.6.post1/vllm/model_executor/layers/rotary_embedding.py + +"""Rotary Positional Embeddings.""" +import math from typing import Any, Dict, List, Optional, Tuple, Union import torch +import torch.nn as nn +from vllm import _custom_ops as ops +from vllm.model_executor.custom_op import CustomOp + +from sglang.srt.layers.custom_op_util import register_custom_op +from sglang.srt.utils import is_cuda_available + +_is_cuda_available = is_cuda_available() +if _is_cuda_available: + from sgl_kernel import apply_rope_with_cos_sin_cache_inplace + + +def _rotate_neox(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., ::2] + x2 = x[..., 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) + + +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + +@register_custom_op("sglang_rotary_embedding") +class RotaryEmbedding(CustomOp): + """Original rotary positional embedding.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + # NOTE(ByronHsu): cache needs to be in FP32 for numerical stability + if not _is_cuda_available: + cache = cache.to(dtype) + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + if offsets is not None: + positions = positions + offsets + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if _is_cuda_available: + apply_rope_with_cos_sin_cache_inplace( + positions=positions, + query=query, + key=key, + head_size=self.head_size, + cos_sin_cache=self.cos_sin_cache, + is_neox=self.is_neox_style, + ) + else: + self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) + ops.rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + ) + return query, key + + def forward_xpu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + from vllm._ipex_ops import ipex_ops as ops + + self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype) + ops.rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + ) + return query, key + + def forward_hpu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + from habana_frameworks.torch.hpex.kernels import ( + RotaryPosEmbeddingMode, + apply_rotary_pos_emb, + ) + + positions = positions.flatten() + if offsets is not None: + positions = positions + offsets + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions).view(num_tokens, 1, -1) + cos, sin = cos_sin.chunk(2, dim=-1) + # HPU RoPE kernel requires hidden dimension for cos and sin to be equal + # to query hidden dimension, so the original tensors need to be + # expanded + # GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE + # and expansion of cos/sin tensors via concatenation + # GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE + # and expansion of cos/sin tensors via repeat_interleave + rope_mode: RotaryPosEmbeddingMode + if self.is_neox_style: + rope_mode = RotaryPosEmbeddingMode.BLOCKWISE + cos = torch.cat((cos, cos), dim=-1) + sin = torch.cat((sin, sin), dim=-1) + else: + rope_mode = RotaryPosEmbeddingMode.PAIRWISE + sin = torch.repeat_interleave(sin, 2, dim=-1, output_size=cos_sin.shape[-1]) + cos = torch.repeat_interleave(cos, 2, dim=-1, output_size=cos_sin.shape[-1]) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + return s + + +class LinearScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with linear scaling. + + It supports multiple scaling factors. Since multiple LoRA adapters may have + different scaling factors, we need multiple cos/sin caches. In this way, + instead of running rotary embedding kernel per lora, we can run multiple + lora in a batched way. + + In addition to that, we also keep the cos/sin cache for the scaling factor + of 1 (default) at all times. + + Exemplary for two scaling factors x=1, y and z with embeddings + [[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and + [[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and + [[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]], + + we construct the cos/sin cache as follows: + [[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p], + ... + [xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]] + + We then use offsets to index into the cos/sin cache for + the respective scaling factors. + + The offset to cache can be accessed via `scaling_factor_to_offset` API. + + Credits to the Reddit user /u/kaiokendev + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factors: Union[List[float], float], + dtype: torch.dtype, + ) -> None: + if isinstance(scaling_factors, float): + scaling_factors = [scaling_factors] + self.scaling_factors: List[float] = scaling_factors # noqa + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) + # Lazy initialized. + self._scaling_factor_to_offset: Dict[float, int] + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.base) + cache_list: List[torch.Tensor] = [] + # offsets to the next cache in a tensor. + # Each offset corresponds to the same index in scaling_factors. + offsets: List[int] = [] + for scaling_factor in self.scaling_factors: + # NOTE(woosuk): self.max_position_embeddings is the original + # maximum length before applying the rope scaling. + # Thus, the maximum length after applying the rope scaling is + # self.max_position_embeddings * self.scaling_factor. + max_len = self.max_position_embeddings * scaling_factor + t = torch.arange(max_len, dtype=torch.float) + t = t / scaling_factor + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + if not cache_list: + offset = 0 + else: + last_offset = offsets[-1] + next_max_len = cache_list[-1].shape[0] + offset = last_offset + next_max_len + offsets.append(offset) + cache_list.append(cache) + self._scaling_factor_to_offset = { + float(scaling_factor): offsets[i] + for i, scaling_factor in enumerate(self.scaling_factors) + } + assert len(self.scaling_factors) == len(offsets) + return torch.cat(cache_list, dim=0) + + @property + def scaling_factor_to_offset(self) -> Dict[float, int]: + return self._scaling_factor_to_offset + + +class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with Dynamic NTK scaling. + + Credits to the Reddit users /u/bloc97 and /u/emozilla + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + ) -> None: + self.scaling_factor = scaling_factor + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) + + def _compute_cos_sin_cache(self) -> torch.Tensor: + # NOTE(woosuk): self.max_position_embeddings is the original + # maximum length before applying the rope scaling. + # Thus, the maximum length after applying the rope scaling is + # self.max_position_embeddings * self.scaling_factor. + max_len = self.max_position_embeddings * self.scaling_factor + base = self.base * ( + (self.scaling_factor * max_len / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.rotary_dim / (self.rotary_dim - 2)) + inv_freq = self._compute_inv_freq(base) + t = torch.arange(max_len, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + +# Inverse dim formula to find dim based on number of rotations +def _yarn_find_correction_dim( + num_rotations: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +) -> float: + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +# Find dim range bounds based on rotations +def _yarn_find_correction_range( + low_rot: int, + high_rot: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +) -> Tuple[int, int]: + low = math.floor( + _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + _yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def _yarn_linear_ramp_mask( + low: float, high: float, dim: int, dtype: torch.dtype +) -> torch.Tensor: + if low == high: + high += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +def _yarn_get_mscale(scale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * math.log(scale) + 1.0 + + +class YaRNScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation + self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor) + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + self.rotary_dim, + self.base, + self.max_position_embeddings, + ) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = ( + 1 + - _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float) + ) * self.extrapolation_factor + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange( + self.max_position_embeddings * self.scaling_factor, dtype=torch.float32 + ) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * self.mscale + sin = freqs.sin() * self.mscale + cache = torch.cat((cos, sin), dim=-1) + return cache + + +class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): + """Phi3 family of models scaled rotary embedding. + + Based on the original RotaryEmbedding implementation. + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + original_max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + short_factor: List[float], + long_factor: List[float], + short_mscale: Optional[float] = None, + long_mscale: Optional[float] = None, + ): + super().__init__() + + if rotary_dim != head_size: + raise ValueError( + f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \ + rotary_dim != head_size ({rotary_dim}!={head_size})." + ) + if is_neox_style is False: + raise ValueError( + "`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style." + ) + + self.head_size = head_size + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.base = base + self.short_factor = short_factor + self.long_factor = long_factor + scale = self.max_position_embeddings / self.original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = math.sqrt( + 1 + math.log(scale) / math.log(self.original_max_position_embeddings) + ) + if short_mscale is None: + short_mscale = scaling_factor + if long_mscale is None: + long_mscale = scaling_factor + + self.short_mscale = short_mscale + self.long_mscale = long_mscale + + short_cache = self._compute_cos_sin_cache( + original_max_position_embeddings, short_factor, short_mscale + ) + short_cache = short_cache.to(dtype) + self.register_buffer("short_cos_sin_cache", short_cache, persistent=False) + + long_cache = self._compute_cos_sin_cache( + max_position_embeddings, long_factor, long_mscale + ) + long_cache = long_cache.to(dtype) + self.register_buffer("long_cos_sin_cache", long_cache, persistent=False) + + long_short_cache = torch.cat( + [self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0 + ) + self.register_buffer( + "long_short_cos_sin_cache", long_short_cache, persistent=False + ) + + def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor: + rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32) + inv_freq = 1.0 / ( + rescale_factors + * ( + self.base + ** ( + torch.arange(0, self.head_size, 2, dtype=torch.float) + / self.head_size + ) + ) + ) + return inv_freq + + def _compute_cos_sin_cache( + self, + max_position_embeddings: int, + rescale_factors: List[float], + mscale: float, + ) -> torch.Tensor: + inv_freq = self._compute_inv_freq(rescale_factors) + t = torch.arange(max_position_embeddings, dtype=torch.float) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * mscale + sin = freqs.sin() * mscale + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + query = query.view(*query.shape[:-1], -1, self.head_size) + key = key.view(*key.shape[:-1], -1, self.head_size) + + k = self.original_max_position_embeddings + long_prompt_offset = ( + torch.any(positions > k).float() * torch.full_like(positions, k) + ).long() + idx = ( + torch.add(positions, long_prompt_offset) + if long_prompt_offset is not None + else positions + ) + self.long_short_cos_sin_cache: torch.Tensor = self.long_short_cos_sin_cache.to( + idx.device + ) + idx = torch.add(idx, offsets) if offsets is not None else idx + cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx) + + cos, sin = cos_sin.chunk(2, dim=-1) + cos = cos.repeat(1, 2).unsqueeze(-2) + sin = sin.repeat(1, 2).unsqueeze(-2) + + query = query * cos + _rotate_neox(query) * sin + key = key * cos + _rotate_neox(key) * sin + + return query.flatten(-2), key.flatten(-2) + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class DeepseekScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ -class MRotaryEmbedding: + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1, + mscale_all_dim: float = 0, + device: Optional[str] = "cuda", + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation. + self.mscale = float( + yarn_get_mscale(self.scaling_factor, float(mscale)) + / yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) + * attn_factor + ) + self.device = device + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device=self.device) + / self.rotary_dim + ) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + self.rotary_dim, + self.base, + self.max_position_embeddings, + ) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = ( + 1 + - _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float) + ) * self.extrapolation_factor + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange( + self.max_position_embeddings * self.scaling_factor, + device=self.device, + dtype=torch.float32, + ) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * self.mscale + sin = freqs.sin() * self.mscale + cache = torch.cat((cos, sin), dim=-1) + print("Cache shape", cache.shape) + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """PyTorch-native implementation equivalent to forward().""" + query_rot = query[..., : self.rotary_dim] + key_rot = key[..., : self.rotary_dim] + if self.rotary_dim < self.head_size: + query_pass = query[..., self.rotary_dim :] + key_pass = key[..., self.rotary_dim :] + + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device) + cos_sin = self.cos_sin_cache[ + torch.add(positions, offsets) if offsets is not None else positions + ] + cos, sin = cos_sin.chunk(2, dim=-1) + if self.is_neox_style: + # NOTE(woosuk): Here we assume that the positions tensor has the + # shape [batch_size, seq_len]. + cos = cos.repeat(1, 1, 2).unsqueeze(-2) + sin = sin.repeat(1, 1, 2).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + + rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj + query_rot = query_rot * cos + rotate_fn(query_rot) * sin + key_rot = key_rot * cos + rotate_fn(key_rot) * sin + + if self.rotary_dim < self.head_size: + query = torch.cat((query_rot, query_pass), dim=-1) + key = torch.cat((key_rot, key_pass), dim=-1) + else: + query = query_rot + key = key_rot + return query, key + + +class Llama3RotaryEmbedding(RotaryEmbedding): + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + scaling_factor: float, + low_freq_factor: float, + high_freq_factor: float, + orig_max_position: int, + ) -> None: + self.scaling_factor = scaling_factor + self.low_freq_factor = low_freq_factor + self.high_freq_factor = high_freq_factor + self.orig_max_position = orig_max_position + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + inv_freqs = super()._compute_inv_freq(base) + low_freq_wavelen = self.orig_max_position / self.low_freq_factor + high_freq_wavelen = self.orig_max_position / self.high_freq_factor + + wave_len = 2 * math.pi / inv_freqs + if self.low_freq_factor != self.high_freq_factor: + smooth = (self.orig_max_position / wave_len - self.low_freq_factor) / ( + self.high_freq_factor - self.low_freq_factor + ) + else: + smooth = 0 + new_freqs = torch.where( + wave_len < high_freq_wavelen, + inv_freqs, + torch.where( + wave_len > low_freq_wavelen, + inv_freqs / self.scaling_factor, + (1 - smooth) * inv_freqs / self.scaling_factor + smooth * inv_freqs, + ), + ) + return new_freqs + + +class MRotaryEmbedding(RotaryEmbedding): """Rotary Embedding with Multimodal Sections.""" + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + mrope_section: Optional[List[int]] = None, + ) -> None: + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) + + self.mrope_section = mrope_section + if self.mrope_section: + assert sum(self.mrope_section) == rotary_dim // 2 + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """PyTorch-native implementation equivalent to forward(). + + Args: + positions: + [num_tokens,] (text only) or + [3, num_tokens] (T/H/W positions with multimodal inputs) + query: [num_tokens, num_heads * head_size] + key: [num_tokens, num_kv_heads * head_size] + """ + assert positions.ndim == 1 or positions.ndim == 2 + + num_tokens = positions.shape[-1] + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + if positions.ndim == 2: + assert self.mrope_section + + cos = torch.cat( + [m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))], + dim=-1, + ) + sin = torch.cat( + [m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))], + dim=-1, + ) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + @staticmethod def get_input_positions( - input_tokens: torch.Tensor, + input_tokens: List[int], image_grid_thw: Union[List[List[int]], torch.Tensor], + video_grid_thw: Union[List[List[int]], torch.Tensor], + image_token_id: int, + video_token_id: int, vision_start_token_id: int, + vision_end_token_id: int, spatial_merge_size: int, context_len: int = 0, + seq_len: Optional[int] = None, ) -> Tuple[List[List[int]], int]: """Get mrope input positions and delta value.""" if isinstance(image_grid_thw, torch.Tensor): image_grid_thw = image_grid_thw.tolist() + if isinstance(video_grid_thw, torch.Tensor): + video_grid_thw = video_grid_thw.tolist() + input_tokens_tensor = torch.tensor(input_tokens) vision_start_indices = torch.argwhere( - input_tokens == vision_start_token_id + input_tokens_tensor == vision_start_token_id ).squeeze(1) - image_indices = vision_start_indices + 1 - image_nums = image_indices.shape[0] + vision_tokens = input_tokens_tensor[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() llm_pos_ids_list: list = [] st = 0 - input_tokens_len = input_tokens.shape[0] - for image_index in range(image_nums): - ed = image_indices[image_index].item() - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) + remain_images, remain_videos = image_nums, video_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video llm_grid_t, llm_grid_h, llm_grid_w = ( t, h // spatial_merge_size, @@ -84,16 +963,17 @@ def get_input_positions( ) st = ed + llm_grid_t * llm_grid_h * llm_grid_w - if st < input_tokens_len: + if st < len(input_tokens): st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - text_len = input_tokens_len - st + text_len = len(input_tokens) - st llm_pos_ids_list.append( torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx ) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - llm_positions = llm_positions[:, context_len:] - mrope_position_delta = (llm_positions.max() + 1 - input_tokens_len).item() + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + llm_positions = llm_positions[:, context_len:seq_len] + return llm_positions.tolist(), mrope_position_delta @staticmethod @@ -110,3 +990,292 @@ def get_next_input_positions( ) for _ in range(3) ] + + +_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} + + +def get_rope( + head_size: int, + rotary_dim: int, + max_position: int, + base: int, + is_neox_style: bool = True, + rope_scaling: Optional[Dict[str, Any]] = None, + dtype: Optional[torch.dtype] = None, + partial_rotary_factor: float = 1.0, +) -> RotaryEmbedding: + if dtype is None: + dtype = torch.get_default_dtype() + if rope_scaling is not None: + # Transforms every value that is a list into a tuple for caching calls + rope_scaling_tuple = { + k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items() + } + rope_scaling_args = tuple(rope_scaling_tuple.items()) + else: + rope_scaling_args = None + if partial_rotary_factor < 1.0: + rotary_dim = int(rotary_dim * partial_rotary_factor) + key = ( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + rope_scaling_args, + dtype, + ) + if key in _ROPE_DICT: + return _ROPE_DICT[key] + + if rope_scaling is None: + rotary_emb = RotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, dtype + ) + else: + if "rope_type" in rope_scaling: + scaling_type = rope_scaling["rope_type"] + elif "type" in rope_scaling: + scaling_type = rope_scaling["type"] + else: + raise ValueError("Unknown RoPE scaling type") + + if scaling_type == "llama3": + scaling_factor = rope_scaling["factor"] + low_freq_factor = rope_scaling["low_freq_factor"] + high_freq_factor = rope_scaling["high_freq_factor"] + original_max_position = rope_scaling["original_max_position_embeddings"] + rotary_emb = Llama3RotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + scaling_factor, + low_freq_factor, + high_freq_factor, + original_max_position, + ) + elif scaling_type == "default": + if "mrope_section" in rope_scaling: + rotary_emb = MRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + mrope_section=rope_scaling["mrope_section"], + ) + else: + rotary_emb = RotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + ) + elif scaling_type == "linear": + scaling_factor = rope_scaling["factor"] + rotary_emb = LinearScalingRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + scaling_factor, + dtype, + ) + elif scaling_type == "dynamic": + scaling_factor = rope_scaling["factor"] + rotary_emb = DynamicNTKScalingRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + scaling_factor, + dtype, + ) + elif scaling_type == "yarn": + scaling_factor = rope_scaling["factor"] + original_max_position = rope_scaling["original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k + in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow") + } + rotary_emb = YaRNScalingRotaryEmbedding( + head_size, + rotary_dim, + original_max_position, + base, + is_neox_style, + scaling_factor, + dtype, + **extra_kwargs, + ) + elif scaling_type == "deepseek_yarn": + scaling_factor = rope_scaling["factor"] + original_max_position = rope_scaling["original_max_position_embeddings"] + # assert max_position == original_max_position * scaling_factor + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k + in ( + "extrapolation_factor", + "attn_factor", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ) + } + rotary_emb = DeepseekScalingRotaryEmbedding( + head_size, + rotary_dim, + original_max_position, + base, + is_neox_style, + scaling_factor, + dtype, + **extra_kwargs, + ) + elif scaling_type == "longrope": + short_factor = rope_scaling["short_factor"] + long_factor = rope_scaling["long_factor"] + original_max_position = rope_scaling["original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k in ("short_mscale", "long_mscale") + } + rotary_emb = Phi3LongRoPEScaledRotaryEmbedding( + head_size, + rotary_dim, + max_position, + original_max_position, + base, + is_neox_style, + dtype, + short_factor, + long_factor, + **extra_kwargs, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + _ROPE_DICT[key] = rotary_emb + return rotary_emb + + +def get_rope_cpu( + head_size: int, + rotary_dim: int, + max_position: int, + base: int, + is_neox_style: bool = True, + rope_scaling: Optional[Dict[str, Any]] = None, + dtype: Optional[torch.dtype] = None, + partial_rotary_factor: float = 1.0, + device: Optional[str] = None, +) -> RotaryEmbedding: + if dtype is None: + dtype = torch.get_default_dtype() + if rope_scaling is not None: + # Transforms every value that is a list into a tuple for caching calls + rope_scaling_tuple = { + k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items() + } + rope_scaling_args = tuple(rope_scaling_tuple.items()) + else: + rope_scaling_args = None + if partial_rotary_factor < 1.0: + rotary_dim = int(rotary_dim * partial_rotary_factor) + key = ( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + rope_scaling_args, + dtype, + ) + if key in _ROPE_DICT: + return _ROPE_DICT[key] + + assert rope_scaling is not None + scaling_type = rope_scaling["rope_type"] + assert ( + scaling_type == "deepseek_yarn" + ), "Only deepseek_yarn is supported for CPU for now" + + scaling_factor = rope_scaling["factor"] + original_max_position = rope_scaling["original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k + in ( + "extrapolation_factor", + "attn_factor", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ) + } + extra_kwargs["device"] = device + rotary_emb = DeepseekScalingRotaryEmbedding( + head_size, + rotary_dim, + original_max_position, + base, + is_neox_style, + scaling_factor, + dtype, + **extra_kwargs, + ) + + _ROPE_DICT[key] = rotary_emb + return rotary_emb + + +def get_rope_wrapper( + head_size: int, + rotary_dim: int, + max_position: int, + base: int, + is_neox_style: bool = True, + rope_scaling: Optional[Dict[str, Any]] = None, + dtype: Optional[torch.dtype] = None, + partial_rotary_factor: float = 1.0, + device: Optional[str] = None, +): + if device != "cpu": + return get_rope( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + rope_scaling, + dtype, + partial_rotary_factor, + ) + + return get_rope_cpu( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + rope_scaling, + dtype, + partial_rotary_factor, + device, + ) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index b0dfda3e882..73ef13c35f2 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -1,16 +1,19 @@ import logging -from typing import Union +from typing import List import torch +import torch.distributed as dist from torch import nn +from sglang.srt.distributed import get_tensor_model_parallel_group +from sglang.srt.layers.dp_attention import get_attention_tp_group from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo -from sglang.srt.utils import crash_on_warnings, is_flashinfer_available +from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda_available -if is_flashinfer_available(): - from flashinfer.sampling import ( +if is_cuda_available(): + from sgl_kernel import ( min_p_sampling_from_probs, top_k_renorm_prob, top_k_top_p_sampling_from_probs, @@ -20,21 +23,30 @@ logger = logging.getLogger(__name__) +SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP") + class Sampler(nn.Module): def __init__(self): super().__init__() self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"] + self.tp_sync_group = get_tensor_model_parallel_group().device_group + + if global_server_args_dict["enable_dp_attention"]: + self.tp_sync_group = get_attention_tp_group().device_group def forward( self, - logits: Union[torch.Tensor, LogitsProcessorOutput], + logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo, + return_logprob: bool, + top_logprobs_nums: List[int], ): - if isinstance(logits, LogitsProcessorOutput): - logits = logits.next_token_logits + logits = logits_output.next_token_logits - logits = logits.contiguous() + # Apply the custom logit processors if registered in the sampling info. + if sampling_info.has_custom_logit_processor: + self._apply_custom_logit_processor(logits, sampling_info) if self.use_nan_detectioin and torch.any(torch.isnan(logits)): logger.warning("Detected errors during sampling! NaN in the logits.") @@ -47,14 +59,25 @@ def forward( if sampling_info.is_all_greedy: # Use torch.argmax if all requests use greedy sampling batch_next_token_ids = torch.argmax(logits, -1) + if return_logprob: + logprobs = torch.nn.functional.log_softmax(logits, dim=-1) else: # Post process logits logits.div_(sampling_info.temperatures) probs = torch.softmax(logits, dim=-1) - logits = None del logits if global_server_args_dict["sampling_backend"] == "flashinfer": + if return_logprob: + # NOTE: the top_p_renorm_prob from flashinfer has numerical problems, + # https://github.com/flashinfer-ai/flashinfer/issues/708 + # so we use the torch implementation. + + # clamp to avoid -inf + logprobs = torch.log( + top_p_normalize_probs_torch(probs, sampling_info.top_ps) + ).clamp(min=torch.finfo(probs.dtype).min) + max_top_k_round, batch_size = 32, probs.shape[0] uniform_samples = torch.rand( (max_top_k_round, batch_size), device=probs.device @@ -77,6 +100,7 @@ def forward( if self.use_nan_detectioin and not torch.all(success): logger.warning("Detected errors during sampling!") batch_next_token_ids = torch.zeros_like(batch_next_token_ids) + elif global_server_args_dict["sampling_backend"] == "pytorch": # A slower fallback implementation with torch native operations. batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch( @@ -84,34 +108,129 @@ def forward( sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps, + sampling_info.need_min_p_sampling, ) + if return_logprob: + # clamp to avoid -inf + logprobs = torch.log( + top_p_normalize_probs_torch(probs, sampling_info.top_ps) + ).clamp(min=torch.finfo(probs.dtype).min) else: raise ValueError( f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" ) + # Attach logprobs to logits_output (in-place modification) + if return_logprob: + if any(x > 0 for x in top_logprobs_nums): + ( + logits_output.next_token_top_logprobs_val, + logits_output.next_token_top_logprobs_idx, + ) = get_top_logprobs(logprobs, top_logprobs_nums) + + logits_output.next_token_logprobs = logprobs[ + torch.arange(len(batch_next_token_ids), device=sampling_info.device), + batch_next_token_ids, + ] + + if SYNC_TOKEN_IDS_ACROSS_TP or sampling_info.grammars: + # For performance reasons, SGLang does not sync the final token IDs across TP ranks by default. + # This saves one all-reduce, but the correctness of this approach depends on the determinism of several operators: + # the last all-reduce, the last lm_head matmul, and all sampling kernels. + # These kernels are deterministic in most cases, but there are some rare instances where they are not deterministic. + # In such cases, enable this env variable to prevent hanging due to TP ranks becoming desynchronized. + # When using xgrammar, this becomes more likely so we also do the sync when grammar is used. + + torch.distributed.all_reduce( + batch_next_token_ids, + op=dist.ReduceOp.MIN, + group=self.tp_sync_group, + ) + return batch_next_token_ids.to(torch.int32) + def _apply_custom_logit_processor( + self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo + ): + """Apply custom logit processors to the logits. + This function will modify the logits in-place.""" + + assert logits.shape[0] == len(sampling_batch_info), ( + f"The batch size of logits ({logits.shape[0]}) does not match the batch size of " + f"sampling_batch_info ({len(sampling_batch_info)})" + ) + + for _, ( + processor, + batch_mask, + ) in sampling_batch_info.custom_logit_processor.items(): + # Get the batch indices that need to be processed + batch_indices = batch_mask.nonzero(as_tuple=True)[0] + + assert batch_mask.shape[0] == len(sampling_batch_info), ( + f"The number of batch mask ({batch_mask.shape[0]}) does not match the number of " + f"sampling_batch_info ({len(sampling_batch_info)})" + ) + + # Apply the processor to the logits + logits[batch_mask] = processor( + logits[batch_mask], + [sampling_batch_info.custom_params[i] for i in batch_indices], + ) + + logger.debug( + f"Custom logit processor {processor.__class__.__name__} is applied." + ) + def top_k_top_p_min_p_sampling_from_probs_torch( probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor, min_ps: torch.Tensor, + need_min_p_sampling: bool, ): """A top-k, top-p and min-p sampling implementation with native pytorch operations.""" probs_sort, probs_idx = probs.sort(dim=-1, descending=True) probs_sum = torch.cumsum(probs_sort, dim=-1) - min_p_thresholds = probs_sort[:, 0] * min_ps - probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0 probs_sort[ torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1) >= top_ks.view(-1, 1) ] = 0.0 - probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0 - probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0]) + probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0 + + if need_min_p_sampling: + min_p_thresholds = probs_sort[:, 0] * min_ps + probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0 + sampled_index = torch.multinomial(probs_sort, num_samples=1) # int32 range is enough to represent the token ids probs_idx = probs_idx.to(torch.int32) batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1) return batch_next_token_ids + + +def top_p_normalize_probs_torch( + probs: torch.Tensor, + top_ps: torch.Tensor, +): + # See also top_k_top_p_min_p_sampling_from_probs_torch + probs_sort, probs_idx = probs.sort(dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0 + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort) + + +def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]): + max_k = max(top_logprobs_nums) + ret = logprobs.topk(max_k, dim=1) + values = ret.values.tolist() + indices = ret.indices.tolist() + + output_top_logprobs_val = [] + output_top_logprobs_idx = [] + for i, k in enumerate(top_logprobs_nums): + output_top_logprobs_val.append(values[i][:k]) + output_top_logprobs_idx.append(indices[i][:k]) + return output_top_logprobs_val, output_top_logprobs_idx diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index 910309da973..e08abd5ae1d 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -2,11 +2,44 @@ Common utilities for torchao. """ +import logging +import os +import pwd +from typing import Callable, Optional + import torch +logger = logging.getLogger(__name__) + + +def get_gemlite_cache_path() -> str: + return f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json" + + +def save_gemlite_cache(print_error: bool = False) -> bool: + try: + from gemlite.core import GemLiteLinearTriton + + GemLiteLinearTriton.cache_config(get_gemlite_cache_path()) + except Exception: + if print_error: + logger.error("Failed to save the GemLite cache.") + return False + return True + + +def proj_filter( + module: torch.nn.Module, + fqn: str, +): + """Filter function for quantizing projection layers.""" + return "proj" in fqn + def apply_torchao_config_to_model( - model: torch.nn.Module, torchao_config: str, filter_fn=None + model: torch.nn.Module, + torchao_config: str, + filter_fn: Optional[Callable] = proj_filter, ): """Quantize a modelwith torchao quantization specified by torchao_config @@ -27,11 +60,6 @@ def apply_torchao_config_to_model( ) from torchao.quantization.observer import PerRow, PerTensor - if filter_fn is None: - - def filter_fn(module, fqn): - return "proj" in fqn - if torchao_config == "" or torchao_config is None: return model elif "int8wo" in torchao_config: @@ -47,6 +75,29 @@ def filter_fn(module, fqn): 256, ], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}" quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn) + elif "gemlite" in torchao_config: + # gemlite--- or + # gemlite-- (packing_bitwidth defaults to 32) + from gemlite.core import GemLiteLinearTriton + from torchao.quantization import gemlite_uintx_weight_only + + _quant_args = torchao_config.split("-") + bit_width = int(_quant_args[-2]) + group_size = None if _quant_args[-1] == "None" else int(_quant_args[-1]) + + try: + packing_bitwidth = int(_quant_args[-3]) + except (ValueError, IndexError): + # if only 2 inputs found or conversion fails, use default value + packing_bitwidth = 32 + + quantize_( + model, gemlite_uintx_weight_only(group_size, bit_width, packing_bitwidth) + ) + + # try to load gemlite kernel config + GemLiteLinearTriton.load_config(get_gemlite_cache_path()) + elif "fp8wo" in torchao_config: # this requires newer hardware # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index effea1c6c95..ed9d67ef970 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -6,14 +6,14 @@ import torch import torch.nn.functional as F from torch.nn.parameter import Parameter, UninitializedParameter -from vllm.distributed import ( + +from sglang.srt.distributed import ( divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.parameter import BasevLLMParameter - +from sglang.srt.layers.parameter import BasevLLMParameter from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, @@ -220,6 +220,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", enable_tp: bool = True, + use_presharded_weights: bool = False, ): super().__init__() self.quant_config = quant_config @@ -236,6 +237,12 @@ def __init__( self.padding_size = padding_size self.org_vocab_size = org_num_embeddings or num_embeddings num_added_embeddings = num_embeddings - self.org_vocab_size + self.use_presharded_weights = use_presharded_weights + if use_presharded_weights: + assert ( + num_added_embeddings == 0 + ), "Lora is not supported with presharded weights." + self.org_vocab_size_padded = pad_vocab_size( self.org_vocab_size, self.padding_size ) @@ -447,10 +454,14 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): start_idx = start_idx // packed_factor shard_size = shard_size // packed_factor else: - assert loaded_weight.shape[output_dim] == self.org_vocab_size + assert loaded_weight.shape[output_dim] == ( + self.org_vocab_size + // (self.tp_size if self.use_presharded_weights else 1) + ) # Copy the data. - loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + if not self.use_presharded_weights: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) param[: loaded_weight.shape[0]].data.copy_(loaded_weight) param[loaded_weight.shape[0] :].data.fill_(0) @@ -514,6 +525,7 @@ def __init__( padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_presharded_weights: bool = False, ): super().__init__( num_embeddings, @@ -523,6 +535,7 @@ def __init__( padding_size, quant_config, prefix, + use_presharded_weights=use_presharded_weights, ) self.quant_config = quant_config if bias: diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index 839d10222e2..c8cbe36602b 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -19,18 +19,11 @@ # https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py -import json -import os import re -from typing import Any, Dict, List, Optional, Tuple -import safetensors.torch import torch from torch import nn -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, - VocabParallelEmbedding, -) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.linear import ( ColumnParallelLinear, @@ -38,7 +31,6 @@ QKVParallelLinear, RowParallelLinear, ) -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.loader import DefaultModelLoader diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py new file mode 100644 index 00000000000..4560a270870 --- /dev/null +++ b/python/sglang/srt/managers/cache_controller.py @@ -0,0 +1,307 @@ +from __future__ import annotations + +""" +Copyright 2023-2025 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import logging +import threading +from queue import PriorityQueue, Queue +from typing import Optional + +import torch + +from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPoolHost + +logger = logging.getLogger(__name__) + + +class CacheOperation: + + counter = 0 + + def __init__( + self, + host_indices: torch.Tensor, + device_indices: torch.Tensor, + node_id: int, + priority: Optional[int] = None, + ): + self.host_indices = host_indices + self.device_indices = device_indices + self.node_ids = [node_id] + self.data = None + + self.id = CacheOperation.counter + CacheOperation.counter += 1 + # default priority is the order of creation + self.priority = priority if priority is not None else self.id + + def merge(self, other: "CacheOperation") -> None: + # multiple operations can be merged into a single operation for batch processing + self.host_indices = torch.cat([self.host_indices, other.host_indices]) + self.device_indices = torch.cat([self.device_indices, other.device_indices]) + self.priority = min(self.priority, other.priority) + self.node_ids.extend(other.node_ids) + + def __lt__(self, other: "CacheOperation"): + return self.priority < other.priority + + +class TransferBuffer: + """ + Overlapping buffer preparation and transfer operations to improve throughput. + """ + + def __init__(self, buffer_count: int = 3, max_buffer_size: int = 1000) -> None: + self.buffers = Queue(maxsize=buffer_count) + # todo: adjust the buffer size based on throughput profile of the system + self.max_buffer_size = max_buffer_size + + def full(self) -> bool: + return self.buffers.full() + + def empty(self) -> bool: + return self.buffers.empty() + + def put(self, item, block=True) -> None: + self.buffers.put(item, block=block) + + def get(self, block=True) -> Optional[CacheOperation]: + try: + return self.buffers.get(block=block) + except Exception as e: + logger.error(e) + + +class HiCacheController: + + def __init__( + self, + mem_pool_device: MHATokenToKVPool, + mem_pool_host: MLATokenToKVPoolHost, + write_policy: str = "write_through_selective", + ): + + self.mem_pool_device = mem_pool_device + self.mem_pool_host = mem_pool_host + self.write_policy = write_policy + + if write_policy not in [ + "write_through", + "write_through_selective", + "write_back", + ]: + raise ValueError(f"Invalid write policy: {write_policy}") + + self.write_queue = PriorityQueue() + self.load_queue = PriorityQueue() + + self.ack_write_queue = Queue() + self.ack_load_queue = Queue() + + self.write_buffer = TransferBuffer() + self.load_buffer = TransferBuffer() + + self.write_stream = torch.cuda.Stream() + self.load_stream = torch.cuda.Stream() + + self.write_thread = threading.Thread( + target=self.write_thread_func_buffer, daemon=True + ) + self.load_thread = threading.Thread( + target=self.load_thread_func_buffer, daemon=True + ) + self.write_thread.start() + self.load_thread.start() + + def write( + self, + device_indices: torch.Tensor, + priority: Optional[int] = None, + node_id: int = 0, + ) -> Optional[torch.Tensor]: + """ + Back up KV caches from device memory to host memory. + """ + host_indices = self.mem_pool_host.alloc(len(device_indices)) + if host_indices is None: + return None + self.write_queue.put( + CacheOperation(host_indices, device_indices, node_id, priority) + ) + self.mem_pool_host.protect_write(host_indices) + return host_indices + + def load( + self, + host_indices: torch.Tensor, + priority: Optional[int] = None, + node_id: int = 0, + ) -> Optional[torch.Tensor]: + """ + Load KV caches from host memory to device memory. + """ + device_indices = self.mem_pool_device.alloc(len(host_indices)) + if device_indices is None: + return None + self.load_queue.put( + CacheOperation(host_indices, device_indices, node_id, priority) + ) + self.mem_pool_host.protect_load(host_indices) + return device_indices + + def write_thread_func_direct(self): + """ + Directly write through KV caches to host memory without buffering. + """ + with torch.cuda.stream(self.write_stream): + while True: + try: + operation = self.write_queue.get(block=True) + operation.data = self.mem_pool_device.get_flat_data( + operation.device_indices + ) + self.mem_pool_host.transfer(operation.host_indices, operation.data) + self.mem_pool_host.complete_io(operation.host_indices) + for node_id in operation.node_ids: + self.ack_write_queue.put(node_id) + except Exception as e: + logger.error(e) + + def load_thread_func_direct(self): + """ + Directly load KV caches from host memory to device memory without buffering. + """ + with torch.cuda.stream(self.load_stream): + while True: + try: + operation = self.load_queue.get(block=True) + operation.data = self.mem_pool_host.get_flat_data( + operation.host_indices + ) + self.mem_pool_device.transfer( + operation.device_indices, operation.data + ) + self.mem_pool_host.complete_io(operation.host_indices) + for node_id in operation.node_ids: + self.ack_load_queue.put(node_id) + except Exception as e: + logger.error(e) + + def write_aux_func(self, no_wait=False): + """ + Auxiliary function to prepare the buffer for write operations. + """ + buffer = None + while True: + try: + operation = self.write_queue.get(block=True) + if buffer is None: + buffer = operation + else: + buffer.merge(operation) + if ( + no_wait + or len(buffer.host_indices) >= self.write_buffer.max_buffer_size + or self.write_queue.empty() + or self.write_buffer.empty() + ): + assert ( + buffer.device_indices.is_cuda + ), "Device indices should be on GPU" + buffer.data = self.mem_pool_device.get_flat_data( + buffer.device_indices + ).contiguous() + self.write_buffer.put(buffer, block=True) + buffer = None + except Exception as e: + logger.error(e) + + def load_aux_func(self): + """ + Auxiliary function to prepare the buffer for load operations. + """ + buffer = None + while True: + try: + operation = self.load_queue.get(block=True) + if buffer is None: + buffer = operation + else: + buffer.merge(operation) + if ( + len(buffer.host_indices) >= self.load_buffer.max_buffer_size + or self.load_queue.empty() + or self.load_buffer.empty() + ): + buffer.data = ( + self.mem_pool_host.get_flat_data(buffer.host_indices) + .contiguous() + .pin_memory() + ) + self.load_buffer.put(buffer, block=True) + buffer = None + except Exception as e: + logger.error(e) + + def write_thread_func_buffer(self): + aux_thread = threading.Thread(target=self.write_aux_func, daemon=True) + aux_thread.start() + with torch.cuda.stream(self.write_stream): + while True: + operation = self.write_buffer.get() + if operation is None: + continue + self.mem_pool_host.transfer(operation.host_indices, operation.data) + self.mem_pool_host.complete_io(operation.host_indices) + for node_id in operation.node_ids: + self.ack_write_queue.put(node_id) + + def load_thread_func_buffer(self): + aux_thread = threading.Thread(target=self.load_aux_func, daemon=True) + aux_thread.start() + with torch.cuda.stream(self.load_stream): + while True: + operation = self.load_buffer.get() + if operation is None: + continue + self.mem_pool_device.transfer(operation.device_indices, operation.data) + self.mem_pool_host.complete_io(operation.host_indices) + for node_id in operation.node_ids: + self.ack_load_queue.put(node_id) + + def evict_device( + self, device_indices: torch.Tensor, host_indices: torch.Tensor + ) -> int: + if self.mem_pool_host.is_synced(host_indices): + self.mem_pool_device.free(device_indices) + self.mem_pool_host.update_backup(host_indices) + return len(device_indices) + else: + raise ValueError( + f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}" + ) + + def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int: + if not backup_only: + raise ValueError("Other eviction policies are not supported yet.") + + if self.mem_pool_host.is_backup(host_indices): + self.mem_pool_host.free(host_indices) + return len(host_indices) + else: + raise ValueError( + f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}" + ) diff --git a/python/sglang/srt/managers/configure_logging.py b/python/sglang/srt/managers/configure_logging.py new file mode 100644 index 00000000000..187af4d9c08 --- /dev/null +++ b/python/sglang/srt/managers/configure_logging.py @@ -0,0 +1,46 @@ +""" +Copyright 2023-2025 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +""" +Configure the logging settings of a server. + +Usage: +python3 -m sglang.srt.managers.configure_logging --url http://localhost:30000 +""" + +import argparse + +import requests + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--url", type=str, default="http://localhost:30000") + parser.add_argument("--log-requests", action="store_true") + parser.add_argument( + "--dump-requests-folder", type=str, default="/tmp/sglang_request_dump" + ) + parser.add_argument("--dump-requests-threshold", type=int, default=1000) + args = parser.parse_args() + + response = requests.post( + args.url + "/configure_logging", + json={ + "log_requests": args.log_requests, + "log_requests_level": 1, # Log full requests + "dump_requests_folder": args.dump_requests_folder, + "dump_requests_threshold": args.dump_requests_threshold, + }, + ) + assert response.status_code == 200 diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 8edb79417e2..3b959b1ba76 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -20,8 +20,10 @@ from enum import Enum, auto import psutil +import setproctitle import zmq +from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.managers.io_struct import ( TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, @@ -54,6 +56,7 @@ class DataParallelController: def __init__(self, server_args, port_args) -> None: # Parse args + self.max_total_num_tokens = None self.server_args = server_args self.port_args = port_args self.load_balance_method = LoadBalanceMethod.from_str( @@ -62,9 +65,10 @@ def __init__(self, server_args, port_args) -> None: # Init inter-process communication self.context = zmq.Context(1 + server_args.dp_size) - self.recv_from_tokenizer = get_zmq_socket( - self.context, zmq.PULL, port_args.scheduler_input_ipc_name - ) + if server_args.node_rank == 0: + self.recv_from_tokenizer = get_zmq_socket( + self.context, zmq.PULL, port_args.scheduler_input_ipc_name, False + ) # Dispatch method self.round_robin_counter = 0 @@ -74,33 +78,50 @@ def __init__(self, server_args, port_args) -> None: } self.dispatching = dispatch_lookup[self.load_balance_method] - # Start data parallel workers - base_gpu_id = 0 + # Launch data parallel workers + self.scheduler_procs = [] self.workers = [None] * server_args.dp_size + if not server_args.enable_dp_attention: + dp_port_args = self.launch_dp_schedulers(server_args, port_args) + else: + dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args) + + # Only node rank 0 runs the real data parallel controller that dispatches the requests. + if server_args.node_rank == 0: + for dp_rank in range(server_args.dp_size): + self.workers[dp_rank] = get_zmq_socket( + self.context, + zmq.PUSH, + dp_port_args[dp_rank].scheduler_input_ipc_name, + True, + ) + + self.max_req_input_len = None + + def launch_dp_schedulers(self, server_args, port_args): + base_gpu_id = 0 + threads = [] sockets = [] + dp_port_args = [] for dp_rank in range(server_args.dp_size): tmp_port_args = PortArgs.init_new(server_args) tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name + dp_port_args.append(tmp_port_args) - if server_args.enable_dp_attention: - # Data parallelism resues the tensor parallelism group, - # so all dp ranks should use the same nccl port. - tmp_port_args.nccl_port = port_args.nccl_port - else: - # This port is checked free in PortArgs.init_new. - # We hold it first so that the next dp worker gets a different port - sockets.append(bind_port(tmp_port_args.nccl_port)) + # This port is checked free in PortArgs.init_new. + # We hold it first so that the next dp worker gets a different port + sockets.append(bind_port(tmp_port_args.nccl_port)) # Create a thread for each worker thread = threading.Thread( - target=self.launch_worker_func, + target=self.launch_tensor_parallel_group, args=(server_args, tmp_port_args, base_gpu_id, dp_rank), ) threads.append(thread) - base_gpu_id += 1 if server_args.enable_dp_attention else server_args.tp_size + base_gpu_id += server_args.tp_size # Free all sockets before starting the threads to launch TP workers for sock in sockets: @@ -112,26 +133,14 @@ def __init__(self, server_args, port_args) -> None: for thread in threads: thread.join() - def launch_worker_func( - self, - server_args: ServerArgs, - port_args: PortArgs, - base_gpu_id: int, - dp_rank: int, - ): - logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.") + return dp_port_args - launch_func_ = ( - self.launch_tensor_parallel_process - if server_args.enable_dp_attention - else self.launch_tensor_parallel_group - ) - self.workers[dp_rank] = launch_func_( - server_args, - port_args, - base_gpu_id, - dp_rank, - ) + def launch_dp_attention_schedulers(self, server_args, port_args): + self.launch_tensor_parallel_group(server_args, port_args, 0, None) + dp_port_args = [] + for dp_rank in range(server_args.dp_size): + dp_port_args.append(PortArgs.init_new(server_args, dp_rank)) + return dp_port_args def launch_tensor_parallel_group( self, @@ -140,8 +149,10 @@ def launch_tensor_parallel_group( base_gpu_id: int, dp_rank: int, ): + if not server_args.enable_dp_attention: + logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.") + # Launch tensor parallel scheduler processes - scheduler_procs = [] scheduler_pipe_readers = [] tp_size_per_node = server_args.tp_size // server_args.nnodes tp_rank_range = range( @@ -149,52 +160,39 @@ def launch_tensor_parallel_group( tp_size_per_node * (server_args.node_rank + 1), ) for tp_rank in tp_rank_range: + rank_port_args = port_args + + if server_args.enable_dp_attention: + # dp attention has different sharding logic + _, _, dp_rank = compute_dp_attention_world_info( + server_args.enable_dp_attention, + tp_rank, + server_args.tp_size, + server_args.dp_size, + ) + # compute zmq ports for this dp rank + rank_port_args = PortArgs.init_new(server_args, dp_rank) + # Data parallelism resues the tensor parallelism group, + # so all dp ranks should use the same nccl port. + rank_port_args.nccl_port = port_args.nccl_port + reader, writer = mp.Pipe(duplex=False) gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node proc = mp.Process( target=run_scheduler_process, - args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer), + args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer), ) proc.start() - scheduler_procs.append(proc) + self.scheduler_procs.append(proc) scheduler_pipe_readers.append(reader) - send_to = get_zmq_socket( - self.context, zmq.PUSH, port_args.scheduler_input_ipc_name - ) - - # Wait for model to finish loading and get max token nums + # Wait for model to finish loading scheduler_info = [] for i in range(len(scheduler_pipe_readers)): scheduler_info.append(scheduler_pipe_readers[i].recv()) self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"] - - return send_to - - def launch_tensor_parallel_process( - self, - server_args: ServerArgs, - port_args: PortArgs, - base_gpu_id: int, - dp_rank: int, - ): - reader, writer = mp.Pipe(duplex=False) - gpu_id = base_gpu_id - tp_rank = dp_rank - proc = mp.Process( - target=run_scheduler_process, - args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer), - ) - proc.start() - send_to = get_zmq_socket( - self.context, zmq.PUSH, port_args.scheduler_input_ipc_name - ) - - scheduler_info = reader.recv() - self.max_total_num_tokens = scheduler_info["max_total_num_tokens"] - - return send_to + self.max_req_input_len = scheduler_info[0]["max_req_input_len"] def round_robin_scheduler(self, req): self.workers[self.round_robin_counter].send_pyobj(req) @@ -220,8 +218,8 @@ def event_loop(self): ): self.dispatching(recv_req) else: - # Send other control messages to all workers - for worker in self.workers: + # Send other control messages to first worker of tp group + for worker in self.workers[:: self.server_args.tp_size]: worker.send_pyobj(recv_req) @@ -230,15 +228,26 @@ def run_data_parallel_controller_process( port_args: PortArgs, pipe_writer, ): + setproctitle.setproctitle("sglang::data_parallel_controller") configure_logger(server_args) parent_process = psutil.Process().parent() try: controller = DataParallelController(server_args, port_args) pipe_writer.send( - {"status": "ready", "max_total_num_tokens": controller.max_total_num_tokens} + { + "status": "ready", + "max_total_num_tokens": controller.max_total_num_tokens, + "max_req_input_len": controller.max_req_input_len, + } ) - controller.event_loop() + if server_args.node_rank == 0: + controller.event_loop() + for proc in controller.scheduler_procs: + proc.join() + logger.error( + f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}" + ) except Exception: traceback = get_exception_traceback() logger.error(f"DataParallelController hit an exception: {traceback}") diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index e74ba5026c1..a8ded73bccc 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -15,11 +15,13 @@ import dataclasses import logging +import os import signal from collections import OrderedDict -from typing import List, Union +from typing import Dict, List, Union import psutil +import setproctitle import zmq from sglang.srt.hf_transformers_utils import get_tokenizer @@ -28,13 +30,18 @@ BatchStrOut, BatchTokenIDOut, ) -from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import configure_logger, get_zmq_socket from sglang.utils import find_printable_text, get_exception_traceback logger = logging.getLogger(__name__) +# Maximum number of request states that detokenizer can hold. When exceeded, +# oldest request states will be evicted. Default: 65536 (1<<16). +# For more details, see: https://github.com/sgl-project/sglang/issues/2812 +# Use power of 2 values for better memory allocation. +DETOKENIZER_MAX_STATES = int(os.environ.get("SGLANG_DETOKENIZER_MAX_STATES", 1 << 16)) + @dataclasses.dataclass class DecodeStatus: @@ -58,10 +65,10 @@ def __init__( # Init inter-process communication context = zmq.Context(2) self.recv_from_scheduler = get_zmq_socket( - context, zmq.PULL, port_args.detokenizer_ipc_name + context, zmq.PULL, port_args.detokenizer_ipc_name, True ) self.send_to_tokenizer = get_zmq_socket( - context, zmq.PUSH, port_args.tokenizer_ipc_name + context, zmq.PUSH, port_args.tokenizer_ipc_name, False ) if server_args.skip_tokenizer_init: @@ -71,21 +78,30 @@ def __init__( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, ) - self.decode_status = LimitedCapacityDict() + self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES) + + def trim_matched_stop( + self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool + ): + if no_stop_trim or not finished_reason: + return output - def trim_eos(self, output: Union[str, List[int]], finished_reason, no_stop_trim): - if no_stop_trim: + matched = finished_reason.get("matched", None) + if not matched: return output - # Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit - if isinstance(finished_reason, FINISH_MATCHED_STR) and isinstance(output, str): - pos = output.find(finished_reason.matched) + # TODO(lmzheng): handle the case where multiple stop strs are hit + + # Trim stop str. + if isinstance(matched, str) and isinstance(output, str): + pos = output.find(matched) return output[:pos] if pos != -1 else output - if isinstance(finished_reason, FINISH_MATCHED_TOKEN) and isinstance( - output, list - ): + + # Trim stop token. + if isinstance(matched, int) and isinstance(output, list): assert len(output) > 0 return output[:-1] return output @@ -124,9 +140,9 @@ def event_loop(self): s.decode_ids = recv_obj.decode_ids[i] read_ids.append( - self.trim_eos( + self.trim_matched_stop( s.decode_ids[s.surr_offset :], - recv_obj.finished_reason[i], + recv_obj.finished_reasons[i], recv_obj.no_stop_trim[i], ) ) @@ -147,9 +163,19 @@ def event_loop(self): # Incremental decoding output_strs = [] for i in range(bs): - s = self.decode_status[recv_obj.rids[i]] + try: + s = self.decode_status[recv_obj.rids[i]] + except KeyError: + raise RuntimeError( + f"Decode status not found for request {recv_obj.rids[i]}. " + "It may be due to the request being evicted from the decode status due to memory pressure. " + "Please increase the maximum number of requests by setting " + "the SGLANG_DETOKENIZER_MAX_STATES environment variable to a bigger value than the default value. " + f"The current value is {DETOKENIZER_MAX_STATES}. " + "For more details, see: https://github.com/sgl-project/sglang/issues/2812" + ) new_text = read_texts[i][len(surr_texts[i]) :] - if recv_obj.finished_reason[i] is None: + if recv_obj.finished_reasons[i] is None: # Streaming chunk: update the decode status if len(new_text) > 0 and not new_text.endswith("�"): s.decoded_text = s.decoded_text + new_text @@ -160,9 +186,9 @@ def event_loop(self): new_text = find_printable_text(new_text) output_strs.append( - self.trim_eos( + self.trim_matched_stop( s.decoded_text + new_text, - recv_obj.finished_reason[i], + recv_obj.finished_reasons[i], recv_obj.no_stop_trim[i], ) ) @@ -170,15 +196,26 @@ def event_loop(self): self.send_to_tokenizer.send_pyobj( BatchStrOut( rids=recv_obj.rids, + finished_reasons=recv_obj.finished_reasons, output_strs=output_strs, - meta_info=recv_obj.meta_info, - finished_reason=recv_obj.finished_reason, + prompt_tokens=recv_obj.prompt_tokens, + completion_tokens=recv_obj.completion_tokens, + cached_tokens=recv_obj.cached_tokens, + spec_verify_ct=recv_obj.spec_verify_ct, + input_token_logprobs_val=recv_obj.input_token_logprobs_val, + input_token_logprobs_idx=recv_obj.input_token_logprobs_idx, + output_token_logprobs_val=recv_obj.output_token_logprobs_val, + output_token_logprobs_idx=recv_obj.output_token_logprobs_idx, + input_top_logprobs_val=recv_obj.input_top_logprobs_val, + input_top_logprobs_idx=recv_obj.input_top_logprobs_idx, + output_top_logprobs_val=recv_obj.output_top_logprobs_val, + output_top_logprobs_idx=recv_obj.output_top_logprobs_idx, ) ) class LimitedCapacityDict(OrderedDict): - def __init__(self, capacity=1 << 15, *args, **kwargs): + def __init__(self, capacity: int, *args, **kwargs): super().__init__(*args, **kwargs) self.capacity = capacity @@ -194,6 +231,7 @@ def run_detokenizer_process( server_args: ServerArgs, port_args: PortArgs, ): + setproctitle.setproctitle("sglang::detokenizer") configure_logger(server_args) parent_process = psutil.Process().parent() diff --git a/python/sglang/srt/managers/image_processor.py b/python/sglang/srt/managers/image_processor.py index 7120fa48d52..f43ecb18c16 100644 --- a/python/sglang/srt/managers/image_processor.py +++ b/python/sglang/srt/managers/image_processor.py @@ -9,6 +9,8 @@ import numpy as np import transformers +from decord import VideoReader, cpu +from PIL import Image from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.mm_utils import expand2square, process_anyres_image @@ -36,6 +38,7 @@ class BaseImageProcessor(ABC): def __init__(self, hf_config, server_args, _processor): self.hf_config = hf_config self._processor = _processor + self.server_args = server_args self.executor = concurrent.futures.ProcessPoolExecutor( initializer=init_global_processor, @@ -126,7 +129,12 @@ async def _process_single_image( ) async def process_images_async( - self, image_data: List[Union[str, bytes]], input_text, request_obj + self, + image_data: List[Union[str, bytes]], + input_text, + request_obj, + *args, + **kwargs, ): if not image_data: return None @@ -229,6 +237,186 @@ async def process_images_async( return image_inputs +class MiniCPMVImageProcessor(BaseImageProcessor): + def __init__(self, hf_config, server_args, _processor): + super().__init__(hf_config, server_args, _processor) + self.IMAGE_TOKEN = "(./)" + + @staticmethod + def _process_images_task(images, input_text): + result = global_processor.__call__( + text=input_text, images=images, return_tensors="pt" + ) + return { + "input_ids": result["input_ids"], + "pixel_values": result["pixel_values"], + "tgt_sizes": result["tgt_sizes"], + } + + async def _process_images(self, images, input_text): + if self.executor is not None: + loop = asyncio.get_event_loop() + image_inputs = await loop.run_in_executor( + self.executor, + MiniCPMVImageProcessor._process_images_task, + images, + input_text, + ) + else: + image_inputs = self._processor( + images=images, text=input_text, return_tensors="pt" + ) + + return image_inputs + + async def process_images_async( + self, + image_data: List[Union[str, bytes]], + input_ids, + request_obj, + max_req_input_len, + ): + if not image_data: + return None + + if not isinstance(image_data, list): + image_data = [image_data] + + image_hashes, image_sizes = [], [] + all_frames = [] + + # roughly calculate the max number of frames under the max_req_input_len limit + def calculate_max_num_frames() -> int: + # Model-specific + NUM_TOKEN_PER_FRAME = 330 + + ret = (max_req_input_len - len(input_ids)) // NUM_TOKEN_PER_FRAME + return min(ret, 100) + + MAX_NUM_FRAMES = calculate_max_num_frames() + + # print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}") + + def get_estimated_frames_list(): + """ + estimate the total frame count from all visual input + """ + # Before processing inputs + estimated_frames_list = [] + for image in image_data: + if isinstance(image, str) and image.startswith("video:"): + path = image[len("video:") :] + # Estimate frames for the video + vr = VideoReader(path, ctx=cpu(0)) + num_frames = len(vr) + else: + # For images, each contributes one frame + num_frames = 1 + estimated_frames_list.append(num_frames) + + return estimated_frames_list + + estimated_frames_list = get_estimated_frames_list() + total_frame_count = sum(estimated_frames_list) + scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count) + + def encode_video(video_path, frame_count_limit=None): + if not os.path.exists(video_path): + logger.error(f"Video {video_path} does not exist") + return [] + + if frame_count_limit == 0: + return [] + + def uniform_sample(l, n): + gap = len(l) / n + idxs = [int(i * gap + gap / 2) for i in range(n)] + return [l[i] for i in idxs] + + vr = VideoReader(video_path, ctx=cpu(0)) + sample_fps = round(vr.get_avg_fps() / 1) # FPS + frame_idx = [i for i in range(0, len(vr), sample_fps)] + if frame_count_limit is not None and len(frame_idx) > frame_count_limit: + frame_idx = uniform_sample(frame_idx, frame_count_limit) + frames = vr.get_batch(frame_idx).asnumpy() + frames = [Image.fromarray(v.astype("uint8")) for v in frames] + return frames + + if isinstance(input_ids, list): + assert len(input_ids) and isinstance(input_ids[0], int) + input_text = self._processor.tokenizer.decode(input_ids) + else: + input_text = input_ids + # MiniCPMV requires each frame of video as a single image token + text_parts = input_text.split(self.IMAGE_TOKEN) + new_text_parts = [] + + # Process each input with allocated frames + for image_index, (image, estimated_frames) in enumerate( + zip(image_data, estimated_frames_list) + ): + if len(all_frames) >= MAX_NUM_FRAMES: + frames_to_process = 0 + else: + frames_to_process = max(1, int(estimated_frames * scaling_factor)) + + if frames_to_process == 0: + frames = [] + else: + try: + if isinstance(image, str) and image.startswith("video:"): + path = image[len("video:") :] + frames = encode_video(path, frame_count_limit=frames_to_process) + else: + raw_image, _size = load_image(image) + frames = [raw_image] + if len(frames) == 0: + continue + except FileNotFoundError as e: + print(e) + return None + image_sizes += frames[0].size * len(frames) + image_hashes += [hash(image)] * len(frames) + all_frames += frames + + assert frames_to_process == len(frames) + + new_text_parts.append(text_parts[image_index]) + + if frames_to_process != 0: + new_text_parts.append(self.IMAGE_TOKEN * len(frames)) + + new_text_parts.append(text_parts[-1]) + + input_text = "".join(new_text_parts) + + if len(all_frames) == 0: + return None + res = await self._process_images(images=all_frames, input_text=input_text) + pixel_values = res["pixel_values"] + tgt_sizes = res["tgt_sizes"] + input_ids = res["input_ids"] + + # Collect special token ids + tokenizer = self._processor.tokenizer + im_start_id = [tokenizer.im_start_id] + im_end_id = [tokenizer.im_end_id] + if tokenizer.slice_start_id: + slice_start_id = [tokenizer.slice_start_id] + slice_end_id = [tokenizer.slice_end_id] + return { + "input_ids": input_ids.flatten().tolist(), + "pixel_values": pixel_values, + "tgt_sizes": tgt_sizes, + "image_hashes": image_hashes, + "modalities": request_obj.modalities or ["image"], + "im_start_id": im_start_id, + "im_end_id": im_end_id, + "slice_start_id": slice_start_id, + "slice_end_id": slice_end_id, + } + + class Qwen2VLImageProcessor(BaseImageProcessor): def __init__(self, hf_config, server_args, _image_processor): self.hf_config = hf_config @@ -289,7 +477,12 @@ async def _process_single_image(self, image_data: Union[bytes, str]): return self._process_single_image_task(image_data) async def process_images_async( - self, image_data: List[Union[str, bytes]], input_text, request_obj + self, + image_data: List[Union[str, bytes]], + input_text, + request_obj, + *args, + **kwargs, ): if not image_data: return None @@ -350,6 +543,8 @@ def get_image_processor( return MllamaImageProcessor(hf_config, server_args, processor) elif "Qwen2VLForConditionalGeneration" in hf_config.architectures: return Qwen2VLImageProcessor(hf_config, server_args, processor.image_processor) + elif "MiniCPMV" in hf_config.architectures: + return MiniCPMVImageProcessor(hf_config, server_args, processor) else: return LlavaImageProcessor(hf_config, server_args, processor.image_processor) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 27bf5a4bdb1..f7419d04f33 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -17,14 +17,22 @@ """ import uuid -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Union from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.sampling.sampling_params import SamplingParams +@dataclass +class SessionParams: + id: Optional[str] = None + rid: Optional[str] = None + offset: Optional[int] = None + replace: Optional[bool] = None + + @dataclass class GenerateReqInput: # The input prompt. It can be a single prompt or a batch of prompts. @@ -51,15 +59,20 @@ class GenerateReqInput: return_text_in_logprobs: bool = False # Whether to stream output. stream: bool = False + # Whether to log metrics for this request (e.g. health_generate calls do not log metrics) + log_metrics: bool = True + # The modalities of the image data [image, multi-images, video] modalities: Optional[List[str]] = None # LoRA related lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None - # Session id info for continual prompting - session: Optional[ - Union[List[Tuple[str, Optional[str]]], Tuple[str, Optional[str]]] - ] = None + # Session info for continual prompting + session_params: Optional[Union[List[Dict], Dict]] = None + # Custom logit processor for advanced sampling control. Must be a serialized instance + # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py + # Use the processor's `to_str()` method to generate the serialized string. + custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None def normalize_batch_and_arguments(self): if ( @@ -174,6 +187,13 @@ def normalize_batch_and_arguments(self): else: assert self.parallel_sample_num == 1 + if self.custom_logit_processor is None: + self.custom_logit_processor = [None] * num + elif not isinstance(self.custom_logit_processor, list): + self.custom_logit_processor = [self.custom_logit_processor] * num + else: + assert self.parallel_sample_num == 1 + def regenerate_rid(self): self.rid = uuid.uuid4().hex return self.rid @@ -190,8 +210,14 @@ def __getitem__(self, i): top_logprobs_num=self.top_logprobs_num[i], return_text_in_logprobs=self.return_text_in_logprobs, stream=self.stream, + log_metrics=self.log_metrics, modalities=self.modalities[i] if self.modalities else None, lora_path=self.lora_path[i] if self.lora_path is not None else None, + custom_logit_processor=( + self.custom_logit_processor[i] + if self.custom_logit_processor is not None + else None + ), ) @@ -221,9 +247,13 @@ class TokenizedGenerateReqInput: # The input embeds input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None - # Session id info for continual prompting - session_id: Optional[str] = None - session_rid: Optional[str] = None + # Session info for continual prompting + session_params: Optional[SessionParams] = None + + # Custom logit processor for advanced sampling control. Must be a serialized instance + # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py + # Use the processor's `to_str()` method to generate the serialized string. + custom_logit_processor: Optional[str] = None @dataclass @@ -238,6 +268,8 @@ class EmbeddingReqInput: sampling_params: Union[List[Dict], Dict] = None # Dummy input embeds for compatibility input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None + # Whether to log metrics for this request (e.g. health_generate calls do not log metrics) + log_metrics: bool = True def normalize_batch_and_arguments(self): if (self.text is None and self.input_ids is None) or ( @@ -308,42 +340,74 @@ class TokenizedEmbeddingReqInput: class BatchTokenIDOut: # The request id rids: List[str] + # The finish reason + finished_reasons: List[BaseFinishReason] + # For incremental decoding # The version id to sync decode status with in detokenizer_manager vids: List[int] decoded_texts: List[str] decode_ids: List[int] read_offsets: List[int] - # Only used when `--skip-tokenizer-init` + # Only used when `--skip-tokenizer-init` is on output_ids: Optional[List[int]] + # Detokenization configs skip_special_tokens: List[bool] spaces_between_special_tokens: List[bool] - meta_info: List[Dict] - finished_reason: List[BaseFinishReason] no_stop_trim: List[bool] + # Token counts + prompt_tokens: List[int] + completion_tokens: List[int] + cached_tokens: List[int] + spec_verify_ct: List[int] + + # Logprobs + input_token_logprobs_val: List[float] + input_token_logprobs_idx: List[int] + output_token_logprobs_val: List[float] + output_token_logprobs_idx: List[int] + input_top_logprobs_val: List[List] + input_top_logprobs_idx: List[List] + output_top_logprobs_val: List[List] + output_top_logprobs_idx: List[List] + @dataclass class BatchStrOut: # The request id rids: List[str] + # The finish reason + finished_reasons: List[dict] # The output decoded strings output_strs: List[str] - # The meta info - meta_info: List[Dict] - # The finish reason - finished_reason: List[BaseFinishReason] + + # Token counts + prompt_tokens: List[int] + completion_tokens: List[int] + cached_tokens: List[int] + spec_verify_ct: List[int] + + # Logprobs + input_token_logprobs_val: List[float] + input_token_logprobs_idx: List[int] + output_token_logprobs_val: List[float] + output_token_logprobs_idx: List[int] + input_top_logprobs_val: List[List] + input_top_logprobs_idx: List[List] + output_top_logprobs_val: List[List] + output_top_logprobs_idx: List[List] @dataclass class BatchEmbeddingOut: # The request id rids: List[str] + # The finish reason + finished_reasons: List[BaseFinishReason] # The output embedding embeddings: List[List[float]] - # The meta info - meta_info: List[Dict] - # The finish reason - finished_reason: List[BaseFinishReason] + # Token counts + prompt_tokens: List[int] @dataclass @@ -378,6 +442,17 @@ class UpdateWeightsFromDistributedReqOutput: message: str +@dataclass +class UpdateWeightsFromTensorReqInput: + serialized_named_tensors: bytes # indeed Dict[str, torch.Tensor] + + +@dataclass +class UpdateWeightsFromTensorReqOutput: + success: bool + message: str + + @dataclass class InitWeightsUpdateGroupReqInput: # The master address @@ -411,6 +486,26 @@ class GetWeightsByNameReqOutput: parameter: list +@dataclass +class ReleaseMemoryOccupationReqInput: + pass + + +@dataclass +class ReleaseMemoryOccupationReqOutput: + pass + + +@dataclass +class ResumeMemoryOccupationReqInput: + pass + + +@dataclass +class ResumeMemoryOccupationReqOutput: + pass + + @dataclass class AbortReq: # The request id @@ -422,9 +517,18 @@ class ProfileReq(Enum): STOP_PROFILE = 2 +@dataclass +class ConfigureLoggingReq: + log_requests: Optional[bool] = None + log_requests_level: Optional[int] = None + dump_requests_folder: Optional[str] = None + dump_requests_threshold: Optional[int] = None + + @dataclass class OpenSessionReqInput: capacity_of_str_len: int + session_id: Optional[str] = None @dataclass @@ -434,4 +538,29 @@ class CloseSessionReqInput: @dataclass class OpenSessionReqOutput: - session_id: str + session_id: Optional[str] + success: bool + + +@dataclass +class Function: + description: Optional[str] = None + name: Optional[str] = None + parameters: Optional[object] = None + + +@dataclass +class Tool: + function: Function + type: Optional[str] = "function" + + +@dataclass +class FunctionCallReqInput: + text: str # The text to parse. + tools: List[Tool] = field( + default_factory=list + ) # A list of available function tools (name, parameters, etc.). + tool_call_parser: Optional[str] = ( + None # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all. + ) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 5855d4248ff..f22d3d5fe74 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1,3 +1,5 @@ +from __future__ import annotations + # Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,7 +31,7 @@ import dataclasses import logging -from typing import List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union import numpy as np import torch @@ -42,11 +44,14 @@ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool -from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs +if TYPE_CHECKING: + from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm + INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 # Put some global args for easy access @@ -59,9 +64,9 @@ "enable_nan_detection": ServerArgs.enable_nan_detection, "enable_dp_attention": ServerArgs.enable_dp_attention, "enable_ep_moe": ServerArgs.enable_ep_moe, + "device": ServerArgs.device, } - logger = logging.getLogger(__name__) @@ -110,14 +115,18 @@ def to_json(self): class FINISH_ABORT(BaseFinishReason): - def __init__(self, message="Unknown error"): + def __init__(self, message="Unknown error", status_code=None, err_type=None): super().__init__(is_error=True) self.message = message + self.status_code = status_code + self.err_type = err_type def to_json(self): return { "type": "abort", "message": self.message, + "status_code": self.status_code, + "err_type": self.err_type, } @@ -129,6 +138,7 @@ class ImageInputs: image_hashes: Optional[list] = None image_sizes: Optional[list] = None image_offsets: Optional[list] = None + image_pad_len: Optional[list] = None pad_values: Optional[list] = None modalities: Optional[list] = None num_image_tokens: Optional[int] = None @@ -141,6 +151,15 @@ class ImageInputs: image_grid_thws: List[Tuple[int, int, int]] = None mrope_position_delta: Optional[torch.Tensor] = None + # MiniCPMV related + # All the images in the batch should share the same special image + # bound token ids. + im_start_id: Optional[torch.Tensor] = None + im_end_id: Optional[torch.Tensor] = None + slice_start_id: Optional[torch.Tensor] = None + slice_end_id: Optional[torch.Tensor] = None + tgt_sizes: Optional[list] = None + @staticmethod def from_dict(obj: dict): ret = ImageInputs( @@ -160,6 +179,11 @@ def from_dict(obj: dict): "aspect_ratio_ids", "aspect_ratio_mask", "image_grid_thws", + "im_start_id", + "im_end_id", + "slice_start_id", + "slice_end_id", + "tgt_sizes", ] for arg in optional_args: if arg in obj: @@ -181,6 +205,7 @@ def merge(self, other): optional_args = [ "image_sizes", "image_offsets", + "image_pad_len", # "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images "aspect_ratio_ids", "aspect_ratio_mask", @@ -200,10 +225,15 @@ def __init__( origin_input_text: str, origin_input_ids: Tuple[int], sampling_params: SamplingParams, + return_logprob: bool = False, + top_logprobs_num: int = 0, + stream: bool = False, origin_input_ids_unpadded: Optional[Tuple[int]] = None, lora_path: Optional[str] = None, input_embeds: Optional[List[List[float]]] = None, session_id: Optional[str] = None, + custom_logit_processor: Optional[str] = None, + eos_token_ids: Optional[Set[int]] = None, ): # Input and output info self.rid = rid @@ -214,13 +244,16 @@ def __init__( else origin_input_ids # Before image padding ) self.origin_input_ids = origin_input_ids - self.output_ids = [] # Each decode stage's output ids - self.fill_ids = None # fill_ids = origin_input_ids + output_ids + # Each decode stage's output ids + self.output_ids = [] + # fill_ids = origin_input_ids + output_ids. Updated if chunked. + self.fill_ids = None self.session_id = session_id + self.input_embeds = input_embeds + # Sampling info self.sampling_params = sampling_params - self.lora_path = lora_path - self.input_embeds = input_embeds + self.custom_logit_processor = custom_logit_processor # Memory pool info self.req_pool_idx = None @@ -228,8 +261,9 @@ def __init__( # Check finish self.tokenizer = None self.finished_reason = None - self.stream = False self.to_abort = False + self.stream = stream + self.eos_token_ids = eos_token_ids # For incremental decoding # ----- | --------- read_ids -------| @@ -241,37 +275,46 @@ def __init__( # 2: read_offset # 3: last token self.vid = 0 # version id to sync decode status with in detokenizer_manager - self.decoded_text = "" self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm self.read_offset = None - - # The number of decoded tokens for token usage report. Note that - # this does not include the jump forward tokens. - self.completion_tokens_wo_jump_forward = 0 + self.decoded_text = "" # For multimodal inputs self.image_inputs: Optional[ImageInputs] = None # Prefix info self.prefix_indices = [] + # Tokens to run prefill. input_tokens - shared_prefix_tokens. + # Updated if chunked. self.extend_input_len = 0 self.last_node = None + + # Chunked prefill self.is_being_chunked = 0 # For retraction self.is_retracted = False # Logprobs (arguments) - self.return_logprob = False + self.return_logprob = return_logprob self.logprob_start_len = 0 - self.top_logprobs_num = 0 - - # Logprobs (return value) - self.normalized_prompt_logprob = None - self.input_token_logprobs = None - self.input_top_logprobs = None - self.output_token_logprobs = [] - self.output_top_logprobs = [] + self.top_logprobs_num = top_logprobs_num + + # Logprobs (return values) + self.input_token_logprobs_val: Optional[List[float]] = None + self.input_token_logprobs_idx: Optional[List[int]] = None + self.input_top_logprobs_val: Optional[List[float]] = None + self.input_top_logprobs_idx: Optional[List[int]] = None + + if return_logprob: + self.output_token_logprobs_val = [] + self.output_token_logprobs_idx = [] + self.output_top_logprobs_val = [] + self.output_top_logprobs_idx = [] + else: + self.output_token_logprobs_val = self.output_token_logprobs_idx = ( + self.output_top_logprobs_val + ) = self.output_top_logprobs_idx = None # Logprobs (internal values) # The tokens is prefilled but need to be considered as decode tokens @@ -286,8 +329,14 @@ def __init__( # Constrained decoding self.grammar: Optional[BaseGrammarObject] = None - # The number of cached tokens, that were already cached in the KV cache + # The number of cached tokens that were already cached in the KV cache self.cached_tokens = 0 + self.already_computed = 0 + + # The number of verification forward passes in the speculative decoding. + # This is used to compute the average acceptance length per request. + self.spec_verify_ct = 0 + self.lora_path = lora_path def extend_image_inputs(self, image_inputs): if self.image_inputs is None: @@ -295,13 +344,14 @@ def extend_image_inputs(self, image_inputs): else: self.image_inputs.merge(image_inputs) - # whether request reached finished condition def finished(self) -> bool: + # Whether request reached finished condition return self.finished_reason is not None def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None): self.fill_ids = self.origin_input_ids + self.output_ids if tree_cache is not None: + # tree cache is None if the prefix is not computed with tree cache. self.prefix_indices, self.last_node = tree_cache.match_prefix( rid=self.rid, key=self.adjust_max_prefix_ids() ) @@ -320,9 +370,6 @@ def adjust_max_prefix_ids(self): max_prefix_len = min(max_prefix_len, input_len - 1) if self.return_logprob: - if self.normalized_prompt_logprob is None: - # Need at least two tokens to compute normalized logprob - max_prefix_len = min(max_prefix_len, input_len - 2) max_prefix_len = min(max_prefix_len, self.logprob_start_len) max_prefix_len = max(max_prefix_len, 0) @@ -379,18 +426,23 @@ def check_finished(self): last_token_id = self.output_ids[-1] - matched_eos = False - - # Check stop token ids - if self.sampling_params.stop_token_ids: - matched_eos = last_token_id in self.sampling_params.stop_token_ids - if self.tokenizer is not None: - matched_eos |= last_token_id == self.tokenizer.eos_token_id - if self.tokenizer.additional_stop_token_ids: - matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids - if matched_eos and not self.sampling_params.ignore_eos: - self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id) - return + if not self.sampling_params.ignore_eos: + matched_eos = False + + # Check stop token ids + if self.sampling_params.stop_token_ids: + matched_eos = last_token_id in self.sampling_params.stop_token_ids + if self.eos_token_ids: + matched_eos |= last_token_id in self.eos_token_ids + if self.tokenizer is not None: + matched_eos |= last_token_id == self.tokenizer.eos_token_id + if self.tokenizer.additional_stop_token_ids: + matched_eos |= ( + last_token_id in self.tokenizer.additional_stop_token_ids + ) + if matched_eos: + self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id) + return # Check stop strings if len(self.sampling_params.stop_strs) > 0: @@ -454,15 +506,31 @@ def jump_forward_and_retokenize(self, jump_forward_str, next_state): k = k + 1 else: break - self.output_token_logprobs = self.output_token_logprobs[:k] - self.output_top_logprobs = self.output_top_logprobs[:k] + self.output_token_logprobs_val = self.output_token_logprobs_val[:k] + self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k] + self.output_top_logprobs_val = self.output_top_logprobs_val[:k] + self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k] self.logprob_start_len = prompt_tokens + k self.last_update_decode_tokens = len(self.output_ids) - k return True + def reset_for_retract(self): + self.prefix_indices = [] + self.last_node = None + self.extend_input_len = 0 + self.is_retracted = True + + # For incremental logprobs + # TODO: Fix the `logprob_start_len` + self.last_update_decode_tokens = 0 + self.logprob_start_len = 10**9 + def __repr__(self): - return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, " + return ( + f"rid(n={self.rid}, " + f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}" + ) bid = 0 @@ -470,7 +538,7 @@ def __repr__(self): @dataclasses.dataclass class ScheduleBatch: - """Store all inforamtion of a batch on the scheduler.""" + """Store all information of a batch on the scheduler.""" # Request, memory pool, and cache reqs: List[Req] @@ -488,13 +556,13 @@ class ScheduleBatch: next_batch_sampling_info: SamplingBatchInfo = None # Batched arguments to model runner - input_ids: torch.Tensor = None - input_embeds: torch.Tensor = None - req_pool_indices: torch.Tensor = None - seq_lens: torch.Tensor = None + input_ids: torch.Tensor = None # shape: [b], int32 + input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32 + req_pool_indices: torch.Tensor = None # shape: [b], int32 + seq_lens: torch.Tensor = None # shape: [b], int64 # The output locations of the KV cache - out_cache_loc: torch.Tensor = None - output_ids: torch.Tensor = None + out_cache_loc: torch.Tensor = None # shape: [b], int32 + output_ids: torch.Tensor = None # shape: [b], int32 # The sum of all sequence lengths seq_lens_sum: int = None @@ -526,9 +594,16 @@ class ScheduleBatch: # Has grammar has_grammar: bool = False - # device + # Device device: str = "cuda" + # Speculative decoding + spec_algorithm: SpeculativeAlgorithm = None + spec_info: Optional[SpecInfo] = None + + # Enable custom logit processor + enable_custom_logit_processor: bool = False + @classmethod def init_new( cls, @@ -538,6 +613,8 @@ def init_new( tree_cache: BasePrefixCache, model_config: ModelConfig, enable_overlap: bool, + spec_algorithm: SpeculativeAlgorithm, + enable_custom_logit_processor: bool, ): return cls( reqs=reqs, @@ -550,6 +627,8 @@ def init_new( has_stream=any(req.stream for req in reqs), has_grammar=any(req.grammar for req in reqs), device=req_to_token_pool.device, + spec_algorithm=spec_algorithm, + enable_custom_logit_processor=enable_custom_logit_processor, ) def batch_size(self): @@ -605,7 +684,7 @@ def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]) or len(req.prefix_indices) >= im.num_image_tokens ) - self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to( + self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to( self.device, non_blocking=True ) @@ -639,7 +718,7 @@ def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]) self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to( self.device, non_blocking=True ) - self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to( + self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to( self.device, non_blocking=True ) @@ -677,13 +756,6 @@ def prepare_for_extend(self): pt = 0 for i, req in enumerate(reqs): - already_computed = ( - req.extend_logprob_start_len + 1 + req.cached_tokens - if req.extend_logprob_start_len > 0 - else 0 - ) - req.cached_tokens += len(req.prefix_indices) - already_computed - req.req_pool_idx = req_pool_indices[i] pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids) seq_lens.append(seq_len) @@ -699,15 +771,20 @@ def prepare_for_extend(self): # If req.input_embeds is already a list, append its content directly input_embeds.extend(req.input_embeds) # Use extend to avoid nesting - # Compute the relative logprob_start_len in an extend batch - if req.logprob_start_len >= pre_len: - extend_logprob_start_len = min( - req.logprob_start_len - pre_len, req.extend_input_len - 1 - ) - else: - extend_logprob_start_len = req.extend_input_len - 1 + if req.return_logprob: + # Compute the relative logprob_start_len in an extend batch + if req.logprob_start_len >= pre_len: + extend_logprob_start_len = min( + req.logprob_start_len - pre_len, req.extend_input_len - 1 + ) + else: + raise RuntimeError( + f"This should never happen. {req.logprob_start_len=}, {pre_len=}" + ) + req.extend_logprob_start_len = extend_logprob_start_len - req.extend_logprob_start_len = extend_logprob_start_len + req.cached_tokens += pre_len - req.already_computed + req.already_computed = seq_len req.is_retracted = False pre_lens.append(pre_len) @@ -715,10 +792,10 @@ def prepare_for_extend(self): self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to( self.device, non_blocking=True ) - self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to( + self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int64).to( self.device, non_blocking=True ) - self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to( + self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to( self.device, non_blocking=True ) self.input_embeds = ( @@ -804,8 +881,8 @@ def mix_with_running(self, running_batch: "ScheduleBatch"): # TODO (lianmin): Revisit this. It should be seq_len - 1 self.extend_logprob_start_lens.extend([0] * running_bs) - def check_decode_mem(self): - bs = len(self.reqs) + def check_decode_mem(self, buf_multiplier=1): + bs = len(self.reqs) * buf_multiplier if self.token_to_kv_pool.available_size() >= bs: return True @@ -876,15 +953,7 @@ def retract_decode(self): ) residual_size = max(0, residual_size) self.tree_cache.evict(residual_size, self.token_to_kv_pool.free) - - req.prefix_indices = [] - req.last_node = None - req.extend_input_len = 0 - req.is_retracted = True - - # For incremental logprobs - req.last_update_decode_tokens = 0 - req.logprob_start_len = 10**9 + req.reset_for_retract() self.filter_batch(keep_indices=sorted_indices) @@ -959,14 +1028,21 @@ def prepare_encoder_info_decode(self): def prepare_for_idle(self): self.forward_mode = ForwardMode.IDLE self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device) - self.seq_lens = torch.empty(0, dtype=torch.int32, device=self.device) + self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device) self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device) self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device) self.seq_lens_sum = 0 self.extend_num_tokens = 0 + self.sampling_info = SamplingBatchInfo.from_schedule_batch( + self, + self.model_config.vocab_size, + enable_overlap_schedule=self.enable_overlap, + ) def prepare_for_decode(self): self.forward_mode = ForwardMode.DECODE + if self.spec_algorithm.is_eagle(): + return self.input_ids = self.output_ids self.output_ids = None @@ -1022,7 +1098,7 @@ def filter_batch( self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices] self.reqs = [self.reqs[i] for i in keep_indices] - new_indices = torch.tensor(keep_indices, dtype=torch.int32).to( + new_indices = torch.tensor(keep_indices, dtype=torch.int64).to( self.device, non_blocking=True ) self.req_pool_indices = self.req_pool_indices[new_indices] @@ -1040,6 +1116,8 @@ def filter_batch( self.has_grammar = any(req.grammar for req in self.reqs) self.sampling_info.filter_batch(keep_indices, new_indices) + if self.spec_info: + self.spec_info.filter_batch(new_indices) def merge_batch(self, other: "ScheduleBatch"): # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because @@ -1068,12 +1146,15 @@ def merge_batch(self, other: "ScheduleBatch"): self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums self.reqs.extend(other.reqs) - self.return_logprob = self.return_logprob or other.return_logprob - self.has_stream = self.has_stream or other.has_stream - self.has_grammar = self.has_grammar or other.has_grammar + self.return_logprob |= other.return_logprob + self.has_stream |= other.has_stream + self.has_grammar |= other.has_grammar + + if self.spec_info: + self.spec_info.merge_batch(other.spec_info) def get_model_worker_batch(self): - if self.forward_mode.is_decode() or self.forward_mode.is_idle(): + if self.forward_mode.is_decode_or_idle(): extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None else: extend_seq_lens = self.extend_lens @@ -1088,7 +1169,6 @@ def get_model_worker_batch(self): global bid bid += 1 - return ModelWorkerBatch( bid=bid, forward_mode=self.forward_mode, @@ -1097,7 +1177,6 @@ def get_model_worker_batch(self): seq_lens=self.seq_lens, out_cache_loc=self.out_cache_loc, seq_lens_sum=self.seq_lens_sum, - req_to_token_pool_records=self.req_to_token_pool.get_write_records(), return_logprob=self.return_logprob, top_logprobs_nums=self.top_logprobs_nums, global_num_tokens=self.global_num_tokens, @@ -1114,6 +1193,13 @@ def get_model_worker_batch(self): lora_paths=[req.lora_path for req in self.reqs], sampling_info=self.sampling_info, input_embeds=self.input_embeds, + spec_algorithm=self.spec_algorithm, + spec_info=self.spec_info, + capture_hidden_mode=( + getattr(self.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL) + if self.spec_info + else CaptureHiddenMode.NULL + ), ) def copy(self): @@ -1125,6 +1211,8 @@ def copy(self): out_cache_loc=self.out_cache_loc, return_logprob=self.return_logprob, decoding_reqs=self.decoding_reqs, + spec_algorithm=self.spec_algorithm, + enable_custom_logit_processor=self.enable_custom_logit_processor, ) def __str__(self): @@ -1152,9 +1240,6 @@ class ModelWorkerBatch: # The sum of all sequence lengths seq_lens_sum: int - # The memory pool operation records - req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]] - # For logprob return_logprob: bool top_logprobs_nums: Optional[List[int]] @@ -1187,6 +1272,11 @@ class ModelWorkerBatch: # The input Embeds input_embeds: Optional[torch.tensor] = None + # Speculative decoding + spec_algorithm: SpeculativeAlgorithm = None + spec_info: Optional[SpecInfo] = None + capture_hidden_mode: CaptureHiddenMode = None + @triton.jit def write_req_to_token_pool_triton( diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index abe7da9ea21..a3a099b83de 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -18,11 +18,14 @@ from collections import defaultdict from contextlib import contextmanager from enum import Enum, auto -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Set, Union + +import torch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache -from sglang.srt.mem_cache.radix_cache import TreeNode +from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool +from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode # Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large. # This can prevent the server from being too conservative. @@ -32,83 +35,209 @@ os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096") ) +# Threshold for in-batch prefix cache. +# If a request has a matched prefix length (against existing cache) less than this value, +# the scheduler runs the in-batch prefix caching check for this request. +# If we set it to -1, it means we disable in-batch prefix caching. +IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD = int( + os.environ.get("IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD", "32") +) + +# Threshold for in-batch prefix cache. +# If a request has a matched prefix length (within the waiting queue) larger than this value, +# the scheduler deprioritizes this request +IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int( + os.environ.get("IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD", "32") +) + + +class CacheAwarePolicy(Enum): + """Scheduling policies that are aware of the tree cache.""" + + LPM = "lpm" # longest prefix match + DFS_WEIGHT = "dfs-weight" # depth-first search weighting + + +class CacheAgnosticPolicy(Enum): + """Scheduling policies that are not aware of the tree cache.""" + + FCFS = "fcfs" # first come first serve + LOF = "lof" # longest output first + RANDOM = "random" + class SchedulePolicy: - def __init__(self, policy: str, tree_cache: BasePrefixCache): - if tree_cache.disable and policy in ["lpm", "dfs-weight"]: - # LPM and DFS-weight is meaningless when the tree cache is disabled. - policy = "fcfs" + Policy = Union[CacheAwarePolicy, CacheAgnosticPolicy] - self.policy = policy + def __init__(self, policy: str, tree_cache: BasePrefixCache): + self.policy = self._validate_and_adjust_policy(policy, tree_cache) self.tree_cache = tree_cache - def calc_priority(self, waiting_queue: List[Req]): - if len(waiting_queue) > 128 and self.policy == "lpm": - # Turn off the expensive prefix matching and sorting when the #queue is large. - policy = "fcfs" - else: - policy = self.policy + # It is used to find the matching prefix for in-batch prefix caching. + self.waiting_queue_radix_tree = RadixCache( + req_to_token_pool=None, token_to_kv_pool=None, disable=False + ) - # Compute matched prefix length - prefix_computed = False - if policy == "lpm" or policy == "dfs-weight": - for r in waiting_queue: - # NOTE: the prefix_indices must always be aligned with last_node - r.prefix_indices, r.last_node = self.tree_cache.match_prefix( - rid=r.rid, key=r.adjust_max_prefix_ids() - ) + def calc_priority(self, waiting_queue: List[Req]) -> bool: + policy = self._determine_active_policy(waiting_queue) + prefix_computed = False + if isinstance(policy, CacheAwarePolicy): prefix_computed = True - - if policy == "lpm": - # Longest Prefix Match - waiting_queue.sort(key=lambda x: -len(x.prefix_indices)) - elif policy == "fcfs": - # first come first serve - pass - elif policy == "lof": - # longest output first - waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens) - elif policy == "random": - random.shuffle(waiting_queue) - elif policy == "dfs-weight": - last_node_to_reqs = defaultdict(list) - for req in waiting_queue: - last_node_to_reqs[req.last_node].append(req) - - node_to_weight = defaultdict(int) - for node in last_node_to_reqs: - node_to_weight[node] = len(last_node_to_reqs[node]) - self.calc_weight(self.tree_cache.root_node, node_to_weight) - - waiting_queue.clear() - self.get_dfs_priority( - self.tree_cache.root_node, - node_to_weight, - last_node_to_reqs, - waiting_queue, + temporary_deprioritized = self._compute_prefix_matches( + waiting_queue, policy ) + if policy == CacheAwarePolicy.LPM: + SchedulePolicy._sort_by_longest_prefix( + waiting_queue, temporary_deprioritized + ) + elif policy == CacheAwarePolicy.DFS_WEIGHT: + SchedulePolicy._sort_by_dfs_weight(waiting_queue, self.tree_cache) + else: + raise ValueError(f"Unknown CacheAware Policy: {policy=}") else: - raise ValueError(f"Unknown schedule_policy: {policy=}") + if policy == CacheAgnosticPolicy.FCFS: + pass + elif policy == CacheAgnosticPolicy.LOF: + SchedulePolicy._sort_by_longest_output(waiting_queue) + elif policy == CacheAgnosticPolicy.RANDOM: + SchedulePolicy._sort_randomly(waiting_queue) + else: + raise ValueError(f"Unknown CacheAgnostic Policy: {policy=}") return prefix_computed - def calc_weight(self, cur_node: TreeNode, node_to_weight: Dict): + def _determine_active_policy(self, waiting_queue: List[Req]) -> Policy: + if len(waiting_queue) > 128 and self.policy == CacheAwarePolicy.LPM: + # Turn off the expensive prefix matching and sorting when the #queue is large. + return CacheAgnosticPolicy.FCFS + return self.policy + + def _validate_and_adjust_policy( + self, policy: str, tree_cache: BasePrefixCache + ) -> Policy: + """ + Validates the policy and adjusts it if necessary based on tree cache settings. + """ + try: + policy_enum = CacheAwarePolicy(policy) + if tree_cache.disable: + # If tree_cache is disabled, using CacheAgnosticPolicy policy + return CacheAgnosticPolicy.FCFS + return policy_enum + except ValueError: + try: + return CacheAgnosticPolicy(policy) + except ValueError: + raise ValueError(f"Unknown schedule_policy: {policy=}") + + def _compute_prefix_matches( + self, waiting_queue: List[Req], policy: CacheAwarePolicy + ) -> Set[int]: + """ + Computes and caches the matching prefixes for requests in the waiting queue, + and handles in-batch prefix caching logic. + """ + temporary_deprioritized: Set[int] = set() + self.waiting_queue_radix_tree.reset() + + for r in waiting_queue: + prefix_ids = r.adjust_max_prefix_ids() + + # NOTE: the prefix_indices must always be aligned with last_node + r.prefix_indices, r.last_node = self.tree_cache.match_prefix( + rid=r.rid, key=prefix_ids + ) + + # NOTE(sang): This logic is for in-batch prefix caching; + # If there are more than 1 request that have small matching prefix from + # existing cache, but all those requests share the same prefix, we prefer + # to schedule only one of them so that we can increase the cache hit rate. + # We prefer to set IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD > 0 because too small + # threshold means we cannot use in-batch prefix caching for short prefixes. + # It is kind of common when the engine is long running (e.g., imagine the prefix "the"). + if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD: + in_batch_matching_prefixes, _ = ( + self.waiting_queue_radix_tree.match_prefix( + rid=r.rid, key=prefix_ids + ) + ) + if ( + len(in_batch_matching_prefixes) + >= IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD + ): + temporary_deprioritized.add(r.rid) + else: + # Insert with a dummy key + self.waiting_queue_radix_tree.insert( + prefix_ids, torch.empty(len(prefix_ids), dtype=torch.bool) + ) + return temporary_deprioritized + + @staticmethod + def _sort_by_longest_prefix( + waiting_queue: List[Req], temporary_deprioritized: Set[int] + ) -> None: + """Sorts the waiting queue based on the longest prefix match.""" + waiting_queue.sort( + key=lambda r: ( + -len(r.prefix_indices) + if r.rid not in temporary_deprioritized + else float("inf") + ) + ) + + @staticmethod + def _sort_by_dfs_weight( + waiting_queue: List[Req], tree_cache: BasePrefixCache + ) -> None: + """Sorts the waiting queue based on a depth-first search weighting.""" + last_node_to_reqs = defaultdict(list) + for req in waiting_queue: + last_node_to_reqs[req.last_node].append(req) + + node_to_weight = defaultdict(int) + for node in last_node_to_reqs: + node_to_weight[node] = len(last_node_to_reqs[node]) + SchedulePolicy._calc_weight(tree_cache.root_node, node_to_weight) + + waiting_queue.clear() + SchedulePolicy._get_dfs_priority( + tree_cache.root_node, + node_to_weight, + last_node_to_reqs, + waiting_queue, + ) + + @staticmethod + def _sort_by_longest_output(waiting_queue: List[Req]) -> None: + """Sorts the waiting queue based on the longest output (max_new_tokens).""" + waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens) + + @staticmethod + def _sort_randomly(waiting_queue: List[Req]) -> None: + """Shuffles the waiting queue randomly.""" + random.shuffle(waiting_queue) + + @staticmethod + def _calc_weight(cur_node: TreeNode, node_to_weight: Dict[TreeNode, int]) -> None: for child in cur_node.children.values(): - self.calc_weight(child, node_to_weight) + SchedulePolicy._calc_weight(child, node_to_weight) node_to_weight[cur_node] += node_to_weight[child] - def get_dfs_priority( - self, + @staticmethod + def _get_dfs_priority( cur_node: TreeNode, - node_to_priority: Dict, - last_node_to_reqs: Dict, + node_to_priority: Dict[TreeNode, int], + last_node_to_reqs: Dict[TreeNode, List[Req]], q: List, - ): + ) -> None: childs = [child for child in cur_node.children.values()] childs.sort(key=lambda x: -node_to_priority[x]) for child in childs: - self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q) + SchedulePolicy._get_dfs_priority( + child, node_to_priority, last_node_to_reqs, q + ) q.extend(last_node_to_reqs[cur_node]) @@ -122,23 +251,24 @@ class PrefillAdder: def __init__( self, tree_cache: BasePrefixCache, + token_to_kv_pool: BaseTokenToKVPool, running_batch: ScheduleBatch, new_token_ratio: float, - rem_total_tokens: int, rem_input_tokens: int, rem_chunk_tokens: Optional[int], mixed_with_decode_tokens: int = 0, ): self.tree_cache = tree_cache + self.token_to_kv_pool = token_to_kv_pool self.running_batch = running_batch self.new_token_ratio = new_token_ratio - self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens self.rem_chunk_tokens = rem_chunk_tokens if self.rem_chunk_tokens is not None: self.rem_chunk_tokens -= mixed_with_decode_tokens - self.cur_rem_tokens = rem_total_tokens - mixed_with_decode_tokens + self.rem_total_token_offset = mixed_with_decode_tokens + self.cur_rem_token_offset = mixed_with_decode_tokens self.req_states = None self.can_run_list = [] @@ -147,8 +277,7 @@ def __init__( self.log_input_tokens = 0 if running_batch is not None: - # Pre-remove the tokens which will be occupied by the running requests - self.rem_total_tokens -= sum( + self.rem_total_token_offset += sum( [ min( (r.sampling_params.max_new_tokens - len(r.output_ids)), @@ -159,6 +288,22 @@ def __init__( ] ) + @property + def rem_total_tokens(self): + return ( + self.token_to_kv_pool.available_size() + + self.tree_cache.evictable_size() + - self.rem_total_token_offset + ) + + @property + def cur_rem_tokens(self): + return ( + self.token_to_kv_pool.available_size() + + self.tree_cache.evictable_size() + - self.cur_rem_token_offset + ) + def budget_state(self): if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0: return AddReqResult.NO_TOKEN @@ -173,8 +318,8 @@ def budget_state(self): def _prefill_one_req( self, prefix_len: int, extend_input_len: int, max_new_tokens: int ): - self.rem_total_tokens -= extend_input_len + max_new_tokens - self.cur_rem_tokens -= extend_input_len + self.rem_total_token_offset += extend_input_len + max_new_tokens + self.cur_rem_token_offset += extend_input_len self.rem_input_tokens -= extend_input_len if self.rem_chunk_tokens is not None: self.rem_chunk_tokens -= extend_input_len @@ -189,7 +334,7 @@ def add_being_chunked_req(self, req: Req): self.can_run_list.append(req) self._prefill_one_req( - len(req.prefix_indices), + 0, req.extend_input_len, ( min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION) @@ -204,12 +349,10 @@ def add_being_chunked_req(self, req: Req): @contextmanager def _lock_node(self, last_node: TreeNode): try: - delta = self.tree_cache.inc_lock_ref(last_node) - self.rem_total_tokens += delta + self.tree_cache.inc_lock_ref(last_node) yield None finally: - delta = self.tree_cache.dec_lock_ref(last_node) - self.rem_total_tokens += delta + self.tree_cache.dec_lock_ref(last_node) def add_one_req_ignore_eos(self, req: Req): def add_req_state(r, insert_sort=False): @@ -305,7 +448,6 @@ def add_one_req(self, req: Req): or input_tokens <= self.rem_chunk_tokens or ( req.return_logprob - and req.normalized_prompt_logprob is None and req.logprob_start_len != len(req.origin_input_ids) - 1 ) ): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 4ca4cd740dc..79d4db114e8 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -13,6 +13,7 @@ # ============================================================================== """A scheduler that manages a tensor parallel GPU worker.""" +import faulthandler import logging import os import signal @@ -21,16 +22,21 @@ import warnings from collections import deque from concurrent import futures +from dataclasses import dataclass +from http import HTTPStatus from types import SimpleNamespace -from typing import List, Optional +from typing import Dict, List, Optional, Tuple, Union import psutil +import setproctitle import torch import zmq from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.constrained.base_grammar_backend import create_grammar_backend from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer +from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( AbortReq, @@ -45,12 +51,18 @@ OpenSessionReqInput, OpenSessionReqOutput, ProfileReq, + ReleaseMemoryOccupationReqInput, + ReleaseMemoryOccupationReqOutput, + ResumeMemoryOccupationReqInput, + ResumeMemoryOccupationReqOutput, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqOutput, UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqOutput, + UpdateWeightsFromTensorReqInput, + UpdateWeightsFromTensorReqOutput, ) from sglang.srt.managers.schedule_batch import ( FINISH_ABORT, @@ -68,11 +80,14 @@ from sglang.srt.managers.session_controller import Session from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient +from sglang.srt.managers.utils import validate_input_length from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import ( broadcast_pyobj, configure_logger, @@ -83,14 +98,27 @@ set_random_seed, suppress_other_loggers, ) -from sglang.utils import get_exception_traceback +from sglang.utils import TypeBasedDispatcher, get_exception_traceback logger = logging.getLogger(__name__) -# Test retract decode +# Test retract decode for debugging purposes test_retract = get_bool_env_var("SGLANG_TEST_RETRACT") +@dataclass +class GenerationBatchResult: + logits_output: LogitsProcessorOutput + next_token_ids: List[int] + bid: int + + +@dataclass +class EmbeddingBatchResult: + embeddings: torch.Tensor + bid: int + + class Scheduler: """A scheduler that manages a tensor parallel GPU worker.""" @@ -113,27 +141,46 @@ def __init__( self.enable_overlap = not server_args.disable_overlap_schedule self.skip_tokenizer_init = server_args.skip_tokenizer_init self.enable_metrics = server_args.enable_metrics + self.spec_algorithm = SpeculativeAlgorithm.from_string( + server_args.speculative_algorithm + ) + self.decode_mem_cache_buf_multiplier = ( + self.server_args.speculative_num_draft_tokens + if not self.spec_algorithm.is_none() + else 1 + ) + self.enable_hierarchical_cache = server_args.enable_hierarchical_cache + + # Distributed rank info + self.dp_size = server_args.dp_size + self.attn_tp_rank, self.attn_tp_size, self.dp_rank = ( + compute_dp_attention_world_info( + server_args.enable_dp_attention, + self.tp_rank, + self.tp_size, + self.dp_size, + ) + ) # Init inter-process communication context = zmq.Context(2) - - if self.tp_rank == 0 or self.server_args.enable_dp_attention: + if self.attn_tp_rank == 0: self.recv_from_tokenizer = get_zmq_socket( - context, zmq.PULL, port_args.scheduler_input_ipc_name + context, zmq.PULL, port_args.scheduler_input_ipc_name, False ) self.send_to_tokenizer = get_zmq_socket( - context, zmq.PUSH, port_args.tokenizer_ipc_name + context, zmq.PUSH, port_args.tokenizer_ipc_name, False ) if server_args.skip_tokenizer_init: - # Directly send to the tokenizer/api + # Directly send to the TokenizerManager self.send_to_detokenizer = get_zmq_socket( - context, zmq.PUSH, port_args.tokenizer_ipc_name + context, zmq.PUSH, port_args.tokenizer_ipc_name, False ) else: - # Send to the detokenizer + # Send to the DetokenizerManager self.send_to_detokenizer = get_zmq_socket( - context, zmq.PUSH, port_args.detokenizer_ipc_name + context, zmq.PUSH, port_args.detokenizer_ipc_name, False ) else: self.recv_from_tokenizer = None @@ -161,6 +208,7 @@ def __init__( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, ) self.tokenizer = self.processor.tokenizer else: @@ -168,6 +216,7 @@ def __init__( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, ) # Check whether overlap can be enabled @@ -196,6 +245,21 @@ def __init__( nccl_port=port_args.nccl_port, ) + # Launch a worker for speculative decoding if needed + if self.spec_algorithm.is_eagle(): + from sglang.srt.speculative.eagle_worker import EAGLEWorker + + self.draft_worker = EAGLEWorker( + gpu_id=gpu_id, + tp_rank=tp_rank, + server_args=server_args, + nccl_port=port_args.nccl_port, + target_worker=self.tp_worker, + dp_rank=dp_rank, + ) + else: + self.draft_worker = None + # Get token and memory info from the model worker ( self.max_total_num_tokens, @@ -211,13 +275,14 @@ def __init__( _, ) = self.tp_worker.get_worker_info() self.tp_cpu_group = self.tp_worker.get_tp_cpu_group() + self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group() self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func() global_server_args_dict.update(worker_global_server_args_dict) set_random_seed(self.random_seed) - # Print debug info logger.info( f"max_total_num_tokens={self.max_total_num_tokens}, " + f"chunked_prefill_size={server_args.chunked_prefill_size}, " f"max_prefill_tokens={self.max_prefill_tokens}, " f"max_running_requests={self.max_running_requests}, " f"context_len={self.model_config.context_len}" @@ -254,12 +319,16 @@ def __init__( self.forward_ct = 0 self.forward_ct_decode = 0 self.num_generated_tokens = 0 + self.spec_num_total_accepted_tokens = 0 + self.spec_num_total_forward_ct = 0 self.last_decode_stats_tic = time.time() self.stream_interval = server_args.stream_interval self.current_stream = torch.get_device_module(self.device).current_stream() + if self.device == "cpu": + self.current_stream.synchronize = lambda: None # No-op for CPU # Session info - self.sessions = {} + self.sessions: Dict[str, Session] = {} # Init chunked prefill self.chunked_prefill_size = server_args.chunked_prefill_size @@ -273,28 +342,9 @@ def __init__( # Init the grammar backend for constrained generation self.grammar_queue: List[Req] = [] if not server_args.skip_tokenizer_init: - if server_args.grammar_backend == "outlines": - from sglang.srt.constrained.outlines_backend import ( - OutlinesGrammarBackend, - ) - - self.grammar_backend = OutlinesGrammarBackend( - self.tokenizer, - whitespace_pattern=server_args.constrained_json_whitespace_pattern, - allow_jump_forward=not server_args.disable_jump_forward, - ) - elif server_args.grammar_backend == "xgrammar": - from sglang.srt.constrained.xgrammar_backend import ( - XGrammarGrammarBackend, - ) - - self.grammar_backend = XGrammarGrammarBackend( - self.tokenizer, vocab_size=self.model_config.vocab_size - ) - else: - raise ValueError( - f"Invalid grammar backend: {server_args.grammar_backend}" - ) + self.grammar_backend = create_grammar_backend( + server_args, self.tokenizer, self.model_config.vocab_size + ) else: self.grammar_backend = None @@ -329,6 +379,10 @@ def __init__( t.start() self.parent_process = psutil.Process().parent() + self.memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=server_args.enable_memory_saver + ) + # Init profiler if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "": self.profiler = None @@ -356,22 +410,58 @@ def __init__( }, ) + # The largest prefill length of a single request + self._largest_prefill_len: int = 0 + # The largest context length (prefill + generation) of a single request + self._largest_prefill_decode_len: int = 0 + + # Init request dispatcher + self._request_dispatcher = TypeBasedDispatcher( + [ + (TokenizedGenerateReqInput, self.handle_generate_request), + (TokenizedEmbeddingReqInput, self.handle_embedding_request), + (FlushCacheReq, self.flush_cache_wrapped), + (AbortReq, self.abort_request), + (UpdateWeightFromDiskReqInput, self.update_weights_from_disk), + (InitWeightsUpdateGroupReqInput, self.init_weights_update_group), + ( + UpdateWeightsFromDistributedReqInput, + self.update_weights_from_distributed, + ), + (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor), + (GetWeightsByNameReqInput, self.get_weights_by_name), + (ProfileReq, self.profile), + (OpenSessionReqInput, self.open_session), + (CloseSessionReqInput, self.close_session), + ( + ReleaseMemoryOccupationReqInput, + lambda _: self.release_memory_occupation(), + ), + ( + ResumeMemoryOccupationReqInput, + lambda _: self.resume_memory_occupation(), + ), + ] + ) + def watchdog_thread(self): """A watch dog thread that will try to kill the server itself if one batch takes too long.""" self.watchdog_last_forward_ct = 0 self.watchdog_last_time = time.time() while True: + current = time.time() if self.cur_batch is not None: if self.watchdog_last_forward_ct == self.forward_ct: - if time.time() > self.watchdog_last_time + self.watchdog_timeout: + if current > self.watchdog_last_time + self.watchdog_timeout: logger.error(f"Watchdog timeout ({self.watchdog_timeout=})") break else: self.watchdog_last_forward_ct = self.forward_ct - self.watchdog_last_time = time.time() - time.sleep(self.watchdog_timeout / 2) - + self.watchdog_last_time = current + time.sleep(self.watchdog_timeout // 2) + # Wait sometimes so that the parent process can print the error. + time.sleep(5) self.parent_process.send_signal(signal.SIGQUIT) @torch.no_grad() @@ -382,16 +472,13 @@ def event_loop_normal(self): self.process_input_requests(recv_reqs) batch = self.get_next_batch_to_run() - if self.server_args.enable_dp_attention: - batch = self.prepare_dp_attn_batch(batch) - self.cur_batch = batch if batch: result = self.run_batch(batch) self.process_batch_result(batch, result) else: - # Self-check and re-init some states when the server is idle + # When the server is idle, so self-check and re-init some states self.check_memory() self.new_token_ratio = self.init_new_token_ratio @@ -400,7 +487,7 @@ def event_loop_normal(self): @torch.no_grad() def event_loop_overlap(self): """A scheduler loop that overlaps the CPU processing and GPU computation.""" - result_queue = deque() + self.result_queue = deque() while True: recv_reqs = self.recv_requests() @@ -408,12 +495,13 @@ def event_loop_overlap(self): batch = self.get_next_batch_to_run() self.cur_batch = batch + if batch: result = self.run_batch(batch) - result_queue.append((batch.copy(), result)) + self.result_queue.append((batch.copy(), result)) if self.last_batch is None: - # A dummy first batch to start the pipeline for overlap scheduler. + # Create a dummy first batch to start the pipeline for overlap schedule. # It is now used for triggering the sampling_info_done event. tmp_batch = ScheduleBatch( reqs=None, @@ -423,20 +511,22 @@ def event_loop_overlap(self): self.process_batch_result(tmp_batch, None) if self.last_batch: - tmp_batch, tmp_result = result_queue.popleft() + # Process the results of the last batch + tmp_batch, tmp_result = self.result_queue.popleft() tmp_batch.next_batch_sampling_info = ( self.tp_worker.cur_sampling_info if batch else None ) self.process_batch_result(tmp_batch, tmp_result) elif batch is None: - # Self-check and re-init some states when the server is idle + # When the server is idle, so self-check and re-init some states self.check_memory() self.new_token_ratio = self.init_new_token_ratio self.last_batch = batch - def recv_requests(self): - if self.tp_rank == 0 or self.server_args.enable_dp_attention: + def recv_requests(self) -> List[Req]: + """Receive results at tp_rank = 0 and broadcast it to all other TP ranks.""" + if self.attn_tp_rank == 0: recv_reqs = [] while True: @@ -448,60 +538,59 @@ def recv_requests(self): else: recv_reqs = None - if self.tp_size != 1 and not self.server_args.enable_dp_attention: + if self.server_args.enable_dp_attention: + if self.attn_tp_rank == 0: + work_reqs = [ + req + for req in recv_reqs + if isinstance( + req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput) + ) + ] + control_reqs = [ + req + for req in recv_reqs + if not isinstance( + req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput) + ) + ] + else: + work_reqs = None + control_reqs = None + + if self.attn_tp_size != 1: + attn_tp_rank_0 = self.dp_rank * self.attn_tp_size + work_reqs = broadcast_pyobj( + work_reqs, + self.attn_tp_rank, + self.attn_tp_cpu_group, + src=attn_tp_rank_0, + ) + if self.tp_size != 1: + control_reqs = broadcast_pyobj( + control_reqs, self.tp_rank, self.tp_cpu_group + ) + recv_reqs = work_reqs + control_reqs + elif self.tp_size != 1: recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group) return recv_reqs def process_input_requests(self, recv_reqs: List): for recv_req in recv_reqs: - if isinstance(recv_req, TokenizedGenerateReqInput): - self.handle_generate_request(recv_req) - elif isinstance(recv_req, TokenizedEmbeddingReqInput): - self.handle_embedding_request(recv_req) - elif isinstance(recv_req, FlushCacheReq): - self.flush_cache() - elif isinstance(recv_req, AbortReq): - self.abort_request(recv_req) - elif isinstance(recv_req, UpdateWeightFromDiskReqInput): - success, message = self.update_weights_from_disk(recv_req) - self.send_to_tokenizer.send_pyobj( - UpdateWeightFromDiskReqOutput(success, message) - ) - elif isinstance(recv_req, GetWeightsByNameReqInput): - parameter = self.get_weights_by_name(recv_req) - self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter)) - elif isinstance(recv_req, InitWeightsUpdateGroupReqInput): - success, message = self.init_weights_update_group(recv_req) - self.send_to_tokenizer.send_pyobj( - InitWeightsUpdateGroupReqOutput(success, message) - ) - elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput): - success, message = self.update_weights_from_distributed(recv_req) - self.send_to_tokenizer.send_pyobj( - UpdateWeightsFromDistributedReqOutput(success, message) - ) - elif isinstance(recv_req, GetWeightsByNameReqInput): - parameter = self.get_weights_by_name(recv_req) - self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter)) - elif isinstance(recv_req, ProfileReq): - if recv_req == ProfileReq.START_PROFILE: - self.start_profile() - else: - self.stop_profile() - elif isinstance(recv_req, OpenSessionReqInput): - session_id = self.open_session(recv_req) - self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id)) - elif isinstance(recv_req, CloseSessionReqInput): - self.close_session(recv_req) - else: - raise ValueError(f"Invalid request: {recv_req}") + output = self._request_dispatcher(recv_req) + if output is not None: + self.send_to_tokenizer.send_pyobj(output) def handle_generate_request( self, recv_req: TokenizedGenerateReqInput, ): # Create a new request - if recv_req.session_id is None or recv_req.session_id not in self.sessions: + if ( + recv_req.session_params is None + or recv_req.session_params.id is None + or recv_req.session_params.id not in self.sessions + ): if recv_req.input_embeds is not None: # Generate fake input_ids based on the length of input_embeds @@ -509,31 +598,52 @@ def handle_generate_request( fake_input_ids = [1] * seq_length recv_req.input_ids = fake_input_ids + # Handle custom logit processor passed to the request + custom_logit_processor = recv_req.custom_logit_processor + if ( + not self.server_args.enable_custom_logit_processor + and custom_logit_processor is not None + ): + logger.warning( + "The SGLang server is not configured to enable custom logit processor." + "The custom logit processor passed in will be ignored." + "Please set --enable-custom-logits-processor to enable this feature." + ) + custom_logit_processor = None + req = Req( recv_req.rid, recv_req.input_text, recv_req.input_ids, recv_req.sampling_params, + return_logprob=recv_req.return_logprob, + top_logprobs_num=recv_req.top_logprobs_num, + stream=recv_req.stream, lora_path=recv_req.lora_path, input_embeds=recv_req.input_embeds, + custom_logit_processor=custom_logit_processor, + eos_token_ids=self.model_config.hf_eos_token_id, ) req.tokenizer = self.tokenizer - if recv_req.session_id is not None: + if ( + recv_req.session_params is not None + and recv_req.session_params.id is not None + ): req.finished_reason = FINISH_ABORT( - f"Invalid request: session id {recv_req.session_id} does not exist" + f"Invalid request: session id {recv_req.session_params.id} does not exist" ) self.waiting_queue.append(req) return else: - # Create a new request from a previsou session - session = self.sessions[recv_req.session_id] + # Create a new request from a previous session + session = self.sessions[recv_req.session_params.id] req = session.create_req(recv_req, self.tokenizer) if isinstance(req.finished_reason, FINISH_ABORT): self.waiting_queue.append(req) return - # Handle image inputs + # Handle multimodal inputs if recv_req.image_inputs is not None: image_inputs = ImageInputs.from_dict(recv_req.image_inputs) # Expand a single image token into multiple dummy tokens for receiving image embeddings @@ -543,36 +653,36 @@ def handle_generate_request( req.extend_image_inputs(image_inputs) if len(req.origin_input_ids) >= self.max_req_input_len: - logger.error( + error_msg = ( "Multimodal prompt is too long after expanding multimodal tokens. " - f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}. " + f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}." ) + logger.error(error_msg) req.origin_input_ids = [0] req.image_inputs = None req.sampling_params.max_new_tokens = 0 req.finished_reason = FINISH_ABORT( - "Multimodal prompt is too long. Check server logs for details." + error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError" ) self.waiting_queue.append(req) return - # Copy more attributes - req.return_logprob = recv_req.return_logprob - req.top_logprobs_num = recv_req.top_logprobs_num - req.stream = recv_req.stream - req.logprob_start_len = recv_req.logprob_start_len + # Validate prompts length + error_msg = validate_input_length( + req, + self.max_req_input_len, + self.server_args.allow_auto_truncate, + ) + if error_msg: + self.waiting_queue.append(req) + return - if req.logprob_start_len == -1: + # Copy more attributes + if recv_req.logprob_start_len == -1: # By default, only return the logprobs for output tokens - req.logprob_start_len = len(recv_req.input_ids) - 1 - - # Truncate prompts that are too long - if len(req.origin_input_ids) > self.max_req_input_len: - logger.warning( - "Request length is longer than the KV cache pool size or " - "the max context length. Truncated!!!" - ) - req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len] + req.logprob_start_len = len(req.origin_input_ids) - 1 + else: + req.logprob_start_len = recv_req.logprob_start_len req.sampling_params.max_new_tokens = min( ( @@ -588,12 +698,15 @@ def handle_generate_request( if ( req.sampling_params.json_schema is not None or req.sampling_params.regex is not None + or req.sampling_params.ebnf is not None ): assert self.grammar_backend is not None if req.sampling_params.json_schema is not None: key = ("json", req.sampling_params.json_schema) elif req.sampling_params.regex is not None: key = ("regex", req.sampling_params.regex) + elif req.sampling_params.ebnf is not None: + key = ("ebnf", req.sampling_params.ebnf) req.grammar = self.grammar_backend.get_cached_value(key) if not req.grammar: @@ -617,27 +730,34 @@ def handle_embedding_request( ) req.tokenizer = self.tokenizer - # Truncate prompts that are too long - if len(req.origin_input_ids) >= self.max_req_input_len: - logger.warning( - "Request length is longer than the KV cache pool size or " - "the max context length. Truncated!!!" - ) - req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len] + # Validate prompts length + error_msg = validate_input_length( + req, + self.max_req_input_len, + self.server_args.allow_auto_truncate, + ) + if error_msg: + self.waiting_queue.append(req) + return + # Copy more attributes + req.logprob_start_len = len(req.origin_input_ids) - 1 self.waiting_queue.append(req) - def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked): - if isinstance(self.tree_cache, RadixCache): - self.tree_cache_metrics["total"] += ( - adder.log_input_tokens + adder.log_hit_tokens - ) / 10**9 - self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9 - tree_cache_hit_rate = ( - self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"] - ) - else: - tree_cache_hit_rate = 0.0 + def log_prefill_stats( + self, + adder: PrefillAdder, + can_run_list: List[Req], + running_bs: ScheduleBatch, + has_being_chunked: bool, + ): + self.tree_cache_metrics["total"] += ( + adder.log_input_tokens + adder.log_hit_tokens + ) / 10**9 + self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9 + tree_cache_hit_rate = ( + self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"] + ) num_used = self.max_total_num_tokens - ( self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() @@ -672,31 +792,56 @@ def log_decode_stats(self): self.num_generated_tokens = 0 self.last_decode_stats_tic = time.time() num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0 - logger.info( - f"Decode batch. " - f"#running-req: {num_running_reqs}, " - f"#token: {num_used}, " - f"token usage: {num_used / self.max_total_num_tokens:.2f}, " - f"gen throughput (token/s): {gen_throughput:.2f}, " - f"#queue-req: {len(self.waiting_queue)}" - ) + if self.spec_algorithm.is_none(): + msg = ( + f"Decode batch. " + f"#running-req: {num_running_reqs}, " + f"#token: {num_used}, " + f"token usage: {num_used / self.max_total_num_tokens:.2f}, " + f"gen throughput (token/s): {gen_throughput:.2f}, " + f"#queue-req: {len(self.waiting_queue)}" + ) + spec_accept_length = 0 + else: + spec_accept_length = ( + self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct + ) + self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0 + msg = ( + f"Decode batch. " + f"#running-req: {num_running_reqs}, " + f"#token: {num_used}, " + f"token usage: {num_used / self.max_total_num_tokens:.2f}, " + f"accept len: {spec_accept_length:.2f}, " + f"gen throughput (token/s): {gen_throughput:.2f}, " + f"#queue-req: {len(self.waiting_queue)}" + ) + + logger.info(msg) if self.enable_metrics: self.stats.num_running_reqs = num_running_reqs self.stats.num_used_tokens = num_used self.stats.token_usage = num_used / self.max_total_num_tokens self.stats.gen_throughput = gen_throughput self.stats.num_queue_reqs = len(self.waiting_queue) + self.stats.spec_accept_length = spec_accept_length self.metrics_collector.log_stats(self.stats) def check_memory(self): available_size = ( self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() ) - if available_size != self.max_total_num_tokens: + protected_size = self.tree_cache.protected_size() + memory_leak = available_size != ( + self.max_total_num_tokens + if not self.enable_hierarchical_cache + else self.max_total_num_tokens - protected_size + ) + if memory_leak: msg = ( "KV cache pool leak detected!" - f"{available_size=}, {self.max_total_num_tokens=}\n" + f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n" ) warnings.warn(msg) if crash_on_warnings(): @@ -712,7 +857,7 @@ def check_memory(self): if crash_on_warnings(): raise ValueError(msg) - def get_next_batch_to_run(self): + def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: # Merge the prefill batch into the running batch if self.last_batch and self.last_batch.forward_mode.is_extend(): if self.being_chunked_req: @@ -729,16 +874,23 @@ def get_next_batch_to_run(self): else: self.running_batch.merge_batch(self.last_batch) - # Run prefill first if possible new_batch = self.get_new_batch_prefill() if new_batch is not None: - return new_batch + # Run prefill first if possible + ret = new_batch + else: + # Run decode + if self.running_batch is None: + ret = None + else: + self.running_batch = self.update_running_batch(self.running_batch) + ret = self.running_batch - # Run decode - if self.running_batch is None: - return None - self.running_batch = self.update_running_batch(self.running_batch) - return self.running_batch + # Handle DP attention + if self.server_args.enable_dp_attention: + ret = self.prepare_dp_attn_batch(ret) + + return ret def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: # Check if the grammar is ready in the grammar queue @@ -762,9 +914,9 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: # Prefill policy adder = PrefillAdder( self.tree_cache, + self.token_to_kv_pool, self.running_batch, self.new_token_ratio, - self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(), self.max_prefill_tokens, self.chunked_prefill_size, running_bs if self.is_mixed_chunk else 0, @@ -804,7 +956,16 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: res = adder.add_one_req(req) if res != AddReqResult.CONTINUE: if res == AddReqResult.NO_TOKEN: - self.batch_is_full = True + if self.enable_hierarchical_cache: + # Set batch_is_full after making sure there are requests that can be served + self.batch_is_full = len(adder.can_run_list) > 0 or ( + self.running_batch is not None + and not self.running_batch.is_empty() + ) + else: + self.batch_is_full = True + break + if self.server_args.prefill_only_one_req: break # Update waiting queue @@ -823,7 +984,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.being_chunked_req.is_being_chunked += 1 # Print stats - if self.tp_rank == 0: + if self.attn_tp_rank == 0: self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked) # Create a new batch @@ -834,6 +995,8 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.tree_cache, self.model_config, self.enable_overlap, + self.spec_algorithm, + self.server_args.enable_custom_logit_processor, ) new_batch.prepare_for_extend() @@ -867,11 +1030,15 @@ def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]: return None # Check if decode out of memory - if not batch.check_decode_mem() or (test_retract and batch.batch_size() > 10): + if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or ( + test_retract and batch.batch_size() > 10 + ): old_ratio = self.new_token_ratio retracted_reqs, new_token_ratio = batch.retract_decode() self.new_token_ratio = new_token_ratio + if self.draft_worker: + self.draft_worker.finish_request(retracted_reqs) logger.info( "Decode out of memory happened. " @@ -886,7 +1053,7 @@ def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]: ) # Check for jump-forward - if not self.disable_jump_forward: + if not self.disable_jump_forward and batch.has_grammar: jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func) self.waiting_queue.extend(jump_forward_reqs) if batch.is_empty(): @@ -900,72 +1067,94 @@ def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]: batch.prepare_for_decode() return batch - def run_batch(self, batch: ScheduleBatch): + def run_batch( + self, batch: ScheduleBatch + ) -> Union[GenerationBatchResult, EmbeddingBatchResult]: """Run a batch.""" self.forward_ct += 1 if self.is_generation: - model_worker_batch = batch.get_model_worker_batch() - if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0: + if self.spec_algorithm.is_none(): + model_worker_batch = batch.get_model_worker_batch() logits_output, next_token_ids = self.tp_worker.forward_batch_generation( model_worker_batch ) - elif batch.forward_mode.is_idle(): - model_worker_batch = batch.get_model_worker_batch() - self.tp_worker.forward_batch_idle(model_worker_batch) - return else: - logits_output = None - if self.skip_tokenizer_init: - next_token_ids = torch.full( - (batch.batch_size(),), self.tokenizer.eos_token_id - ) - else: - next_token_ids = torch.full((batch.batch_size(),), 0) + ( + logits_output, + next_token_ids, + model_worker_batch, + num_accepted_tokens, + ) = self.draft_worker.forward_batch_speculative_generation(batch) + self.spec_num_total_accepted_tokens += ( + num_accepted_tokens + batch.batch_size() + ) + self.spec_num_total_forward_ct += batch.batch_size() + self.num_generated_tokens += num_accepted_tokens batch.output_ids = next_token_ids - ret = logits_output, next_token_ids, model_worker_batch.bid + + ret = GenerationBatchResult( + logits_output=logits_output, + next_token_ids=next_token_ids, + bid=model_worker_batch.bid, + ) else: # embedding or reward model - assert batch.extend_num_tokens != 0 model_worker_batch = batch.get_model_worker_batch() embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch) - ret = embeddings, model_worker_batch.bid + ret = EmbeddingBatchResult( + embeddings=embeddings, bid=model_worker_batch.bid + ) return ret - def process_batch_result(self, batch: ScheduleBatch, result): + def process_batch_result( + self, + batch: ScheduleBatch, + result: Union[GenerationBatchResult, EmbeddingBatchResult], + ): if batch.forward_mode.is_decode(): self.process_batch_result_decode(batch, result) if batch.is_empty(): self.running_batch = None elif batch.forward_mode.is_extend(): self.process_batch_result_prefill(batch, result) + elif batch.forward_mode.is_idle(): + if self.enable_overlap: + self.tp_worker.resolve_batch_result(result.bid) elif batch.forward_mode.is_dummy_first(): batch.next_batch_sampling_info.update_regex_vocab_mask() self.current_stream.synchronize() batch.next_batch_sampling_info.sampling_info_done.set() - def process_batch_result_prefill(self, batch: ScheduleBatch, result): + def process_batch_result_prefill( + self, + batch: ScheduleBatch, + result: Union[GenerationBatchResult, EmbeddingBatchResult], + ): + skip_stream_req = None if self.is_generation: - logits_output, next_token_ids, bid = result + ( + logits_output, + next_token_ids, + bid, + ) = ( + result.logits_output, + result.next_token_ids, + result.bid, + ) if self.enable_overlap: logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid) else: # Move next_token_ids and logprobs to cpu + next_token_ids = next_token_ids.tolist() if batch.return_logprob: logits_output.next_token_logprobs = ( - logits_output.next_token_logprobs[ - torch.arange(len(next_token_ids), device=self.device), - next_token_ids, - ].tolist() + logits_output.next_token_logprobs.tolist() ) logits_output.input_token_logprobs = ( logits_output.input_token_logprobs.tolist() ) - logits_output.normalized_prompt_logprobs = ( - logits_output.normalized_prompt_logprobs.tolist() - ) - next_token_ids = next_token_ids.tolist() # Check finish conditions logprob_pt = 0 @@ -980,7 +1169,6 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): continue if req.is_being_chunked <= 0: - req.completion_tokens_wo_jump_forward += 1 req.output_ids.append(next_token_id) req.check_finished() @@ -1000,6 +1188,10 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): else: # being chunked reqs' prefill is not finished req.is_being_chunked -= 1 + # There is only at most one request being currently chunked. + # Because this request does not finish prefill, + # we don't want to stream the request currently being chunked. + skip_stream_req = req if batch.next_batch_sampling_info: batch.next_batch_sampling_info.update_regex_vocab_mask() @@ -1007,7 +1199,7 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): batch.next_batch_sampling_info.sampling_info_done.set() else: # embedding or reward model - embeddings, bid = result + embeddings, bid = result.embeddings, result.bid embeddings = embeddings.tolist() # Check finish conditions @@ -1029,23 +1221,27 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): # being chunked reqs' prefill is not finished req.is_being_chunked -= 1 - self.stream_output(batch.reqs) + self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req) - def process_batch_result_decode(self, batch: ScheduleBatch, result): - logits_output, next_token_ids, bid = result + def process_batch_result_decode( + self, + batch: ScheduleBatch, + result: GenerationBatchResult, + ): + logits_output, next_token_ids, bid = ( + result.logits_output, + result.next_token_ids, + result.bid, + ) self.num_generated_tokens += len(batch.reqs) if self.enable_overlap: logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid) next_token_logprobs = logits_output.next_token_logprobs else: - # Move next_token_ids and logprobs to cpu - if batch.return_logprob: - next_token_logprobs = logits_output.next_token_logprobs[ - torch.arange(len(next_token_ids), device=self.device), - next_token_ids, - ].tolist() next_token_ids = next_token_ids.tolist() + if batch.return_logprob: + next_token_logprobs = logits_output.next_token_logprobs.tolist() self.token_to_kv_pool.free_group_begin() @@ -1059,19 +1255,25 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result): self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1]) continue - req.completion_tokens_wo_jump_forward += 1 - req.output_ids.append(next_token_id) + if batch.spec_algorithm.is_none(): + # speculative worker will solve the output_ids in speculative decoding + req.output_ids.append(next_token_id) + req.check_finished() if req.finished(): self.tree_cache.cache_finished_req(req) if req.return_logprob: - req.output_token_logprobs.append( - (next_token_logprobs[i], next_token_id) - ) + req.output_token_logprobs_val.append(next_token_logprobs[i]) + req.output_token_logprobs_idx.append(next_token_id) if req.top_logprobs_num > 0: - req.output_top_logprobs.append(logits_output.output_top_logprobs[i]) + req.output_top_logprobs_val.append( + logits_output.next_token_top_logprobs_val[i] + ) + req.output_top_logprobs_idx.append( + logits_output.next_token_top_logprobs_idx[i] + ) if req.grammar is not None: req.grammar.accept_token(next_token_id) @@ -1082,13 +1284,13 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result): self.current_stream.synchronize() batch.next_batch_sampling_info.sampling_info_done.set() - self.stream_output(batch.reqs) + self.stream_output(batch.reqs, batch.return_logprob) self.token_to_kv_pool.free_group_end() self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30) if ( - self.tp_rank == 0 + self.attn_tp_rank == 0 and self.forward_ct_decode % self.server_args.decode_log_interval == 0 ): self.log_decode_stats() @@ -1102,180 +1304,206 @@ def add_logprob_return_values( output: LogitsProcessorOutput, ): """Attach logprobs to the return values.""" - req.output_token_logprobs.append( - (output.next_token_logprobs[i], next_token_ids[i]) - ) + req.output_token_logprobs_val.append(output.next_token_logprobs[i]) + req.output_token_logprobs_idx.append(next_token_ids[i]) # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored. num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len - if req.normalized_prompt_logprob is None: - req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i] - - if req.input_token_logprobs is None: - input_token_logprobs = output.input_token_logprobs[ + if req.input_token_logprobs_val is None: + input_token_logprobs_val = output.input_token_logprobs[ pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens ] - input_token_ids = req.fill_ids[ + + input_token_logprobs_idx = req.fill_ids[ len(req.fill_ids) - num_input_logprobs + 1 : len(req.fill_ids) - req.last_update_decode_tokens ] - # Clip the padded hash values from image tokens. # Otherwise, it will lead to detokenization errors. - input_token_ids = [ + input_token_logprobs_idx = [ x if x < self.model_config.vocab_size - 1 else 0 - for x in input_token_ids + for x in input_token_logprobs_idx ] - req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids)) - if ( req.logprob_start_len == 0 ): # The first token does not have logprob, pad it. - req.input_token_logprobs = [ - (None, req.fill_ids[0]) - ] + req.input_token_logprobs + input_token_logprobs_val = [None] + input_token_logprobs_val + input_token_logprobs_idx = [req.fill_ids[0]] + input_token_logprobs_idx + + req.input_token_logprobs_val = input_token_logprobs_val + req.input_token_logprobs_idx = input_token_logprobs_idx if req.last_update_decode_tokens != 0: # Some decode tokens are re-computed in an extend batch - req.output_token_logprobs.extend( - list( - zip( - output.input_token_logprobs[ - pt - + num_input_logprobs - - 1 - - req.last_update_decode_tokens : pt - + num_input_logprobs - - 1 - ], - req.fill_ids[ - len(req.fill_ids) - - req.last_update_decode_tokens : len(req.fill_ids) - ], - ) - ) + req.output_token_logprobs_val.extend( + output.input_token_logprobs[ + pt + + num_input_logprobs + - 1 + - req.last_update_decode_tokens : pt + + num_input_logprobs + - 1 + ], + ) + req.output_token_logprobs_idx.extend( + req.fill_ids[ + len(req.fill_ids) + - req.last_update_decode_tokens : len(req.fill_ids) + ] ) if req.top_logprobs_num > 0: - if req.input_top_logprobs is None: - req.input_top_logprobs = output.input_top_logprobs[i] + if req.input_top_logprobs_val is None: + req.input_top_logprobs_val = output.input_top_logprobs_val[i] + req.input_top_logprobs_idx = output.input_top_logprobs_idx[i] if req.logprob_start_len == 0: - req.input_top_logprobs = [None] + req.input_top_logprobs + req.input_top_logprobs_val = [None] + req.input_top_logprobs_val + req.input_top_logprobs_idx = [None] + req.input_top_logprobs_idx if req.last_update_decode_tokens != 0: - req.output_top_logprobs.extend( - output.input_top_logprobs[i][-req.last_update_decode_tokens :] + req.output_top_logprobs_val.extend( + output.input_top_logprobs_val[i][-req.last_update_decode_tokens :] + ) + req.output_top_logprobs_idx.extend( + output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :] ) - req.output_top_logprobs.append(output.output_top_logprobs[i]) + + req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i]) + req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i]) return num_input_logprobs - def stream_output(self, reqs: List[Req]): + def stream_output( + self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None + ): """Stream the output to detokenizer.""" - output_rids = [] - output_meta_info: List[dict] = [] - output_finished_reason: List[BaseFinishReason] = [] + rids = [] + finished_reasons: List[BaseFinishReason] = [] + if self.is_generation: - output_vids = [] + vids = [] decoded_texts = [] - output_read_ids = [] - output_read_offsets = [] + decode_ids_list = [] + read_offsets = [] output_ids = [] - output_skip_special_tokens = [] - output_spaces_between_special_tokens = [] - output_no_stop_trim = [] - else: # embedding or reward model - output_embeddings = [] - is_stream_iter = self.forward_ct_decode % self.stream_interval == 0 + skip_special_tokens = [] + spaces_between_special_tokens = [] + no_stop_trim = [] + prompt_tokens = [] + completion_tokens = [] + cached_tokens = [] + spec_verify_ct = [] + + if return_logprob: + input_token_logprobs_val = [] + input_token_logprobs_idx = [] + output_token_logprobs_val = [] + output_token_logprobs_idx = [] + input_top_logprobs_val = [] + input_top_logprobs_idx = [] + output_top_logprobs_val = [] + output_top_logprobs_idx = [] + else: + input_token_logprobs_val = input_token_logprobs_idx = ( + output_token_logprobs_val + ) = output_token_logprobs_idx = input_top_logprobs_val = ( + input_top_logprobs_idx + ) = output_top_logprobs_val = output_top_logprobs_idx = None + + for req in reqs: + if req is skip_req: + continue - for req in reqs: - # TODO(lianmin): revisit this for overlap + retract + stream - if req.finished() or ( - req.stream and (is_stream_iter or len(req.output_ids) == 1) - ): - output_rids.append(req.rid) - output_finished_reason.append(req.finished_reason) - if self.is_generation: - output_vids.append(req.vid) + # TODO(lianmin): revisit this for overlap + retract + stream + if ( + req.finished() + # If stream, follow the given stream_interval + or (req.stream and len(req.output_ids) % self.stream_interval == 0) + # If not stream, we still want to output some tokens to get the benefit of incremental decoding. + or (not req.stream and len(req.output_ids) % 50 == 0) + ): + if self.draft_worker and req.finished(): + self.draft_worker.finish_request(req) + + rids.append(req.rid) + finished_reasons.append( + req.finished_reason.to_json() if req.finished_reason else None + ) + vids.append(req.vid) decoded_texts.append(req.decoded_text) - read_ids, read_offset = req.init_incremental_detokenize() - output_read_ids.append(read_ids) - output_read_offsets.append(read_offset) + decode_ids, read_offset = req.init_incremental_detokenize() + decode_ids_list.append(decode_ids) + read_offsets.append(read_offset) if self.skip_tokenizer_init: output_ids.append(req.output_ids) - output_skip_special_tokens.append( - req.sampling_params.skip_special_tokens - ) - output_spaces_between_special_tokens.append( + skip_special_tokens.append(req.sampling_params.skip_special_tokens) + spaces_between_special_tokens.append( req.sampling_params.spaces_between_special_tokens ) - output_no_stop_trim.append(req.sampling_params.no_stop_trim) - - meta_info = { - "prompt_tokens": len(req.origin_input_ids), - "completion_tokens": len(req.output_ids), - "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward, - "cached_tokens": req.cached_tokens, - "finish_reason": ( - req.finished_reason.to_json() - if req.finished_reason is not None - else None - ), - } - if req.return_logprob: - ( - meta_info["input_token_logprobs"], - meta_info["output_token_logprobs"], - meta_info["input_top_logprobs"], - meta_info["output_top_logprobs"], - meta_info["normalized_prompt_logprob"], - ) = ( - req.input_token_logprobs, - req.output_token_logprobs, - req.input_top_logprobs, - req.output_top_logprobs, - req.normalized_prompt_logprob, - ) - output_meta_info.append(meta_info) - else: # embedding or reward model - output_embeddings.append(req.embedding) - meta_info = { - "prompt_tokens": len(req.origin_input_ids), - } - output_meta_info.append(meta_info) - - # Send to detokenizer - if output_rids: - if self.is_generation: + no_stop_trim.append(req.sampling_params.no_stop_trim) + + prompt_tokens.append(len(req.origin_input_ids)) + completion_tokens.append(len(req.output_ids)) + cached_tokens.append(req.cached_tokens) + + if not self.spec_algorithm.is_none(): + spec_verify_ct.append(req.spec_verify_ct) + + if return_logprob: + input_token_logprobs_val.append(req.input_token_logprobs_val) + input_token_logprobs_idx.append(req.input_token_logprobs_idx) + output_token_logprobs_val.append(req.output_token_logprobs_val) + output_token_logprobs_idx.append(req.output_token_logprobs_idx) + input_top_logprobs_val.append(req.input_top_logprobs_val) + input_top_logprobs_idx.append(req.input_top_logprobs_idx) + output_top_logprobs_val.append(req.output_top_logprobs_val) + output_top_logprobs_idx.append(req.output_top_logprobs_idx) + + # Send to detokenizer + if rids: self.send_to_detokenizer.send_pyobj( BatchTokenIDOut( - output_rids, - output_vids, + rids, + finished_reasons, + vids, decoded_texts, - output_read_ids, - output_read_offsets, + decode_ids_list, + read_offsets, output_ids, - output_skip_special_tokens, - output_spaces_between_special_tokens, - output_meta_info, - output_finished_reason, - output_no_stop_trim, - ) - ) - else: # embedding or reward model - self.send_to_detokenizer.send_pyobj( - BatchEmbeddingOut( - output_rids, - output_embeddings, - output_meta_info, - output_finished_reason, + skip_special_tokens, + spaces_between_special_tokens, + no_stop_trim, + prompt_tokens, + completion_tokens, + cached_tokens, + spec_verify_ct, + input_token_logprobs_val, + input_token_logprobs_idx, + output_token_logprobs_val, + output_token_logprobs_idx, + input_top_logprobs_val, + input_top_logprobs_idx, + output_top_logprobs_val, + output_top_logprobs_idx, ) ) + else: # embedding or reward model + embeddings = [] + prompt_tokens = [] + for req in reqs: + if req.finished(): + rids.append(req.rid) + finished_reasons.append(req.finished_reason.to_json()) + embeddings.append(req.embedding) + prompt_tokens.append(len(req.origin_input_ids)) + self.send_to_detokenizer.send_pyobj( + BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens) + ) def prepare_dp_attn_batch(self, local_batch: ScheduleBatch): # Check if other DP workers have running batches @@ -1303,12 +1531,7 @@ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch): # Check forward mode for cuda graph if not self.server_args.disable_cuda_graph: forward_mode_state = torch.tensor( - ( - 1 - if local_batch.forward_mode.is_decode() - or local_batch.forward_mode.is_idle() - else 0 - ), + (1 if local_batch.forward_mode.is_decode_or_idle() else 0), dtype=torch.int32, ) torch.distributed.all_reduce( @@ -1328,6 +1551,8 @@ def get_idle_batch(self): self.tree_cache, self.model_config, self.enable_overlap, + self.spec_algorithm, + self.server_args.enable_custom_logit_processor, ) idle_batch.prepare_for_idle() return idle_batch @@ -1356,6 +1581,9 @@ def move_ready_grammar_requests(self): self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs]) self.grammar_queue = self.grammar_queue[num_ready_reqs:] + def flush_cache_wrapped(self, recv_req: FlushCacheReq): + self.flush_cache() + def flush_cache(self): """Flush the memory pool and cache.""" if len(self.waiting_queue) == 0 and ( @@ -1367,6 +1595,15 @@ def flush_cache(self): self.grammar_backend.reset() self.req_to_token_pool.clear() self.token_to_kv_pool.clear() + + if not self.spec_algorithm.is_none(): + self.draft_worker.model_runner.req_to_token_pool.clear() + self.draft_worker.model_runner.token_to_kv_pool.clear() + + self.num_generated_tokens = 0 + self.forward_ct_decode = 0 + self.spec_num_total_accepted_tokens = 0 + self.spec_num_total_forward_ct = 0 torch.cuda.empty_cache() logger.info("Cache flushed successfully!") if_success = True @@ -1408,16 +1645,17 @@ def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): assert flash_cache_success, "Cache flush failed after updating weights" else: logger.error(message) - return success, message + return UpdateWeightFromDiskReqOutput(success, message) def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput): """Initialize the online model parameter update group.""" success, message = self.tp_worker.init_weights_update_group(recv_req) - return success, message + return InitWeightsUpdateGroupReqOutput(success, message) def update_weights_from_distributed( - self, recv_req: UpdateWeightsFromDistributedReqInput - ): + self, + recv_req: UpdateWeightsFromDistributedReqInput, + ) -> Tuple[bool, str]: """Update the online model parameter.""" success, message = self.tp_worker.update_weights_from_distributed(recv_req) if success: @@ -1425,11 +1663,44 @@ def update_weights_from_distributed( assert flash_cache_success, "Cache flush failed after updating weights" else: logger.error(message) - return success, message + return UpdateWeightsFromDistributedReqOutput(success, message) + + def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput): + """Update the online model parameter from tensors.""" + success, message = self.tp_worker.update_weights_from_tensor(recv_req) + # TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later + if success: + flash_cache_success = self.flush_cache() + assert flash_cache_success, "Cache flush failed after updating weights" + else: + logger.error(message) + return UpdateWeightsFromTensorReqOutput(success, message) def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): parameter = self.tp_worker.get_weights_by_name(recv_req) - return parameter + return GetWeightsByNameReqOutput(parameter) + + def release_memory_occupation(self): + self.stashed_model_static_state = _export_static_state( + self.tp_worker.worker.model_runner.model + ) + self.memory_saver_adapter.pause() + self.flush_cache() + return ReleaseMemoryOccupationReqOutput() + + def resume_memory_occupation(self): + self.memory_saver_adapter.resume() + _import_static_state( + self.tp_worker.worker.model_runner.model, self.stashed_model_static_state + ) + del self.stashed_model_static_state + return ResumeMemoryOccupationReqOutput() + + def profile(self, recv_req: ProfileReq): + if recv_req == ProfileReq.START_PROFILE: + self.start_profile() + else: + self.stop_profile() def start_profile(self) -> None: if self.profiler is None: @@ -1445,16 +1716,20 @@ def stop_profile(self) -> None: ) logger.info("Profiler is done") - def open_session(self, recv_req: OpenSessionReqInput) -> str: + def open_session(self, recv_req: OpenSessionReqInput): # handle error session_id = recv_req.session_id if session_id in self.sessions: logger.warning(f"session id {session_id} already exist, cannot open.") + return OpenSessionReqOutput(session_id, False) + elif session_id is None: + logger.warning(f"session id is None, cannot open.") + return OpenSessionReqOutput(session_id, False) else: self.sessions[session_id] = Session( recv_req.capacity_of_str_len, session_id ) - return session_id + return OpenSessionReqOutput(session_id, True) def close_session(self, recv_req: CloseSessionReqInput): # handle error @@ -1465,6 +1740,20 @@ def close_session(self, recv_req: CloseSessionReqInput): del self.sessions[session_id] +def _export_static_state(model): + return dict( + buffers=[ + (name, buffer.detach().clone()) for name, buffer in model.named_buffers() + ] + ) + + +def _import_static_state(model, static_params): + self_named_buffers = dict(model.named_buffers()) + for name, tensor in static_params["buffers"]: + self_named_buffers[name][...] = tensor + + def run_scheduler_process( server_args: ServerArgs, port_args: PortArgs, @@ -1473,26 +1762,35 @@ def run_scheduler_process( dp_rank: Optional[int], pipe_writer, ): + setproctitle.setproctitle("sglang::scheduler") + faulthandler.enable() + # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var if dp_rank is None and "SGLANG_DP_RANK" in os.environ: dp_rank = int(os.environ["SGLANG_DP_RANK"]) + # Configue the logger if dp_rank is None: configure_logger(server_args, prefix=f" TP{tp_rank}") else: configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}") + suppress_other_loggers() - # set cpu affinity to this gpu process + # Set cpu affinity to this gpu process if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"): set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id) - suppress_other_loggers() parent_process = psutil.Process().parent() + # Create a scheduler and run the event loop try: scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank) pipe_writer.send( - {"status": "ready", "max_total_num_tokens": scheduler.max_total_num_tokens} + { + "status": "ready", + "max_total_num_tokens": scheduler.max_total_num_tokens, + "max_req_input_len": scheduler.max_req_input_len, + } ) if scheduler.enable_overlap: scheduler.event_loop_overlap() diff --git a/python/sglang/srt/managers/session_controller.py b/python/sglang/srt/managers/session_controller.py index dc5a1b670ea..4f4af636757 100644 --- a/python/sglang/srt/managers/session_controller.py +++ b/python/sglang/srt/managers/session_controller.py @@ -10,41 +10,116 @@ # limitations under the License. # ============================================================================== +import logging import uuid +from typing import Dict, Optional from sglang.srt.managers.io_struct import TokenizedGenerateReqInput -from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req +from sglang.srt.managers.schedule_batch import Req + + +class SessionReqNode: + def __init__(self, req, parent=None, childs=None): + self.req = req + self.parent = parent + if parent is not None: + parent.childs.append(self) + self.childs = [] if not childs else childs + + def clear_childs(self, req_dict): + for req_node in self.childs: + req_node.clear(req_dict) + self.childs = [] + + def clear(self, req_dict): + for req_node in self.childs: + req_node.clear(req_dict) + + if self.req.finished_reason == None: + self.req.to_abort = True + del req_dict[self.req.rid] + + def abort(self): + if self.req.finished_reason == None: + self.req.to_abort = True + + def __str__(self): + return self._str_helper(self.req.rid) + + def _str_helper(self, prefix=""): + if len(self.childs) == 0: + return prefix + "\n" + else: + origin_prefix = prefix + prefix += " -- " + self.childs[0].req.rid + ret = self.childs[0]._str_helper(prefix) + for child in self.childs[1:]: + prefix = " " * len(origin_prefix) + " \- " + child.req.rid + ret += child._str_helper(prefix) + return ret class Session: - def __init__(self, capacity_of_str_len: int, session_id: str = None): + def __init__(self, capacity_of_str_len: int, session_id: Optional[str] = None): self.session_id = session_id if session_id is not None else uuid.uuid4().hex self.capacity_of_str_len = capacity_of_str_len - self.reqs: List[Req] = [] + self.req_nodes: Dict[str, SessionReqNode] = {} def create_req(self, req: TokenizedGenerateReqInput, tokenizer): - if req.session_rid is not None: - while len(self.reqs) > 0: - if self.reqs[-1].rid == req.session_rid: - break - self.reqs = self.reqs[:-1] + assert req.session_params is not None + session_params = req.session_params + + last_req_node = None + last_req = None + abort = False + if session_params.replace: + if session_params.rid is None: + for _, req_node in self.req_nodes.items(): + req_node.clear(self.req_nodes) + else: + if session_params.rid not in self.req_nodes: + abort = True + else: + last_req_node = self.req_nodes[session_params.rid] + last_req_node.abort() + last_req = last_req_node.req + last_req_node.clear_childs(self.req_nodes) else: - self.reqs = [] - if len(self.reqs) > 0: + if session_params.rid is not None: + if session_params.rid not in self.req_nodes: + abort = True + else: + last_req_node = self.req_nodes[session_params.rid] + last_req = last_req_node.req + if not last_req.finished(): + logging.warning( + "The request in a session is appending to a request that hasn't finished." + ) + abort = True + + if last_req is not None: + # trim bos token if it is an append + if tokenizer is not None and req.input_ids[0] == tokenizer.bos_token_id: + req.input_ids = req.input_ids[1:] + input_ids = ( - self.reqs[-1].origin_input_ids - + self.reqs[-1].output_ids[ - : self.reqs[-1].sampling_params.max_new_tokens - ] - + req.input_ids + last_req.origin_input_ids + + last_req.output_ids[: last_req.sampling_params.max_new_tokens] ) + if session_params.offset and session_params.offset != 0: + input_ids = input_ids[: session_params.offset] + req.input_ids + else: + input_ids += req.input_ids input_ids_unpadded = ( - self.reqs[-1].origin_input_ids_unpadded - + self.reqs[-1].output_ids[ - : self.reqs[-1].sampling_params.max_new_tokens - ] - + req.input_ids + last_req.origin_input_ids_unpadded + + last_req.output_ids[: last_req.sampling_params.max_new_tokens] ) + if session_params.offset and session_params.offset != 0: + input_ids_unpadded = ( + input_ids_unpadded[: session_params.offset] + req.input_ids + ) + else: + input_ids_unpadded += req.input_ids else: input_ids = req.input_ids input_ids_unpadded = req.input_ids @@ -56,14 +131,15 @@ def create_req(self, req: TokenizedGenerateReqInput, tokenizer): sampling_params=req.sampling_params, lora_path=req.lora_path, session_id=self.session_id, + custom_logit_processor=req.custom_logit_processor, ) - if len(self.reqs) > 0: - new_req.image_inputs = self.reqs[-1].image_inputs + if last_req is not None: + new_req.image_inputs = last_req.image_inputs new_req.tokenizer = tokenizer - if req.session_rid is not None and len(self.reqs) == 0: - new_req.finished_reason = FINISH_ABORT( - f"Invalid request: requested session rid {req.session_rid} does not exist in the session history" - ) + if abort: + new_req.to_abort = True else: - self.reqs.append(new_req) + new_req_node = SessionReqNode(new_req, last_req_node) + self.req_nodes[req.rid] = new_req_node + return new_req diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 56e01528add..53e1f4edae0 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -18,11 +18,15 @@ import dataclasses import logging import os +import pickle import signal import sys +import threading import time import uuid -from typing import Dict, List, Optional, Tuple, Union +from datetime import datetime +from http import HTTPStatus +from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union import fastapi import uvloop @@ -30,6 +34,7 @@ import zmq.asyncio from fastapi import BackgroundTasks +from sglang.srt.aio_rwlock import RWLock from sglang.srt.configs.model_config import ModelConfig from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.managers.image_processor import ( @@ -42,6 +47,7 @@ BatchStrOut, BatchTokenIDOut, CloseSessionReqInput, + ConfigureLoggingReq, EmbeddingReqInput, FlushCacheReq, GenerateReqInput, @@ -52,17 +58,29 @@ OpenSessionReqInput, OpenSessionReqOutput, ProfileReq, + ReleaseMemoryOccupationReqInput, + ReleaseMemoryOccupationReqOutput, + ResumeMemoryOccupationReqInput, + ResumeMemoryOccupationReqOutput, + SessionParams, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqOutput, UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqOutput, + UpdateWeightsFromTensorReqInput, + UpdateWeightsFromTensorReqOutput, ) from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import get_zmq_socket, kill_process_tree +from sglang.srt.utils import ( + dataclass_to_string_truncated, + get_zmq_socket, + kill_process_tree, +) +from sglang.utils import TypeBasedDispatcher, get_exception_traceback asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -76,11 +94,15 @@ class ReqState: out_list: List finished: bool event: asyncio.Event + obj: Any # For metrics created_time: float first_token_time: Optional[float] = None + # For streaming output + last_output_offset: int = 0 + class TokenizerManager: """TokenizerManager is a process that tokenizes the text.""" @@ -91,16 +113,19 @@ def __init__( port_args: PortArgs, ): # Parse args + self.server_args = server_args self.enable_metrics = server_args.enable_metrics + self.log_requests = server_args.log_requests + self.log_requests_level = 0 # Init inter-process communication context = zmq.asyncio.Context(2) self.recv_from_detokenizer = get_zmq_socket( - context, zmq.PULL, port_args.tokenizer_ipc_name + context, zmq.PULL, port_args.tokenizer_ipc_name, True ) self.send_to_scheduler = get_zmq_socket( - context, zmq.PUSH, port_args.scheduler_input_ipc_name + context, zmq.PUSH, port_args.scheduler_input_ipc_name, True ) # Read model args @@ -119,6 +144,7 @@ def __init__( self.is_generation = self.model_config.is_generation self.context_len = self.model_config.context_len + self.image_token_id = self.model_config.image_token_id # Create image processor placeholder self.image_processor = get_dummy_image_processor() @@ -132,6 +158,7 @@ def __init__( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, ) self.tokenizer = self.processor.tokenizer os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -145,21 +172,48 @@ def __init__( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, ) # Store states - self.to_create_loop = True + self.no_create_loop = False self.rid_to_state: Dict[str, ReqState] = {} - - # For update model weights - self.model_update_lock = asyncio.Lock() - self.model_update_result = None + self.dump_requests_folder = "" # By default do not dump + self.dump_requests_threshold = 1000 + self.dump_request_list: List[Tuple] = [] + + # The event to notify the weight sync is finished. + self.model_update_lock = RWLock() + self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = ( + None + ) + self.asyncio_tasks = set() # For session info self.session_futures = {} # session_id -> asyncio event # Others self.gracefully_exit = False + self.init_weights_update_group_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.update_weights_from_distributed_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.update_weights_from_tensor_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.get_weights_by_name_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.release_memory_occupation_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.resume_memory_occupation_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + # Set after scheduler is initialized + self.max_req_input_len = None # Metrics if self.enable_metrics: @@ -170,6 +224,44 @@ def __init__( }, ) + self._result_dispatcher = TypeBasedDispatcher( + [ + ( + (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut), + self._handle_batch_output, + ), + (OpenSessionReqOutput, self._handle_open_session_req_output), + ( + UpdateWeightFromDiskReqOutput, + self._handle_update_weights_from_disk_req_output, + ), + ( + InitWeightsUpdateGroupReqOutput, + self.init_weights_update_group_communicator.handle_recv, + ), + ( + UpdateWeightsFromDistributedReqOutput, + self.update_weights_from_distributed_communicator.handle_recv, + ), + ( + UpdateWeightsFromTensorReqOutput, + self.update_weights_from_tensor_communicator.handle_recv, + ), + ( + GetWeightsByNameReqOutput, + self.get_weights_by_name_communicator.handle_recv, + ), + ( + ReleaseMemoryOccupationReqOutput, + self.release_memory_occupation_communicator.handle_recv, + ), + ( + ResumeMemoryOccupationReqOutput, + self.resume_memory_occupation_communicator.handle_recv, + ), + ] + ) + async def generate_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], @@ -177,11 +269,7 @@ async def generate_request( ): created_time = time.time() - if self.to_create_loop: - self.create_handle_loop() - - while self.model_update_lock.locked(): - await asyncio.sleep(0.001) + self.auto_create_handle_loop() if isinstance(obj, EmbeddingReqInput) and self.is_generation: raise ValueError( @@ -190,17 +278,25 @@ async def generate_request( ) obj.normalize_batch_and_arguments() - is_single = obj.is_single - if is_single: - tokenized_obj = await self._tokenize_one_request(obj) - self.send_to_scheduler.send_pyobj(tokenized_obj) - async for response in self._wait_one_response(obj, request, created_time): - yield response - else: - async for response in self._handle_batch_request( - obj, request, created_time - ): - yield response + + if self.log_requests: + max_length = 2048 if self.log_requests_level == 0 else 1 << 30 + logger.info( + f"Receive: obj={dataclass_to_string_truncated(obj, max_length)}" + ) + + async with self.model_update_lock.reader_lock: + is_single = obj.is_single + if is_single: + tokenized_obj = await self._tokenize_one_request(obj) + self._send_one_request(obj, tokenized_obj, created_time) + async for response in self._wait_one_response(obj, request): + yield response + else: + async for response in self._handle_batch_request( + obj, request, created_time + ): + yield response async def _tokenize_one_request( self, @@ -214,35 +310,58 @@ async def _tokenize_one_request( if not self.server_args.disable_radix_cache: raise ValueError( "input_embeds is provided while disable_radix_cache is False. " - "Please add `--disable-radix-cach` when you launch the server " + "Please add `--disable-radix-cache` when you launch the server " "if you want to use input_embeds as inputs." ) input_embeds = obj.input_embeds input_ids = obj.input_ids - elif obj.input_ids is None: - input_ids = self.tokenizer.encode(input_text) - else: + elif obj.input_ids is not None: input_ids = obj.input_ids + else: + if self.tokenizer is None: + raise ValueError( + "The engine initialized with skip_tokenizer_init=True cannot " + "accept text prompts. Please provide input_ids or re-initialize " + "the engine with skip_tokenizer_init=False." + ) + input_ids = self.tokenizer.encode(input_text) if self.is_generation: # TODO: also support getting embeddings for multimodal models image_inputs: Dict = await self.image_processor.process_images_async( - obj.image_data, input_text or input_ids, obj + obj.image_data, input_text or input_ids, obj, self.max_req_input_len ) if image_inputs and "input_ids" in image_inputs: input_ids = image_inputs["input_ids"] return_logprob = obj.return_logprob logprob_start_len = obj.logprob_start_len top_logprobs_num = obj.top_logprobs_num - session_id = obj.session[0] if obj.session else None - session_rid = obj.session[1] if obj.session else None + session_params = ( + SessionParams(**obj.session_params) if obj.session_params else None + ) - if obj.input_ids is not None and len(input_ids) >= self.context_len: + input_token_num = len(input_ids) if input_ids is not None else 0 + if input_token_num >= self.context_len: raise ValueError( - f"The input ({len(input_ids)} tokens) is longer than the " + f"The input ({input_token_num} tokens) is longer than the " f"model's context length ({self.context_len} tokens)." ) + if ( + obj.sampling_params.get("max_new_tokens") is not None + and obj.sampling_params.get("max_new_tokens") + input_token_num + >= self.context_len + ): + raise ValueError( + f"Requested token count exceeds the model's maximum context length " + f"of {self.context_len} tokens. You requested a total of " + f"{obj.sampling_params.get('max_new_tokens') + input_token_num} " + f"tokens: {input_token_num} tokens from the input messages and " + f"{obj.sampling_params.get('max_new_tokens')} tokens for the " + f"completion. Please reduce the number of tokens in the input " + f"messages or the completion to fit within the limit." + ) + # Parse sampling parameters sampling_params = SamplingParams(**obj.sampling_params) sampling_params.normalize(self.tokenizer) @@ -262,8 +381,8 @@ async def _tokenize_one_request( obj.stream, lora_path=obj.lora_path, input_embeds=input_embeds, - session_id=session_id, - session_rid=session_rid, + session_params=session_params, + custom_logit_processor=obj.custom_logit_processor, ) elif isinstance(obj, EmbeddingReqInput): tokenized_obj = TokenizedEmbeddingReqInput( @@ -275,16 +394,24 @@ async def _tokenize_one_request( return tokenized_obj - async def _wait_one_response( + def _send_one_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], - request: Optional[fastapi.Request] = None, + tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput], created_time: Optional[float] = None, ): - """Wait for the response of one request.""" event = asyncio.Event() - state = ReqState([], False, event, created_time=created_time) + state = ReqState([], False, event, obj, created_time=created_time) self.rid_to_state[obj.rid] = state + self.send_to_scheduler.send_pyobj(tokenized_obj) + + async def _wait_one_response( + self, + obj: Union[GenerateReqInput, EmbeddingReqInput], + request: Optional[fastapi.Request] = None, + ): + """Wait for the response of one request.""" + state = self.rid_to_state[obj.rid] while True: try: @@ -295,27 +422,36 @@ async def _wait_one_response( raise ValueError(f"Abort request {obj.rid}") continue - if isinstance(obj, GenerateReqInput): - out = self.convert_logprob_style( - state.out_list[-1], - obj.return_logprob, - obj.top_logprobs_num, - obj.return_text_in_logprobs, - ) - else: # isinstance(obj, (EmbeddingReqInput,)) - out = state.out_list[-1] + out = state.out_list[-1] state.out_list = [] if state.finished: - if self.server_args.log_requests: - # Log requests - logger.info(f"in={obj}, out={out}") + if self.log_requests: + max_length = 2048 if self.log_requests_level == 0 else 1 << 30 + msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length)}, out={dataclass_to_string_truncated(out, max_length)}" + logger.info(msg) del self.rid_to_state[obj.rid] + + # Check if this was an abort/error created by scheduler + if isinstance(out["meta_info"].get("finish_reason"), dict): + finish_reason = out["meta_info"]["finish_reason"] + if ( + finish_reason.get("type") == "abort" + and finish_reason.get("status_code") == HTTPStatus.BAD_REQUEST + ): + raise ValueError(finish_reason["message"]) + yield out break state.event.clear() - yield out + + if obj.stream: + yield out + else: + if request is not None and await request.is_disconnected(): + self.abort_request(obj.rid) + raise ValueError(f"Abort request {obj.rid}") async def _handle_batch_request( self, @@ -332,10 +468,8 @@ async def _handle_batch_request( for i in range(batch_size): tmp_obj = obj[i] tokenized_obj = await self._tokenize_one_request(tmp_obj) - self.send_to_scheduler.send_pyobj(tokenized_obj) - generators.append( - self._wait_one_response(tmp_obj, request, created_time) - ) + self._send_one_request(tmp_obj, tokenized_obj, created_time) + generators.append(self._wait_one_response(tmp_obj, request)) rids.append(tmp_obj.rid) else: # FIXME: When using batch and parallel_sample_num together, the perf is not optimal. @@ -360,10 +494,8 @@ async def _handle_batch_request( tokenized_obj.sampling_params = copy.copy(tokenized_obj.sampling_params) tokenized_obj.sampling_params.max_new_tokens = 0 tokenized_obj.stream = False - self.send_to_scheduler.send_pyobj(tokenized_obj) - await self._wait_one_response( - tmp_obj, request, created_time - ).__anext__() + self._send_one_request(tmp_obj, tokenized_obj, created_time) + await self._wait_one_response(tmp_obj, request).__anext__() # Expand requests, assign new rids for them, and send them for i in range(batch_size): @@ -371,10 +503,8 @@ async def _handle_batch_request( tmp_obj = copy.copy(objs[i]) tokenized_obj = copy.copy(tokenized_objs[i]) tokenized_obj.rid = tmp_obj.regenerate_rid() - self.send_to_scheduler.send_pyobj(tokenized_obj) - generators.append( - self._wait_one_response(tmp_obj, request, created_time) - ) + self._send_one_request(tmp_obj, tokenized_obj, created_time) + generators.append(self._wait_one_response(tmp_obj, request)) rids.append(tmp_obj.rid) # Wait for all requests @@ -424,127 +554,150 @@ async def update_weights_from_disk( self, obj: UpdateWeightFromDiskReqInput, request: Optional[fastapi.Request] = None, - ): - if self.to_create_loop: - self.create_handle_loop() + ) -> Tuple[bool, str]: + self.auto_create_handle_loop() # default the load format to the server_args if obj.load_format is None: obj.load_format = self.server_args.load_format + logger.info("Start update_weights. Load format=%s", obj.load_format) - if not self.model_update_lock.locked(): - - async with self.model_update_lock: - # wait for the previous generation requests to finish - for i in range(3): - while len(self.rid_to_state) > 0: - await asyncio.sleep(0.001) - # FIXME: We add some sleep here to avoid some race conditions. - # We can use a read-write lock as a better fix. - await asyncio.sleep(0.01) - self.send_to_scheduler.send_pyobj(obj) - self.model_update_result = asyncio.Future() - - if self.server_args.dp_size == 1: - result = await self.model_update_result - if result.success: - self.server_args.model_path = obj.model_path - self.server_args.load_format = obj.load_format - self.model_path = obj.model_path - return result.success, result.message - else: # self.server_args.dp_size > 1 - self.model_update_tmp = [] - result = await self.model_update_result - - all_success = all([r.success for r in result]) - if all_success is True: - self.server_args.model_path = obj.model_path - self.server_args.load_format = obj.load_format - self.model_path = obj.model_path - all_message = [r.message for r in result] - all_message = " | ".join(all_message) - return all_success, all_message + if True: + # Hold the lock if it is not async. This means that weight sync + # cannot run while requests are in progress. + async with self.model_update_lock.writer_lock: + return await self._wait_for_model_update_from_disk(obj) - else: - return False, "Another update is in progress. Please try again later." + async def _wait_for_model_update_from_disk( + self, obj: UpdateWeightFromDiskReqInput + ) -> Tuple[bool, str]: + self.send_to_scheduler.send_pyobj(obj) + self.model_update_result = asyncio.Future() + if self.server_args.dp_size == 1: + result = await self.model_update_result + if result.success: + self.served_model_name = obj.model_path + self.server_args.model_path = obj.model_path + self.server_args.load_format = obj.load_format + self.model_path = obj.model_path + return result.success, result.message + else: # self.server_args.dp_size > 1 + self.model_update_tmp = [] + result = await self.model_update_result + + all_success = all([r.success for r in result]) + if all_success is True: + self.server_args.model_path = obj.model_path + self.server_args.load_format = obj.load_format + self.model_path = obj.model_path + all_message = [r.message for r in result] + all_message = " | ".join(all_message) + return all_success, all_message async def init_weights_update_group( self, obj: InitWeightsUpdateGroupReqInput, request: Optional[fastapi.Request] = None, - ) -> bool: - if self.to_create_loop: - self.create_handle_loop() - self.send_to_scheduler.send_pyobj(obj) - - self.init_weights_update_group_result = asyncio.Future() + ) -> Tuple[bool, str]: + self.auto_create_handle_loop() assert ( self.server_args.dp_size == 1 ), "dp_size must be 1 for init parameter update group" - result = await self.init_weights_update_group_result + result = (await self.init_weights_update_group_communicator(obj))[0] return result.success, result.message async def update_weights_from_distributed( self, obj: UpdateWeightsFromDistributedReqInput, request: Optional[fastapi.Request] = None, - ): - if self.to_create_loop: - self.create_handle_loop() - - if not self.model_update_lock.locked(): - async with self.model_update_lock: - self.send_to_scheduler.send_pyobj(obj) - self.parameter_update_result = asyncio.Future() - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be for update weights from distributed" - result = await self.parameter_update_result - return result.success, result.message - else: - logger.error("Another parameter update is in progress in tokenizer manager") - return ( - False, - "Another parameter update is in progress. Please try again later.", - ) + ) -> Tuple[bool, str]: + self.auto_create_handle_loop() + assert ( + self.server_args.dp_size == 1 + ), "dp_size must be for update weights from distributed" + + # This means that weight sync + # cannot run while requests are in progress. + async with self.model_update_lock.writer_lock: + result = (await self.update_weights_from_distributed_communicator(obj))[0] + return result.success, result.message + + async def update_weights_from_tensor( + self, + obj: UpdateWeightsFromTensorReqInput, + request: Optional[fastapi.Request] = None, + ) -> Tuple[bool, str]: + self.auto_create_handle_loop() + assert ( + self.server_args.dp_size == 1 + ), "dp_size must be for update weights from distributed" + + # This means that weight sync + # cannot run while requests are in progress. + async with self.model_update_lock.writer_lock: + result = (await self.update_weights_from_tensor_communicator(obj))[0] + return result.success, result.message async def get_weights_by_name( self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None ): - if self.to_create_loop: - self.create_handle_loop() - - self.send_to_scheduler.send_pyobj(obj) - self.get_weights_by_name_result = asyncio.Future() + self.auto_create_handle_loop() + results = await self.get_weights_by_name_communicator(obj) + all_parameters = [r.parameter for r in results] if self.server_args.dp_size == 1: - result = await self.get_weights_by_name_result - return result.parameter + return all_parameters[0] else: - self.get_weights_by_name_tmp = [] - result = await self.get_weights_by_name_result - all_parameters = [r.parameter for r in result] return all_parameters + async def release_memory_occupation( + self, + obj: ReleaseMemoryOccupationReqInput, + request: Optional[fastapi.Request] = None, + ): + self.auto_create_handle_loop() + await self.release_memory_occupation_communicator(obj) + + async def resume_memory_occupation( + self, + obj: ResumeMemoryOccupationReqInput, + request: Optional[fastapi.Request] = None, + ): + self.auto_create_handle_loop() + await self.resume_memory_occupation_communicator(obj) + async def open_session( self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None ): - if self.to_create_loop: - self.create_handle_loop() + self.auto_create_handle_loop() + + if obj.session_id is None: + obj.session_id = uuid.uuid4().hex + elif obj.session_id in self.session_futures: + return None - session_id = uuid.uuid4().hex - obj.session_id = session_id self.send_to_scheduler.send_pyobj(obj) - self.session_futures[session_id] = asyncio.Future() - session_id = await self.session_futures[session_id] - del self.session_futures[session_id] + + self.session_futures[obj.session_id] = asyncio.Future() + session_id = await self.session_futures[obj.session_id] + del self.session_futures[obj.session_id] return session_id async def close_session( self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None ): - assert not self.to_create_loop, "close session should not be the first request" await self.send_to_scheduler.send_pyobj(obj) + def configure_logging(self, obj: ConfigureLoggingReq): + if obj.log_requests is not None: + self.log_requests = obj.log_requests + if obj.log_requests_level is not None: + self.log_requests_level = obj.log_requests_level + if obj.dump_requests_folder is not None: + self.dump_requests_folder = obj.dump_requests_folder + if obj.dump_requests_threshold is not None: + self.dump_requests_threshold = obj.dump_requests_threshold + logging.info(f"Config logging: {obj=}") + def create_abort_task(self, obj: GenerateReqInput): # Abort the request if the client is disconnected. async def abort_request(): @@ -559,23 +712,36 @@ async def abort_request(): background_tasks.add_task(abort_request) return background_tasks - def create_handle_loop(self): - if not self.to_create_loop: + def auto_create_handle_loop(self): + if self.no_create_loop: return - self.to_create_loop = False + self.no_create_loop = True loop = asyncio.get_event_loop() - loop.create_task(self.handle_loop()) + self.asyncio_tasks.add( + loop.create_task(print_exception_wrapper(self.handle_loop)) + ) - signal_handler = SignalHandler(self) - loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler) - loop.create_task(self.sigterm_watchdog()) + # We cannot add signal handler when the tokenizer manager is not in + # the main thread due to the CPython limitation. + if threading.current_thread() is threading.main_thread(): + signal_handler = SignalHandler(self) + loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler) + else: + logger.warning( + "Signal handler is not added because the tokenizer manager is " + "not in the main thread. This disables graceful shutdown of the " + "tokenizer manager when SIGTERM is received." + ) + self.asyncio_tasks.add( + loop.create_task(print_exception_wrapper(self.sigterm_watchdog)) + ) async def sigterm_watchdog(self): while not self.gracefully_exit: - await asyncio.sleep(60) + await asyncio.sleep(5) - # drain requests + # Drain requests while True: remain_num_req = len(self.rid_to_state) logger.info( @@ -593,160 +759,222 @@ async def handle_loop(self): """The event loop that handles requests""" while True: - recv_obj: Union[ - BatchStrOut, - BatchEmbeddingOut, - BatchTokenIDOut, - UpdateWeightFromDiskReqOutput, - UpdateWeightsFromDistributedReqOutput, - GetWeightsByNameReqOutput, - InitWeightsUpdateGroupReqOutput, - ] = await self.recv_from_detokenizer.recv_pyobj() - - if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)): - for i, rid in enumerate(recv_obj.rids): - state = self.rid_to_state.get(rid, None) - if state is None: - continue - - recv_obj.meta_info[i]["id"] = rid - if isinstance(recv_obj, BatchStrOut): - out_dict = { - "text": recv_obj.output_strs[i], - "meta_info": recv_obj.meta_info[i], - } - elif isinstance(recv_obj, BatchTokenIDOut): - out_dict = { - "token_ids": recv_obj.output_ids[i], - "meta_info": recv_obj.meta_info[i], - } - else: - assert isinstance(recv_obj, BatchEmbeddingOut) - out_dict = { - "embedding": recv_obj.embeddings[i], - "meta_info": recv_obj.meta_info[i], - } - state.out_list.append(out_dict) - state.finished = recv_obj.finished_reason[i] is not None - state.event.set() - - if self.enable_metrics: - completion_tokens = recv_obj.meta_info[i]["completion_tokens"] - - if state.first_token_time is None: - state.first_token_time = time.time() - self.metrics_collector.observe_time_to_first_token( - state.first_token_time - state.created_time - ) - else: - if completion_tokens >= 2: - self.metrics_collector.observe_time_per_output_token( - (time.time() - state.first_token_time) - / (completion_tokens - 1) - ) - - if state.finished: - self.metrics_collector.inc_prompt_tokens( - recv_obj.meta_info[i]["prompt_tokens"] - ) - self.metrics_collector.inc_generation_tokens( - completion_tokens - ) - self.metrics_collector.observe_e2e_request_latency( - time.time() - state.created_time - ) - if completion_tokens >= 1: - self.metrics_collector.observe_time_per_output_token( - (time.time() - state.created_time) - / completion_tokens - ) - elif isinstance(recv_obj, OpenSessionReqOutput): - self.session_futures[recv_obj.session_id].set_result( - recv_obj.session_id + recv_obj = await self.recv_from_detokenizer.recv_pyobj() + self._result_dispatcher(recv_obj) + + def _handle_batch_output( + self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut] + ): + for i, rid in enumerate(recv_obj.rids): + state = self.rid_to_state.get(rid, None) + if state is None: + continue + + meta_info = { + "id": rid, + "finish_reason": recv_obj.finished_reasons[i], + "prompt_tokens": recv_obj.prompt_tokens[i], + } + + if getattr(state.obj, "return_logprob", False): + self.convert_logprob_style( + meta_info, + state.obj.top_logprobs_num, + state.obj.return_text_in_logprobs, + recv_obj, + i, ) - elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput): - if self.server_args.dp_size == 1: - self.model_update_result.set_result(recv_obj) - else: # self.server_args.dp_size > 1 - self.model_update_tmp.append(recv_obj) - # set future if the all results are recevied - if len(self.model_update_tmp) == self.server_args.dp_size: - self.model_update_result.set_result(self.model_update_tmp) - elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput): - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for init parameter update group" - self.init_weights_update_group_result.set_result(recv_obj) - elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput): - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for update weights from distributed" - self.parameter_update_result.set_result(recv_obj) - elif isinstance(recv_obj, GetWeightsByNameReqOutput): - if self.server_args.dp_size == 1: - self.get_weights_by_name_result.set_result(recv_obj) - else: - self.get_weights_by_name_tmp.append(recv_obj) - if len(self.get_weights_by_name_tmp) == self.server_args.dp_size: - self.get_weights_by_name_result.set_result( - self.get_weights_by_name_tmp - ) + + if self.server_args.speculative_algorithm: + meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i] + + if not isinstance(recv_obj, BatchEmbeddingOut): + meta_info.update( + { + "completion_tokens": recv_obj.completion_tokens[i], + "cached_tokens": recv_obj.cached_tokens[i], + } + ) + + if isinstance(recv_obj, BatchStrOut): + out_dict = { + "text": recv_obj.output_strs[i], + "meta_info": meta_info, + } + elif isinstance(recv_obj, BatchTokenIDOut): + out_dict = { + "token_ids": recv_obj.output_ids[i], + "meta_info": meta_info, + } else: - raise ValueError(f"Invalid object: {recv_obj=}") + assert isinstance(recv_obj, BatchEmbeddingOut) + out_dict = { + "embedding": recv_obj.embeddings[i], + "meta_info": meta_info, + } + + state.out_list.append(out_dict) + state.finished = recv_obj.finished_reasons[i] is not None + state.event.set() + + if self.enable_metrics and state.obj.log_metrics: + self.collect_metrics(state, recv_obj, i) + if self.dump_requests_folder and state.finished and state.obj.log_metrics: + self.dump_requests(state, out_dict) def convert_logprob_style( self, - ret: dict, - return_logprob: bool, + meta_info: dict, top_logprobs_num: int, return_text_in_logprobs: bool, + recv_obj: BatchStrOut, + recv_obj_index: int, ): - if return_logprob: - ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens( - ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs + meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens( + recv_obj.input_token_logprobs_val[recv_obj_index], + recv_obj.input_token_logprobs_idx[recv_obj_index], + return_text_in_logprobs, + ) + meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens( + recv_obj.output_token_logprobs_val[recv_obj_index], + recv_obj.output_token_logprobs_idx[recv_obj_index], + return_text_in_logprobs, + ) + + if top_logprobs_num > 0: + meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens( + recv_obj.input_top_logprobs_val[recv_obj_index], + recv_obj.input_top_logprobs_idx[recv_obj_index], + return_text_in_logprobs, ) - ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens( - ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs + meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens( + recv_obj.output_top_logprobs_val[recv_obj_index], + recv_obj.output_top_logprobs_idx[recv_obj_index], + return_text_in_logprobs, ) - if top_logprobs_num > 0: - ret["meta_info"]["input_top_logprobs"] = ( - self.detokenize_top_logprobs_tokens( - ret["meta_info"]["input_top_logprobs"], - return_text_in_logprobs, - ) - ) - ret["meta_info"]["output_top_logprobs"] = ( - self.detokenize_top_logprobs_tokens( - ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs - ) - ) - return ret - def detokenize_logprob_tokens( - self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool + self, + token_logprobs_val: List[float], + token_logprobs_idx: List[int], + decode_to_text: bool, ): - # TODO(lianmin): This should run on DetokenizerManager if not decode_to_text: - return [(logprob, token_id, None) for logprob, token_id in token_logprobs] - - assert self.tokenizer is not None - token_ids = [tid for _, tid in token_logprobs] - token_texts = self.tokenizer.batch_decode(token_ids) - return [ - (logprob, token_id, token_text) - for (logprob, token_id), token_text in zip(token_logprobs, token_texts) - ] + return [ + (logprob, token_id, None) + for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx) + ] + else: + assert self.tokenizer is not None + token_texts = self.tokenizer.batch_decode(token_logprobs_idx) + return list(zip(token_logprobs_val, token_logprobs_idx, token_texts)) - def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool): + def detokenize_top_logprobs_tokens( + self, + token_logprobs_val: List[float], + token_logprobs_idx: List[int], + decode_to_text: bool, + ): # TODO: The current implementation only batches the detokenization for top-k tokens per single position. # We should batch all top-k tokens in all positions. - for i, token_top_logprobs in enumerate(top_logprobs): - if token_top_logprobs: - top_logprobs[i] = self.detokenize_logprob_tokens( - token_top_logprobs, decode_to_text + ret = [] + for i in range(len(token_logprobs_val)): + if token_logprobs_val[i]: + ret.append( + self.detokenize_logprob_tokens( + token_logprobs_val[i], token_logprobs_idx[i], decode_to_text + ) + ) + else: + ret.append(None) + return ret + + def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int): + completion_tokens = ( + recv_obj.completion_tokens[i] + if getattr(recv_obj, "completion_tokens", None) + else 0 + ) + + if state.first_token_time is None: + state.first_token_time = time.time() + self.metrics_collector.observe_time_to_first_token( + state.first_token_time - state.created_time + ) + else: + if completion_tokens >= 2: + # Compute time_per_output_token for the streaming case + self.metrics_collector.observe_time_per_output_token( + (time.time() - state.first_token_time) / (completion_tokens - 1) ) - return top_logprobs + + if state.finished: + self.metrics_collector.observe_one_finished_request( + recv_obj.prompt_tokens[i], completion_tokens + ) + self.metrics_collector.observe_e2e_request_latency( + time.time() - state.created_time + ) + # Compute time_per_output_token for the non-streaming case + if ( + hasattr(state.obj, "stream") + and not state.obj.stream + and completion_tokens >= 1 + ): + self.metrics_collector.observe_time_per_output_token( + (time.time() - state.created_time) / completion_tokens + ) + + def dump_requests(self, state: ReqState, out_dict: dict): + self.dump_request_list.append( + (state.obj, out_dict, state.created_time, time.time()) + ) + + if len(self.dump_request_list) >= self.dump_requests_threshold: + filename = os.path.join( + self.dump_requests_folder, + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl", + ) + logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}") + + to_dump = self.dump_request_list + self.dump_request_list = [] + + def background_task(): + os.makedirs(self.dump_requests_folder, exist_ok=True) + with open(filename, "wb") as f: + pickle.dump(to_dump, f) + + # Schedule the task to run in the background without awaiting it + asyncio.create_task(asyncio.to_thread(background_task)) + + def _handle_open_session_req_output(self, recv_obj): + self.session_futures[recv_obj.session_id].set_result( + recv_obj.session_id if recv_obj.success else None + ) + + def _handle_update_weights_from_disk_req_output(self, recv_obj): + if self.server_args.dp_size == 1: + self.model_update_result.set_result(recv_obj) + else: # self.server_args.dp_size > 1 + self.model_update_tmp.append(recv_obj) + # set future if the all results are recevied + if len(self.model_update_tmp) == self.server_args.dp_size: + self.model_update_result.set_result(self.model_update_tmp) + + +async def print_exception_wrapper(func): + """ + Sometimes an asyncio function does not print exception. + We do another wrapper to handle the exception. + """ + try: + await func() + except Exception: + traceback = get_exception_traceback() + logger.error(f"TokenizerManager hit an exception: {traceback}") + kill_process_tree(os.getpid(), include_parent=True) + sys.exit(1) class SignalHandler: @@ -758,3 +986,28 @@ def signal_handler(self, signum=None, frame=None): f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..." ) self.tokenizer_manager.gracefully_exit = True + + +T = TypeVar("T") + + +class _Communicator(Generic[T]): + def __init__(self, sender, fan_out: int): + self._sender = sender + self._fan_out = fan_out + self._result_future: Optional[asyncio.Future] = None + self._result_values: Optional[List[T]] = None + + async def __call__(self, obj): + self._sender.send_pyobj(obj) + self._result_future = asyncio.Future() + self._result_values = [] + await self._result_future + result_values = self._result_values + self._result_future = self._result_values = None + return result_values + + def handle_recv(self, recv_obj: T): + self._result_values.append(recv_obj) + if len(self._result_values) == self._fan_out: + self._result_future.set_result(None) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 3aa06b4b8ee..fd4dbae9900 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -24,12 +24,13 @@ InitWeightsUpdateGroupReqInput, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, + UpdateWeightsFromTensorReqInput, ) from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.server_args import ServerArgs -from sglang.srt.utils import broadcast_pyobj, set_random_seed +from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed logger = logging.getLogger(__name__) @@ -44,13 +45,18 @@ def __init__( tp_rank: int, dp_rank: Optional[int], nccl_port: int, + is_draft_worker: bool = False, ): # Parse args self.tp_rank = tp_rank # Init model and tokenizer self.model_config = ModelConfig( - server_args.model_path, + ( + server_args.model_path + if not is_draft_worker + else server_args.speculative_draft_model_path + ), trust_remote_code=server_args.trust_remote_code, revision=server_args.revision, context_length=server_args.context_length, @@ -67,6 +73,7 @@ def __init__( tp_size=server_args.tp_size, nccl_port=nccl_port, server_args=server_args, + is_draft_worker=is_draft_worker, ) if server_args.skip_tokenizer_init: self.tokenizer = self.processor = None @@ -76,6 +83,7 @@ def __init__( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, ) self.tokenizer = self.processor.tokenizer else: @@ -83,6 +91,7 @@ def __init__( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, ) self.device = self.model_runner.device @@ -94,6 +103,7 @@ def __init__( self.max_total_num_tokens // 2 if server_args.max_running_requests is None else server_args.max_running_requests + // (server_args.dp_size if server_args.enable_dp_attention else 1) ), self.model_runner.req_to_token_pool.size, ) @@ -135,26 +145,31 @@ def get_pad_input_ids_func(self): def get_tp_cpu_group(self): return self.model_runner.tp_group.cpu_group + def get_attention_tp_cpu_group(self): + return self.model_runner.attention_tp_group.cpu_group + def get_memory_pool(self): return ( self.model_runner.req_to_token_pool, self.model_runner.token_to_kv_pool, ) - def forward_batch_idle(self, model_worker_batch: ModelWorkerBatch): - forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) - self.model_runner.forward(forward_batch) - def forward_batch_generation( self, model_worker_batch: ModelWorkerBatch, launch_done: Optional[threading.Event] = None, + skip_sample: bool = False, ): forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) logits_output = self.model_runner.forward(forward_batch) if launch_done: launch_done.set() - next_token_ids = self.model_runner.sample(logits_output, model_worker_batch) + + if skip_sample: + next_token_ids = None + else: + next_token_ids = self.model_runner.sample(logits_output, model_worker_batch) + return logits_output, next_token_ids def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch): @@ -188,6 +203,12 @@ def update_weights_from_distributed( ) return success, message + def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput): + success, message = self.model_runner.update_weights_from_tensor( + MultiprocessingSerializer.deserialize(recv_req.serialized_named_tensors) + ) + return success, message + def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): parameter = self.model_runner.get_weights_by_name( recv_req.name, recv_req.truncate_size diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index a9db1878391..961b0bbdc11 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -28,6 +28,7 @@ InitWeightsUpdateGroupReqInput, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, + UpdateWeightsFromTensorReqInput, ) from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.tp_worker import TpModelWorker @@ -81,6 +82,8 @@ def __init__( self.forward_thread.start() self.parent_process = psutil.Process().parent() self.scheduler_stream = torch.get_device_module(self.device).current_stream() + if self.device == "cpu": + self.scheduler_stream.synchronize = lambda: None # No-op for CPU def get_worker_info(self): return self.worker.get_worker_info() @@ -91,6 +94,9 @@ def get_pad_input_ids_func(self): def get_tp_cpu_group(self): return self.worker.get_tp_cpu_group() + def get_attention_tp_cpu_group(self): + return self.worker.get_attention_tp_cpu_group() + def get_memory_pool(self): return ( self.worker.model_runner.req_to_token_pool, @@ -143,19 +149,13 @@ def forward_thread_func_(self): # Copy results to the CPU if model_worker_batch.return_logprob: - logits_output.next_token_logprobs = logits_output.next_token_logprobs[ - torch.arange(len(next_token_ids), device=self.device), - next_token_ids, - ].to("cpu", non_blocking=True) + logits_output.next_token_logprobs = ( + logits_output.next_token_logprobs.to("cpu", non_blocking=True) + ) if logits_output.input_token_logprobs is not None: logits_output.input_token_logprobs = ( logits_output.input_token_logprobs.to("cpu", non_blocking=True) ) - logits_output.normalized_prompt_logprobs = ( - logits_output.normalized_prompt_logprobs.to( - "cpu", non_blocking=True - ) - ) next_token_ids = next_token_ids.to("cpu", non_blocking=True) copy_done.record() @@ -174,9 +174,6 @@ def resolve_batch_result(self, bid: int): logits_output.input_token_logprobs = ( logits_output.input_token_logprobs.tolist() ) - logits_output.normalized_prompt_logprobs = ( - logits_output.normalized_prompt_logprobs.tolist() - ) next_token_ids = next_token_ids.tolist() return logits_output, next_token_ids @@ -225,6 +222,10 @@ def update_weights_from_distributed( success, message = self.worker.update_weights_from_distributed(recv_req) return success, message + def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput): + success, message = self.worker.update_weights_from_tensor(recv_req) + return success, message + def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): return self.worker.get_weights_by_name(recv_req) diff --git a/python/sglang/srt/managers/utils.py b/python/sglang/srt/managers/utils.py new file mode 100644 index 00000000000..10a1209631e --- /dev/null +++ b/python/sglang/srt/managers/utils.py @@ -0,0 +1,44 @@ +import logging +from http import HTTPStatus +from typing import Optional + +from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req + +logger = logging.getLogger(__name__) + + +def validate_input_length( + req: Req, max_req_input_len: int, allow_auto_truncate: bool +) -> Optional[str]: + """Validate and potentially truncate input length. + + Args: + req: The request containing input_ids to validate + max_req_input_len: Maximum allowed input length + allow_auto_truncate: Whether to truncate long inputs + + Returns: + Error message if validation fails, None if successful + """ + if len(req.origin_input_ids) >= max_req_input_len: + if allow_auto_truncate: + logger.warning( + "Request length is longer than the KV cache pool size or " + "the max context length. Truncated. " + f"{len(req.origin_input_ids)=}, {max_req_input_len=}." + ) + req.origin_input_ids = req.origin_input_ids[:max_req_input_len] + return None + else: + error_msg = ( + f"Input length ({len(req.origin_input_ids)} tokens) exceeds " + f"the maximum allowed length ({max_req_input_len} tokens). " + f"Use a shorter input or enable --allow-auto-truncate." + ) + logger.error(error_msg) + req.finished_reason = FINISH_ABORT( + error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError" + ) + return error_msg + + return None diff --git a/python/sglang/srt/mem_cache/base_prefix_cache.py b/python/sglang/srt/mem_cache/base_prefix_cache.py index 2808ca872a5..9386595a8bd 100644 --- a/python/sglang/srt/mem_cache/base_prefix_cache.py +++ b/python/sglang/srt/mem_cache/base_prefix_cache.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable +from typing import Callable, List, Tuple class BasePrefixCache(ABC): @@ -10,7 +10,7 @@ def reset(self): pass @abstractmethod - def match_prefix(self, **kwargs): + def match_prefix(self, **kwargs) -> Tuple[List[int], int]: pass @abstractmethod @@ -41,6 +41,10 @@ def dec_lock_ref(self, node): def evictable_size(self): pass + @abstractmethod + def protected_size(self): + raise NotImplementedError() + def total_size(self): raise NotImplementedError() diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index 3c430aba368..b50199ca28a 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -2,7 +2,7 @@ """Cache for chunked prefill, used when RadixCache is disabled.""" -from typing import TYPE_CHECKING, Callable, List, Optional +from typing import TYPE_CHECKING, Callable, List, Optional, Tuple from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool @@ -30,7 +30,7 @@ def __init__( def reset(self): self.entries = {} - def match_prefix(self, rid: int, key: List[int]): + def match_prefix(self, rid: int, key: List[int]) -> Tuple[List[int], int]: if rid not in self.entries: return [], None @@ -85,3 +85,6 @@ def dec_lock_ref(self, node): def evictable_size(self): return 0 + + def protected_size(self): + return 0 diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 646e71749d8..7b9b35611d8 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -13,6 +13,8 @@ limitations under the License. """ +from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter + """ Memory pool. @@ -22,38 +24,48 @@ """ import logging -from typing import List, Tuple, Union +import threading +from enum import IntEnum +from functools import wraps +from typing import List, Optional, Tuple, Union +import numpy as np +import psutil import torch from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.utils import get_compiler_backend +from sglang.srt.utils import debug_timing, get_compiler_backend logger = logging.getLogger(__name__) +GB = 1024 * 1024 * 1024 + class ReqToTokenPool: """A memory pool that maps a request to its token locations.""" - def __init__(self, size: int, max_context_len: int, device: str, use_records: bool): + def __init__( + self, + size: int, + max_context_len: int, + device: str, + enable_memory_saver: bool, + ): + memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=enable_memory_saver + ) + self.size = size self.max_context_len = max_context_len self.device = device - self.req_to_token = torch.zeros( - (size, max_context_len), dtype=torch.int32, device=device - ) + with memory_saver_adapter.region(): + self.req_to_token = torch.zeros( + (size, max_context_len), dtype=torch.int32, device=device + ) self.free_slots = list(range(size)) - self.write_records = [] - self.use_records = use_records - - if self.use_records: - self.write = self.write_with_records - else: - self.write = self.write_without_records def write(self, indices, values): - # Keep the signature for type checking. It will be assigned during runtime. - raise NotImplementedError() + self.req_to_token[indices] = values def available_size(self): return len(self.free_slots) @@ -75,23 +87,6 @@ def free(self, free_index: Union[int, List[int]]): def clear(self): self.free_slots = list(range(self.size)) - self.write_records = [] - - def write_without_records(self, indices, values): - self.req_to_token[indices] = values - - def write_with_records(self, indices, values): - self.req_to_token[indices] = values - self.write_records.append((indices, values)) - - def get_write_records(self): - ret = self.write_records - self.write_records = [] - return ret - - def apply_write_records(self, write_records: List[Tuple]): - for indices, values in write_records: - self.req_to_token[indices] = values class BaseTokenToKVPool: @@ -105,8 +100,8 @@ def __init__( ): self.size = size self.dtype = dtype - if dtype == torch.float8_e5m2: - # NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2 + if dtype in (torch.float8_e5m2, torch.float8_e4m3fn): + # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2 self.store_dtype = torch.uint8 else: self.store_dtype = dtype @@ -182,27 +177,79 @@ def __init__( head_dim: int, layer_num: int, device: str, + enable_memory_saver: bool, ): super().__init__(size, dtype, device) - # [size, head_num, head_dim] for each layer - # The padded slot 0 is used for writing dummy outputs from padded tokens. - self.k_buffer = [ - torch.empty( - (size + 1, head_num, head_dim), - dtype=self.store_dtype, - device=device, - ) - for _ in range(layer_num) - ] - self.v_buffer = [ - torch.empty( - (size + 1, head_num, head_dim), - dtype=self.store_dtype, - device=device, - ) - for _ in range(layer_num) - ] + self.memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=enable_memory_saver + ) + + self.head_num = head_num + self.head_dim = head_dim + self.layer_num = layer_num + self._create_buffers() + + k_size, v_size = self.get_kv_size_bytes() + logger.info( + f"KV Cache is allocated. K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB." + ) + + def _create_buffers(self): + with self.memory_saver_adapter.region(): + # [size, head_num, head_dim] for each layer + # The padded slot 0 is used for writing dummy outputs from padded tokens. + self.k_buffer = [ + torch.empty( + (self.size + 1, self.head_num, self.head_dim), + dtype=self.store_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] + self.v_buffer = [ + torch.empty( + (self.size + 1, self.head_num, self.head_dim), + dtype=self.store_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] + + def _clear_buffers(self): + del self.k_buffer + del self.v_buffer + + def get_kv_size_bytes(self): + assert hasattr(self, "k_buffer") + assert hasattr(self, "v_buffer") + k_size_bytes = 0 + for k_cache in self.k_buffer: + k_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize + v_size_bytes = 0 + for v_cache in self.v_buffer: + v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize + return k_size_bytes, v_size_bytes + + # Todo: different memory layout + def get_flat_data(self, indices): + # prepare a large chunk of contiguous data for efficient transfer + flatten = torch.stack( + [ + torch.stack([self.k_buffer[i][indices] for i in range(self.layer_num)]), + torch.stack([self.v_buffer[i][indices] for i in range(self.layer_num)]), + ] + ) + return flatten + + @debug_timing + def transfer(self, indices, flat_data): + # transfer prepared data from host to device + flat_data = flat_data.to(device=self.device, non_blocking=False) + k_data, v_data = flat_data[0], flat_data[1] + for i in range(self.layer_num): + self.k_buffer[i][indices] = k_data[i] + self.v_buffer[i][indices] = v_data[i] def get_key_buffer(self, layer_id: int): if self.store_dtype != self.dtype: @@ -223,9 +270,15 @@ def set_kv_buffer( loc: torch.Tensor, cache_k: torch.Tensor, cache_v: torch.Tensor, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, ): layer_id = layer.layer_id if cache_k.dtype != self.dtype: + if k_scale is not None: + cache_k.div_(k_scale) + if v_scale is not None: + cache_v.div_(v_scale) cache_k = cache_k.to(self.dtype) cache_v = cache_v.to(self.dtype) if self.store_dtype != self.dtype: @@ -245,7 +298,6 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype): class MLATokenToKVPool(BaseTokenToKVPool): - def __init__( self, size: int, @@ -254,19 +306,26 @@ def __init__( qk_rope_head_dim: int, layer_num: int, device: str, + enable_memory_saver: bool, ): super().__init__(size, dtype, device) self.kv_lora_rank = kv_lora_rank - # The padded slot 0 is used for writing dummy outputs from padded tokens. - self.kv_buffer = [ - torch.empty( - (size + 1, 1, kv_lora_rank + qk_rope_head_dim), - dtype=self.store_dtype, - device=device, - ) - for _ in range(layer_num) - ] + + memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=enable_memory_saver + ) + + with memory_saver_adapter.region(): + # The padded slot 0 is used for writing dummy outputs from padded tokens. + self.kv_buffer = [ + torch.empty( + (size + 1, 1, kv_lora_rank + qk_rope_head_dim), + dtype=self.store_dtype, + device=device, + ) + for _ in range(layer_num) + ] def get_key_buffer(self, layer_id: int): if self.store_dtype != self.dtype: @@ -298,7 +357,6 @@ def set_kv_buffer( class DoubleSparseTokenToKVPool(BaseTokenToKVPool): - def __init__( self, size: int, @@ -308,26 +366,32 @@ def __init__( layer_num: int, device: str, heavy_channel_num: int, + enable_memory_saver: bool, ): super().__init__(size, dtype, device) - # [size, head_num, head_dim] for each layer - self.k_buffer = [ - torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device) - for _ in range(layer_num) - ] - self.v_buffer = [ - torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device) - for _ in range(layer_num) - ] - - # [size, head_num, heavy_channel_num] for each layer - self.label_buffer = [ - torch.empty( - (size + 1, head_num, heavy_channel_num), dtype=dtype, device=device - ) - for _ in range(layer_num) - ] + memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=enable_memory_saver + ) + + with memory_saver_adapter.region(): + # [size, head_num, head_dim] for each layer + self.k_buffer = [ + torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device) + for _ in range(layer_num) + ] + self.v_buffer = [ + torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device) + for _ in range(layer_num) + ] + + # [size, head_num, heavy_channel_num] for each layer + self.label_buffer = [ + torch.empty( + (size + 1, head_num, heavy_channel_num), dtype=dtype, device=device + ) + for _ in range(layer_num) + ] def get_key_buffer(self, layer_id: int): return self.k_buffer[layer_id] @@ -354,3 +418,184 @@ def set_kv_buffer( self.k_buffer[layer_id][loc] = cache_k self.v_buffer[layer_id][loc] = cache_v self.label_buffer[layer_id][loc] = cache_label + + +class MemoryStateInt(IntEnum): + IDLE = 0 + RESERVED = 1 + PROTECTED = 2 + SYNCED = 3 + BACKUP = 4 + + +def synchronized(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + with self.lock: + return func(self, *args, **kwargs) + + return wrapper + + +class MLATokenToKVPoolHost: + + def __init__( + self, + device_pool: MHATokenToKVPool, + host_to_device_ratio: float = 2.0, + pin_memory: bool = False, # no need to use pin memory with the double buffering + device: str = "cpu", + ): + assert ( + host_to_device_ratio >= 1 + ), "The host memory should be larger than the device memory with the current protocol" + # todo, other ways of configuring the size + + self.device_pool = device_pool + self.host_to_device_ratio = host_to_device_ratio + self.pin_memory = pin_memory + self.device = device + + self.size = int(device_pool.size * host_to_device_ratio) + self.dtype = device_pool.store_dtype + self.head_num = device_pool.head_num + self.head_dim = device_pool.head_dim + self.layer_num = device_pool.layer_num + self.size_per_token = ( + self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2 + ) + + # Verify there is enough available host memory. + host_mem = psutil.virtual_memory() + requested_bytes = self.size * self.size_per_token + # preserve at least 10GB for other usage + ten_gb = 10 * (1024**3) + if requested_bytes > host_mem.available - ten_gb: + raise ValueError( + f"Not enough host memory available. Requesting " + f"{requested_bytes / 1e9:.2f} GB but only have " + f"{host_mem.available / 1e9:.2f} GB free. Please reduce the " + f"size of the hierarchical cache." + ) + else: + logger.info( + f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache." + ) + + self.kv_buffer = torch.empty( + (2, self.layer_num, self.size, self.head_num, self.head_dim), + dtype=self.dtype, + device=self.device, + pin_memory=self.pin_memory, + ) + + # Initialize memory states and tracking structures. + self.mem_state = torch.zeros( + (self.size,), dtype=torch.uint8, device=self.device + ) + self.free_slots = torch.arange(self.size, dtype=torch.int32) + self.can_use_mem_size = self.size + + # A lock for synchronized operations on memory allocation and state transitions. + self.lock = threading.RLock() + + def get_flat_data(self, indices): + return self.kv_buffer[:, :, indices] + + @debug_timing + def transfer(self, indices, flat_data): + # backup prepared data from device to host + self.kv_buffer[:, :, indices] = flat_data.to( + device=self.device, non_blocking=False + ) + + @synchronized + def clear(self): + self.mem_state.fill_(0) + self.can_use_mem_size = self.size + self.free_slots = torch.arange(self.size, dtype=torch.int32) + + @synchronized + def get_state(self, indices: torch.Tensor) -> MemoryStateInt: + assert len(indices) > 0, "The indices should not be empty" + states = self.mem_state[indices] + assert ( + states == states[0] + ).all(), "The memory slots should have the same state {}".format(states) + return MemoryStateInt(states[0].item()) + + @synchronized + def alloc(self, need_size: int) -> torch.Tensor: + if need_size > self.can_use_mem_size: + return None + + # todo: de-fragementation + select_index = self.free_slots[:need_size] + self.free_slots = self.free_slots[need_size:] + + self.mem_state[select_index] = MemoryStateInt.RESERVED + self.can_use_mem_size -= need_size + + return select_index + + @synchronized + def is_reserved(self, indices: torch.Tensor) -> bool: + return self.get_state(indices) == MemoryStateInt.RESERVED + + @synchronized + def is_protected(self, indices: torch.Tensor) -> bool: + return self.get_state(indices) == MemoryStateInt.PROTECTED + + @synchronized + def is_synced(self, indices: torch.Tensor) -> bool: + return self.get_state(indices) == MemoryStateInt.SYNCED + + @synchronized + def is_backup(self, indices: torch.Tensor) -> bool: + return self.get_state(indices) == MemoryStateInt.BACKUP + + @synchronized + def update_backup(self, indices: torch.Tensor): + assert self.is_synced(indices), ( + f"The host memory slots should be in SYNCED state before turning into BACKUP. " + f"Current state: {self.get_state(indices)}" + ) + self.mem_state[indices] = MemoryStateInt.BACKUP + + @synchronized + def update_synced(self, indices: torch.Tensor): + self.mem_state[indices] = MemoryStateInt.SYNCED + + @synchronized + def protect_write(self, indices: torch.Tensor): + assert self.is_reserved(indices), ( + f"The host memory slots should be RESERVED before write operations. " + f"Current state: {self.get_state(indices)}" + ) + self.mem_state[indices] = MemoryStateInt.PROTECTED + + @synchronized + def protect_load(self, indices: torch.Tensor): + assert self.is_backup(indices), ( + f"The host memory slots should be in BACKUP state before load operations. " + f"Current state: {self.get_state(indices)}" + ) + self.mem_state[indices] = MemoryStateInt.PROTECTED + + @synchronized + def complete_io(self, indices: torch.Tensor): + assert self.is_protected(indices), ( + f"The host memory slots should be PROTECTED during I/O operations. " + f"Current state: {self.get_state(indices)}" + ) + self.mem_state[indices] = MemoryStateInt.SYNCED + + def available_size(self): + return len(self.free_slots) + + @synchronized + def free(self, indices: torch.Tensor) -> int: + self.mem_state[indices] = MemoryStateInt.IDLE + self.free_slots = torch.concat([self.free_slots, indices]) + self.can_use_mem_size += len(indices) + return len(indices) diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index 8cd8354b6b2..3bf87b54299 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -22,7 +22,7 @@ import heapq import time from collections import defaultdict -from typing import TYPE_CHECKING, Callable, List, Optional +from typing import TYPE_CHECKING, Callable, List, Optional, Tuple import torch @@ -34,7 +34,10 @@ class TreeNode: - def __init__(self): + + counter = 0 + + def __init__(self, id: Optional[int] = None): self.children = defaultdict(TreeNode) self.parent = None self.key = None @@ -42,6 +45,23 @@ def __init__(self): self.lock_ref = 0 self.last_access_time = time.time() + self.hit_count = 0 + # indicating the node is loading KV cache from host + self.loading = False + # store the host indices of KV cache + self.host_value = None + + self.id = TreeNode.counter if id is None else id + TreeNode.counter += 1 + + @property + def evicted(self): + return self.value is None + + @property + def backuped(self): + return self.host_value is not None + def __lt__(self, other: "TreeNode"): return self.last_access_time < other.last_access_time @@ -75,8 +95,19 @@ def reset(self): self.root_node.value = [] self.root_node.lock_ref = 1 self.evictable_size_ = 0 - - def match_prefix(self, key: List, **kwargs): + self.protected_size_ = 0 + + def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]: + """Find the matching prefix from the radix tree. + Args: + key: A list of token IDs to find a matching prefix. + Returns: + A tuple of a tensor of matching prefix token IDs and + the last node that contains the prefix values. Note that + this API can modify the internal state of the Radix tree. + The last node create a new child if the prefix is shorter + than the last node's value. + """ if self.disable: return [], self.root_node @@ -193,6 +224,7 @@ def inc_lock_ref(self, node: TreeNode): while node != self.root_node: if node.lock_ref == 0: self.evictable_size_ -= len(node.value) + self.protected_size_ += len(node.value) delta -= len(node.value) node.lock_ref += 1 node = node.parent @@ -206,6 +238,7 @@ def dec_lock_ref(self, node: TreeNode): while node != self.root_node: if node.lock_ref == 1: self.evictable_size_ += len(node.value) + self.protected_size_ -= len(node.value) delta += len(node.value) node.lock_ref -= 1 node = node.parent @@ -214,6 +247,10 @@ def dec_lock_ref(self, node: TreeNode): def evictable_size(self): return self.evictable_size_ + def protected_size(self): + # protected size refers to the size of the cache that is locked + return self.protected_size_ + ##### Internal Helper Functions ##### def _match_prefix_helper( @@ -293,6 +330,8 @@ def _delete_leaf(self, node): self.evictable_size_ -= len(node.key) def _total_size_helper(self, node: TreeNode): + if node.evicted: + return 0 x = len(node.value) for child in node.children.values(): x += self._total_size_helper(child) diff --git a/python/sglang/srt/metrics/collector.py b/python/sglang/srt/metrics/collector.py index d5ae98834b3..26eb2fc27d2 100644 --- a/python/sglang/srt/metrics/collector.py +++ b/python/sglang/srt/metrics/collector.py @@ -25,6 +25,7 @@ class SchedulerStats: gen_throughput: float = 0.0 num_queue_reqs: int = 0 cache_hit_rate: float = 0.0 + spec_accept_length: float = 0.0 class SchedulerMetricsCollector: @@ -37,42 +38,49 @@ def __init__(self, labels: Dict[str, str]) -> None: self.num_running_reqs = Gauge( name="sglang:num_running_reqs", - documentation="The number of running requests", + documentation="The number of running requests.", labelnames=labels.keys(), multiprocess_mode="sum", ) self.num_used_tokens = Gauge( name="sglang:num_used_tokens", - documentation="The number of used tokens", + documentation="The number of used tokens.", labelnames=labels.keys(), multiprocess_mode="sum", ) self.token_usage = Gauge( name="sglang:token_usage", - documentation="The token usage", + documentation="The token usage.", labelnames=labels.keys(), multiprocess_mode="mostrecent", ) self.gen_throughput = Gauge( name="sglang:gen_throughput", - documentation="The generate throughput (token/s)", + documentation="The generation throughput (token/s).", labelnames=labels.keys(), multiprocess_mode="sum", ) self.num_queue_reqs = Gauge( name="sglang:num_queue_reqs", - documentation="The number of requests in the waiting queue", + documentation="The number of requests in the waiting queue.", labelnames=labels.keys(), multiprocess_mode="sum", ) self.cache_hit_rate = Gauge( name="sglang:cache_hit_rate", - documentation="The cache hit rate", + documentation="The prefix cache hit rate.", + labelnames=labels.keys(), + multiprocess_mode="mostrecent", + ) + + self.spec_accept_length = Gauge( + name="sglang:spec_accept_length", + documentation="The average acceptance length of speculative decoding.", labelnames=labels.keys(), multiprocess_mode="mostrecent", ) @@ -88,6 +96,7 @@ def log_stats(self, stats: SchedulerStats) -> None: self._log_gauge(self.gen_throughput, stats.gen_throughput) self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs) self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate) + self._log_gauge(self.spec_accept_length, stats.spec_accept_length) class TokenizerMetricsCollector: @@ -109,31 +118,31 @@ def __init__(self, labels: Dict[str, str]) -> None: labelnames=labels.keys(), ) + self.num_requests_total = Counter( + name="sglang:num_requests_total", + documentation="Number of requests processed.", + labelnames=labels.keys(), + ) + self.histogram_time_to_first_token = Histogram( name="sglang:time_to_first_token_seconds", documentation="Histogram of time to first token in seconds.", labelnames=labels.keys(), buckets=[ - 0.001, - 0.005, - 0.01, - 0.02, - 0.04, - 0.06, - 0.08, 0.1, 0.25, 0.5, 0.75, - 1.0, - 2.5, - 5.0, - 7.5, - 10.0, - 15.0, - 20.0, - 25.0, - 30.0, + 1, + 2, + 5, + 10, + 20, + 40, + 60, + 80, + 120, + 160, ], ) @@ -168,21 +177,19 @@ def __init__(self, labels: Dict[str, str]) -> None: documentation="Histogram of End-to-end request latency in seconds", labelnames=labels.keys(), buckets=[ - 0.3, + 0.1, + 0.25, 0.5, - 0.8, - 1.0, - 1.5, - 2.0, - 2.5, - 5.0, - 10.0, - 15.0, - 20.0, - 30.0, - 40.0, - 50.0, - 60.0, + 1, + 2, + 5, + 10, + 20, + 40, + 60, + 80, + 120, + 160, ], ) @@ -193,11 +200,10 @@ def _log_counter(self, counter, data: Union[int, float]) -> None: # Convenience function for logging to counter. counter.labels(**self.labels).inc(data) - def inc_prompt_tokens(self, value: int): - self._log_counter(self.prompt_tokens_total, value) - - def inc_generation_tokens(self, value: int): - self._log_counter(self.generation_tokens_total, value) + def observe_one_finished_request(self, prompt_tokens: int, generation_tokens: int): + self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens) + self.generation_tokens_total.labels(**self.labels).inc(generation_tokens) + self.num_requests_total.labels(**self.labels).inc(1) def observe_time_to_first_token(self, value: Union[float, int]): self._log_histogram(self.histogram_time_to_first_token, value) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 27043cc9a7d..93b4d0ea57a 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -20,23 +20,25 @@ from typing import TYPE_CHECKING, Callable import torch -from vllm.distributed.parallel_state import graph_capture +import tqdm from vllm.model_executor.custom_op import CustomOp -from sglang.srt.layers.fused_moe_patch import fused_moe_forward_native -from sglang.srt.layers.logits_processor import ( - LogitsMetadata, - LogitsProcessor, - LogitsProcessorOutput, +from sglang.srt.distributed import get_tensor_model_parallel_rank +from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native +from sglang.srt.layers.torchao_utils import save_gemlite_cache +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, ) -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode -from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner -def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int): +def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): for sub in model._modules.values(): if isinstance(sub, CustomOp): if reverse: @@ -45,7 +47,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int): else: # NOTE: Temporarily workaround MoE if "FusedMoE" in sub.__class__.__name__: - if batch_size == 1: + if num_tokens == 1: # The performance of torch.compile on this layer is not always good when bs > 1, # so we decide to only use torch.compile when bs =1 sub._forward_method = fused_moe_forward_native @@ -53,23 +55,22 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int): sub._forward_method = sub.forward_native setattr(sub, "is_torch_compile", True) if isinstance(sub, torch.nn.Module): - _to_torch(sub, reverse, batch_size) + _to_torch(sub, reverse, num_tokens) @contextmanager def patch_model( model: torch.nn.Module, enable_compile: bool, - batch_size: int, - tp_group: "GroupCoordinator", + num_tokens: int, + tp_group: GroupCoordinator, ): """Patch the model to make it compatible with with torch.compile""" backup_ca_comm = None try: if enable_compile: - _to_torch(model, reverse=False, batch_size=batch_size) - monkey_patch_vllm_all_gather() + _to_torch(model, reverse=False, num_tokens=num_tokens) backup_ca_comm = tp_group.ca_comm # Use custom-allreduce here. # We found the custom allreduce is much faster than the built-in allreduce in torch, @@ -84,8 +85,7 @@ def patch_model( yield model.forward finally: if enable_compile: - _to_torch(model, reverse=True, batch_size=batch_size) - monkey_patch_vllm_all_gather(reverse=True) + _to_torch(model, reverse=True, num_tokens=num_tokens) tp_group.ca_comm = backup_ca_comm @@ -103,11 +103,6 @@ def set_torch_compile_config(): torch._dynamo.config.cache_size_limit = 1024 -@maybe_torch_compile(dynamic=True) -def clamp_position(seq_lens): - return torch.clamp((seq_lens - 1), min=0).to(torch.int64) - - class CudaGraphRunner: """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile.""" @@ -124,12 +119,15 @@ def __init__(self, model_runner: "ModelRunner"): self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention self.tp_size = self.model_runner.tp_size + self.dp_size = self.model_runner.server_args.dp_size # Batch sizes to capture - if model_runner.server_args.disable_cuda_graph_padding: - self.capture_bs = list(range(1, 32)) + [64, 128] - else: - self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] + self.capture_bs = self.model_runner.server_args.cuda_graph_bs + if self.capture_bs is None: + if model_runner.server_args.disable_cuda_graph_padding: + self.capture_bs = list(range(1, 33)) + [64, 128] + else: + self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] if max(self.capture_bs) > model_runner.req_to_token_pool.size: # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests @@ -150,6 +148,7 @@ def __init__(self, model_runner: "ModelRunner"): if bs <= model_runner.req_to_token_pool.size and bs <= model_runner.server_args.cuda_graph_max_bs ] + self.compile_bs = ( [ bs @@ -160,14 +159,26 @@ def __init__(self, model_runner: "ModelRunner"): else [] ) + self.capture_forward_mode = ForwardMode.DECODE + self.num_tokens_per_bs = 1 + if model_runner.spec_algorithm.is_eagle(): + if self.model_runner.is_draft_worker: + self.num_tokens_per_bs = ( + self.model_runner.server_args.speculative_eagle_topk + ) + else: + self.capture_forward_mode = ForwardMode.TARGET_VERIFY + self.num_tokens_per_bs = ( + self.model_runner.server_args.speculative_num_draft_tokens + ) + # Attention backend self.max_bs = max(self.capture_bs) - self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs) - + self.max_num_token = self.max_bs * self.num_tokens_per_bs + self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token) self.seq_len_fill_value = ( self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() ) - # FIXME(lsyin): leave it here for now, I don't know whether it is necessary self.encoder_len_fill_value = 0 @@ -176,13 +187,21 @@ def __init__(self, model_runner: "ModelRunner"): # Common inputs with torch.device("cuda"): - self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32) + self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) self.seq_lens = torch.full( (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 ) - self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32) - self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32) + self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64) + self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) + self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64) + + # Speculative_inference + if model_runner.spec_algorithm.is_eagle(): + self.hidden_states = torch.zeros( + (self.max_num_token, self.model_runner.model_config.hidden_size), + dtype=self.model_runner.dtype, + ) if self.is_encoder_decoder: # NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch @@ -195,7 +214,7 @@ def __init__(self, model_runner: "ModelRunner"): if self.enable_dp_attention: self.gathered_buffer = torch.zeros( ( - self.max_bs * self.tp_size, + self.max_bs * self.dp_size, self.model_runner.model_config.hidden_size, ), dtype=self.model_runner.dtype, @@ -255,12 +274,17 @@ def can_run(self, forward_batch: ForwardBatch): def capture(self): with graph_capture() as graph_capture_context: self.stream = graph_capture_context.stream - for bs in self.capture_bs: + capture_range = ( + tqdm.tqdm(self.capture_bs) + if get_tensor_model_parallel_rank() == 0 + else self.capture_bs + ) + for bs in capture_range: with patch_model( self.model_runner.model, bs in self.compile_bs, - bs, - self.model_runner.tp_group, + num_tokens=bs * self.num_tokens_per_bs, + tp_group=self.model_runner.tp_group, ) as forward: ( graph, @@ -269,21 +293,24 @@ def capture(self): self.graphs[bs] = graph self.output_buffers[bs] = output_buffers + # Save gemlite cache after each capture + save_gemlite_cache() + def capture_one_batch_size(self, bs: int, forward: Callable): graph = torch.cuda.CUDAGraph() stream = self.stream + num_tokens = bs * self.num_tokens_per_bs # Common inputs - input_ids = self.input_ids[:bs] + input_ids = self.input_ids[:num_tokens] req_pool_indices = self.req_pool_indices[:bs] seq_lens = self.seq_lens[:bs] - out_cache_loc = self.out_cache_loc[:bs] + out_cache_loc = self.out_cache_loc[:num_tokens] + positions = self.positions[:num_tokens] if self.is_encoder_decoder: encoder_lens = self.encoder_lens[:bs] else: encoder_lens = None - - seq_lens_sum = seq_lens.sum().item() mrope_positions = self.mrope_positions[:, :bs] if self.enable_dp_attention: @@ -293,37 +320,48 @@ def capture_one_batch_size(self, bs: int, forward: Callable): global_num_tokens = None gathered_buffer = None + spec_info = self.get_spec_info(num_tokens, positions) + + forward_batch = ForwardBatch( + forward_mode=self.capture_forward_mode, + batch_size=bs, + input_ids=input_ids, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool=self.model_runner.token_to_kv_pool, + attn_backend=self.model_runner.attn_backend, + out_cache_loc=out_cache_loc, + seq_lens_sum=seq_lens.sum(), + encoder_lens=encoder_lens, + return_logprob=False, + top_logprobs_nums=[0] * bs, + positions=positions, + global_num_tokens=global_num_tokens, + gathered_buffer=gathered_buffer, + mrope_positions=mrope_positions, + spec_algorithm=self.model_runner.spec_algorithm, + spec_info=spec_info, + capture_hidden_mode=( + spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL + ), + ) + # Attention backend self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( bs, + num_tokens, req_pool_indices, seq_lens, encoder_lens, + forward_batch.forward_mode, + forward_batch.spec_info, ) # Run and capture def run_once(): - forward_batch = ForwardBatch( - forward_mode=ForwardMode.DECODE, - batch_size=bs, - input_ids=input_ids, - req_pool_indices=req_pool_indices, - seq_lens=seq_lens, - req_to_token_pool=self.model_runner.req_to_token_pool, - token_to_kv_pool=self.model_runner.token_to_kv_pool, - attn_backend=self.model_runner.attn_backend, - out_cache_loc=out_cache_loc, - seq_lens_sum=seq_lens_sum, - encoder_lens=encoder_lens, - return_logprob=False, - top_logprobs_nums=[0] * bs, - positions=clamp_position(seq_lens), - mrope_positions=mrope_positions, - global_num_tokens=global_num_tokens, - gathered_buffer=gathered_buffer, - ) logits_output = forward(input_ids, forward_batch.positions, forward_batch) - return logits_output.next_token_logits + return logits_output.next_token_logits, logits_output.hidden_states for _ in range(2): torch.cuda.synchronize() @@ -349,6 +387,7 @@ def run_once(): def replay(self, forward_batch: ForwardBatch): assert forward_batch.out_cache_loc is not None raw_bs = forward_batch.batch_size + raw_num_token = raw_bs * self.num_tokens_per_bs # Pad if self.enable_dp_attention: @@ -363,15 +402,20 @@ def replay(self, forward_batch: ForwardBatch): self.out_cache_loc.zero_() # Common inputs - self.input_ids[:raw_bs].copy_(forward_batch.input_ids) + self.input_ids[:raw_num_token].copy_(forward_batch.input_ids) self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) - self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc) + self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) + self.positions[:raw_num_token].copy_(forward_batch.positions) + if self.is_encoder_decoder: self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) if forward_batch.mrope_positions is not None: self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) + if hasattr(forward_batch.spec_info, "hidden_states"): + self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states + # Attention backend self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( bs, @@ -379,33 +423,51 @@ def replay(self, forward_batch: ForwardBatch): self.seq_lens, forward_batch.seq_lens_sum + (bs - raw_bs), self.encoder_lens, + forward_batch.forward_mode, + forward_batch.spec_info, ) # Replay self.graphs[bs].replay() - next_token_logits = self.output_buffers[bs][:raw_bs] + next_token_logits, hidden_states = self.output_buffers[bs] - # Extract logprobs - if forward_batch.return_logprob: - next_token_logprobs = torch.nn.functional.log_softmax( - next_token_logits, dim=-1 - ) - logits_output = LogitsProcessorOutput( - next_token_logits=next_token_logits, - next_token_logprobs=next_token_logprobs, + logits_output = LogitsProcessorOutput( + next_token_logits=next_token_logits[:raw_num_token], + hidden_states=( + hidden_states[:raw_num_token] if hidden_states is not None else None + ), + ) + return logits_output + + def get_spec_info(self, num_tokens: int, positions: torch.Tensor): + spec_info = None + if self.model_runner.spec_algorithm.is_eagle(): + from sglang.srt.speculative.eagle_utils import ( + EAGLEDraftInput, + EagleVerifyInput, ) - return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums) - if return_top_logprob: - logits_metadata = LogitsMetadata( - forward_mode=ForwardMode.DECODE, - top_logprobs_nums=forward_batch.top_logprobs_nums, + + if self.model_runner.is_draft_worker: + spec_info = EAGLEDraftInput() + spec_info.load_server_args(self.model_runner.server_args) + spec_info.hidden_states = self.hidden_states[:num_tokens] + spec_info.positions = positions + spec_info.capture_hidden_mode = CaptureHiddenMode.FULL + else: + spec_info = EagleVerifyInput( + None, + None, + None, + None, + None, + None, + self.model_runner.server_args.speculative_num_draft_tokens, ) - logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs( - next_token_logprobs, logits_metadata - )[1] - else: - logits_output = LogitsProcessorOutput( - next_token_logits=next_token_logits, - ) + spec_info.custom_mask = torch.zeros( + (num_tokens * self.model_runner.model_config.context_len), + dtype=torch.bool, + device="cuda", + ) + spec_info.capture_hidden_mode = CaptureHiddenMode.FULL - return logits_output + return spec_info diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 3a5519956fe..8bd1052754c 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -38,6 +38,7 @@ import triton.language as tl from sglang.srt.layers.rotary_embedding import MRotaryEmbedding +from sglang.srt.utils import get_compiler_backend if TYPE_CHECKING: from sglang.srt.layers.attention import AttentionBackend @@ -45,6 +46,7 @@ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo + from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm class ForwardMode(IntEnum): @@ -59,6 +61,11 @@ class ForwardMode(IntEnum): # No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence are allocated. IDLE = auto() + # Used in speculative decoding: verify a batch in the target model. + TARGET_VERIFY = auto() + # Used in speculative decoding: extend a batch in the draft model. + DRAFT_EXTEND = auto() + # A dummy first batch to start the pipeline for overlap scheduler. # It is now used for triggering the sampling_info_done event for the first prefill batch. DUMMY_FIRST = auto() @@ -67,7 +74,12 @@ def is_prefill(self): return self == ForwardMode.PREFILL def is_extend(self): - return self == ForwardMode.EXTEND or self == ForwardMode.MIXED + return ( + self == ForwardMode.EXTEND + or self == ForwardMode.MIXED + or self == ForwardMode.DRAFT_EXTEND + or self == self.TARGET_VERIFY + ) def is_decode(self): return self == ForwardMode.DECODE @@ -78,9 +90,40 @@ def is_mixed(self): def is_idle(self): return self == ForwardMode.IDLE + def is_target_verify(self): + return self == ForwardMode.TARGET_VERIFY + + def is_draft_extend(self): + return self == ForwardMode.DRAFT_EXTEND + + def is_cuda_graph(self): + return ( + self == ForwardMode.DECODE + or self == ForwardMode.TARGET_VERIFY + or self == ForwardMode.IDLE + ) + def is_dummy_first(self): return self == ForwardMode.DUMMY_FIRST + def is_decode_or_idle(self): + return self == ForwardMode.DECODE or self == ForwardMode.IDLE + + +class CaptureHiddenMode(IntEnum): + NULL = auto() + FULL = auto() + LAST = auto() + + def need_capture(self): + return self != CaptureHiddenMode.NULL + + def is_full(self): + return self == CaptureHiddenMode.FULL + + def is_last(self): + return self == CaptureHiddenMode.LAST + @dataclass class ForwardBatch: @@ -141,14 +184,19 @@ class ForwardBatch: token_to_kv_pool: BaseTokenToKVPool = None attn_backend: AttentionBackend = None - # For Qwen2-VL - mrope_positions: torch.Tensor = None - # For DP attention global_num_tokens: Optional[List[int]] = None gathered_buffer: Optional[torch.Tensor] = None can_run_dp_cuda_graph: bool = False + # Speculative decoding + spec_info: SpecInfo = None + spec_algorithm: SpeculativeAlgorithm = None + capture_hidden_mode: CaptureHiddenMode = None + + # For Qwen2-VL + mrope_positions: torch.Tensor = None + def compute_mrope_positions( self, model_runner: ModelRunner, batch: ModelWorkerBatch ): @@ -234,6 +282,12 @@ def init_new( can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph, lora_paths=batch.lora_paths, sampling_info=batch.sampling_info, + req_to_token_pool=model_runner.req_to_token_pool, + token_to_kv_pool=model_runner.token_to_kv_pool, + attn_backend=model_runner.attn_backend, + spec_algorithm=batch.spec_algorithm, + spec_info=batch.spec_info, + capture_hidden_mode=batch.capture_hidden_mode, input_embeds=batch.input_embeds, ) @@ -246,10 +300,21 @@ def init_new( ) if ret.forward_mode.is_idle(): + ret.positions = torch.empty((0,), device=device) return ret + # Override the positions with spec_info + if ( + ret.spec_info is not None + and getattr(ret.spec_info, "positions", None) is not None + ): + ret.positions = ret.spec_info.positions + # Init position information - if not ret.forward_mode.is_decode(): + if ret.forward_mode.is_decode(): + if ret.positions is None: + ret.positions = clamp_position(batch.seq_lens) + else: ret.extend_seq_lens = torch.tensor( batch.extend_seq_lens, dtype=torch.int32 ).to(device, non_blocking=True) @@ -258,13 +323,15 @@ def init_new( ).to(device, non_blocking=True) if model_runner.server_args.attention_backend != "torch_native": ret.extend_num_tokens = batch.extend_num_tokens - ret.positions, ret.extend_start_loc = compute_position_triton( + positions, ret.extend_start_loc = compute_position_triton( ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens ) else: - ret.positions, ret.extend_start_loc = compute_position_torch( + positions, ret.extend_start_loc = compute_position_torch( ret.extend_prefix_lens, ret.extend_seq_lens ) + if ret.positions is None: + ret.positions = positions ret.extend_prefix_lens_cpu = batch.extend_prefix_lens ret.extend_seq_lens_cpu = batch.extend_seq_lens ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens @@ -272,11 +339,6 @@ def init_new( if model_runner.model_is_mrope: ret.compute_mrope_positions(model_runner, batch) - # Init attention information - ret.req_to_token_pool = model_runner.req_to_token_pool - ret.token_to_kv_pool = model_runner.token_to_kv_pool - ret.attn_backend = model_runner.attn_backend - # Init lora information if model_runner.server_args.lora_paths is not None: model_runner.lora_manager.prepare_lora_batch(ret) @@ -351,3 +413,8 @@ def compute_position_torch( extend_start_loc = torch.zeros_like(extend_seq_lens) extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0) return positions.to(torch.int64), extend_start_loc + + +@torch.compile(dynamic=True, backend=get_compiler_backend()) +def clamp_position(seq_lens): + return torch.clamp((seq_lens - 1), min=0).to(torch.int64) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 3f0cbecac15..6fa1429dc2c 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -17,24 +17,30 @@ import json import logging import time -from typing import Optional +from typing import List, Optional, Tuple import torch import torch.distributed as dist -from vllm.distributed import ( + +from sglang.srt.configs.device_config import DeviceConfig +from sglang.srt.configs.load_config import LoadConfig +from sglang.srt.configs.model_config import AttentionArch, ModelConfig +from sglang.srt.distributed import ( get_tp_group, init_distributed_environment, initialize_model_parallel, set_custom_all_reduce, ) - -from sglang.srt.configs.device_config import DeviceConfig -from sglang.srt.configs.load_config import LoadConfig -from sglang.srt.configs.model_config import AttentionArch, ModelConfig +from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend from sglang.srt.layers.attention.triton_backend import TritonAttnBackend +from sglang.srt.layers.dp_attention import ( + get_attention_tp_group, + get_attention_tp_size, + initialize_dp_attention, +) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model @@ -48,15 +54,17 @@ ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader import get_model -from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.server_args import ServerArgs +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import ( enable_show_time_cost, get_available_gpu_memory, init_custom_process_group, + is_cuda, is_hip, + monkey_patch_p2p_access_check, monkey_patch_vllm_gguf_config, - monkey_patch_vllm_p2p_access_check, set_cpu_offload_max_bytes, ) @@ -75,6 +83,7 @@ def __init__( tp_size: int, nccl_port: int, server_args: ServerArgs, + is_draft_worker: bool = False, ): # Parse args self.model_config = model_config @@ -85,16 +94,23 @@ def __init__( self.tp_size = tp_size self.dist_port = nccl_port self.server_args = server_args + self.is_draft_worker = is_draft_worker self.is_generation = model_config.is_generation self.is_multimodal = model_config.is_multimodal + self.should_log = tp_rank == 0 + self.spec_algorithm = SpeculativeAlgorithm.from_string( + server_args.speculative_algorithm + ) # Model-specific adjustment if ( self.model_config.attention_arch == AttentionArch.MLA and not self.server_args.disable_mla ): - logger.info("MLA optimization is turned on. Use triton backend.") - self.server_args.attention_backend = "triton" + # TODO: add MLA optimization on CPU + if self.server_args.device != "cpu": + logger.info("MLA optimization is turned on. Use triton backend.") + self.server_args.attention_backend = "triton" if self.server_args.enable_double_sparsity: logger.info( @@ -111,17 +127,26 @@ def __init__( ) if self.is_multimodal: - server_args.chunked_prefill_size = -1 self.mem_fraction_static *= 0.95 logger.info( - f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static} " - f"and turn off chunked prefill " + f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} " f"because this is a multimodal model." ) - # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically + + if self.model_config.hf_config.architectures == [ + "MllamaForConditionalGeneration" + ]: + logger.info("Automatically turn off --chunked-prefill-size for mllama.") + server_args.chunked_prefill_size = -1 + if self.model_config.hf_config.architectures == [ "Qwen2VLForConditionalGeneration" ]: + # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically + logger.info( + "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl." + ) + server_args.chunked_prefill_size = -1 server_args.disable_radix_cache = True # Global vars @@ -142,6 +167,7 @@ def __init__( "enable_nan_detection": server_args.enable_nan_detection, "enable_dp_attention": server_args.enable_dp_attention, "enable_ep_moe": server_args.enable_ep_moe, + "device": server_args.device, } ) @@ -150,10 +176,22 @@ def __init__( # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() + self.memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=self.server_args.enable_memory_saver + ) + # Load the model self.sampler = Sampler() self.load_model() + # Apply torchao quantization + torchao_applied = getattr(self.model, "torchao_applied", False) + # In layered loading, torchao may have been applied + if not torchao_applied: + apply_torchao_config_to_model( + self.model, global_server_args_dict["torchao_config"] + ) + # Apply torch TP if the model supports it supports_torch_tp = getattr(self.model, "supports_torch_tp", False) if self.tp_size > 1 and supports_torch_tp: @@ -162,10 +200,6 @@ def __init__( else: self.torch_tp_applied = False - apply_torchao_config_to_model( - self.model, global_server_args_dict["torchao_config"] - ) - # Init memory pool and attention backends if server_args.lora_paths is not None: self.init_lora_manager() @@ -184,36 +218,50 @@ def __init__( def init_torch_distributed(self): logger.info("Init torch distributed begin.") - # Init torch distributed + torch.get_device_module(self.device).set_device(self.gpu_id) if self.device == "cuda": backend = "nccl" - # ToDO(liangan1):Just use gloo to bypass the initilization fail - # Need to use xccl for xpu backend in the future elif self.device == "xpu": + # TODO(liangan1): Just use gloo to bypass the initilization fail + # Need to use xccl for xpu backend in the future backend = "gloo" elif self.device == "hpu": backend = "hccl" + elif self.device == "cpu": + backend = "gloo" if not self.server_args.enable_p2p_check: - monkey_patch_vllm_p2p_access_check(self.gpu_id) + monkey_patch_p2p_access_check() + if self.server_args.dist_init_addr: dist_init_method = f"tcp://{self.server_args.dist_init_addr}" else: dist_init_method = f"tcp://127.0.0.1:{self.dist_port}" set_custom_all_reduce(not self.server_args.disable_custom_all_reduce) - init_distributed_environment( - backend=backend, - world_size=self.tp_size, - rank=self.tp_rank, - local_rank=self.gpu_id, - distributed_init_method=dist_init_method, - ) - initialize_model_parallel(tensor_model_parallel_size=self.tp_size) + + if not self.is_draft_worker: + # Only initialize the distributed environment on the target model worker. + init_distributed_environment( + backend=backend, + world_size=self.tp_size, + rank=self.tp_rank, + local_rank=self.gpu_id, + distributed_init_method=dist_init_method, + ) + initialize_model_parallel(tensor_model_parallel_size=self.tp_size) + initialize_dp_attention( + enable_dp_attention=self.server_args.enable_dp_attention, + tp_rank=self.tp_rank, + tp_size=self.tp_size, + dp_size=self.server_args.dp_size, + ) + min_per_gpu_memory = get_available_gpu_memory( self.device, self.gpu_id, distributed=self.tp_size > 1 ) self.tp_group = get_tp_group() + self.attention_tp_group = get_attention_tp_group() # Check memory for tensor parallelism if self.tp_size > 1: @@ -231,7 +279,8 @@ def load_model(self): ) # This can reduce thread conflicts and speed up weight loading. - torch.set_num_threads(1) + if self.device != "cpu": + torch.set_num_threads(1) if self.device == "cuda": if torch.cuda.get_device_capability()[0] < 8: logger.info( @@ -242,20 +291,49 @@ def load_model(self): if torch.cuda.get_device_capability()[1] < 5: raise RuntimeError("SGLang only supports sm75 and above.") - # Prepare the vllm model config + # Prepare the model config self.load_config = LoadConfig( load_format=self.server_args.load_format, download_dir=self.server_args.download_dir, ) - if self.server_args.load_format == "gguf": monkey_patch_vllm_gguf_config() - self.model = get_model( - model_config=self.model_config, - load_config=self.load_config, - device_config=DeviceConfig(self.device), - ) + # Load the model + # Remove monkey_patch when linear.py quant remove dependencies with vllm + monkey_patch_vllm_parallel_state() + with self.memory_saver_adapter.region(): + self.model = get_model( + model_config=self.model_config, + load_config=self.load_config, + device_config=DeviceConfig(self.device), + ) + monkey_patch_vllm_parallel_state(reverse=True) + + if self.server_args.kv_cache_dtype == "fp8_e4m3": + if self.server_args.quantization_param_path is not None: + if callable(getattr(self.model, "load_kv_cache_scales", None)): + self.model.load_kv_cache_scales( + self.server_args.quantization_param_path + ) + logger.info( + "Loaded KV cache scaling factors from %s", + self.server_args.quantization_param_path, + ) + else: + raise RuntimeError( + "Using FP8 KV cache and scaling factors provided but " + "model %s does not support loading scaling factors.", + self.model.__class__, + ) + else: + logger.warning( + "Using FP8 KV cache but no scaling factors " + "provided. Defaulting to scaling factors of 1.0. " + "This may lead to less accurate results!" + ) + + # Parse other args self.sliding_window_size = ( self.model.get_attention_sliding_window_size() if hasattr(self.model, "get_attention_sliding_window_size") @@ -270,8 +348,10 @@ def load_model(self): f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB" ) - def update_weights_from_disk(self, model_path: str, load_format: str): - """Update engine weights online from disk.""" + def update_weights_from_disk( + self, model_path: str, load_format: str + ) -> tuple[bool, str]: + """Update engine weights in-place from the disk.""" from sglang.srt.model_loader.loader import ( DefaultModelLoader, device_loading_context, @@ -369,7 +449,7 @@ def init_weights_update_group( logger.info( f"init custom process group: master_address={master_address}, master_port={master_port}, " - f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}" + f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}" ) try: @@ -400,7 +480,6 @@ def update_weights_from_distributed(self, name, dtype, shape): target_dtype = ( dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype) ) - current_dtype = self.dtype if isinstance(self.dtype, str) else self.dtype assert ( self._model_update_group is not None @@ -421,6 +500,10 @@ def update_weights_from_distributed(self, name, dtype, shape): logger.error(error_msg) return False, error_msg + def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]): + self.model.load_weights(named_tensors) + return True, "Success" + def get_weights_by_name( self, name: str, truncate_size: int = 100 ) -> Optional[torch.Tensor]: @@ -464,7 +547,7 @@ def profile_max_num_token(self, total_gpu_memory: int): ) else: cell_size = ( - self.model_config.get_num_kv_heads(self.tp_size) + self.model_config.get_num_kv_heads(get_attention_tp_size()) * self.model_config.head_dim * self.model_config.num_hidden_layers * 2 @@ -489,12 +572,37 @@ def init_memory_pool( self.kv_cache_dtype = torch.float8_e5m2fnuz else: self.kv_cache_dtype = torch.float8_e5m2 + elif self.server_args.kv_cache_dtype == "fp8_e4m3": + if is_cuda(): + self.kv_cache_dtype = torch.float8_e4m3fn else: raise ValueError( f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}." ) self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory) + + if max_num_reqs is None: + max_num_reqs = min( + max( + int( + self.max_total_num_tokens / self.model_config.context_len * 512 + ), + 2048, + ), + 4096, + ) + + if not self.spec_algorithm.is_none(): + if self.is_draft_worker: + self.max_total_num_tokens = self.server_args.draft_runner_cache_size + else: + self.server_args.draft_runner_cache_size = ( + self.max_total_num_tokens + + max_num_reqs * self.server_args.speculative_num_steps + + 100 + ) + if max_total_tokens is not None: if max_total_tokens > self.max_total_num_tokens: logging.warning( @@ -509,22 +617,11 @@ def init_memory_pool( "Not enough memory. Please try to increase --mem-fraction-static." ) - if max_num_reqs is None: - max_num_reqs = min( - max( - int( - self.max_total_num_tokens / self.model_config.context_len * 512 - ), - 2048, - ), - 4096, - ) - self.req_to_token_pool = ReqToTokenPool( size=max_num_reqs + 1, max_context_len=self.model_config.context_len + 4, device=self.device, - use_records=False, + enable_memory_saver=self.server_args.enable_memory_saver, ) if ( self.model_config.attention_arch == AttentionArch.MLA @@ -537,25 +634,28 @@ def init_memory_pool( qk_rope_head_dim=self.model_config.qk_rope_head_dim, layer_num=self.model_config.num_hidden_layers, device=self.device, + enable_memory_saver=self.server_args.enable_memory_saver, ) elif self.server_args.enable_double_sparsity: self.token_to_kv_pool = DoubleSparseTokenToKVPool( self.max_total_num_tokens, dtype=self.kv_cache_dtype, - head_num=self.model_config.get_num_kv_heads(self.tp_size), + head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()), head_dim=self.model_config.head_dim, layer_num=self.model_config.num_hidden_layers, device=self.device, heavy_channel_num=self.server_args.ds_heavy_channel_num, + enable_memory_saver=self.server_args.enable_memory_saver, ) else: self.token_to_kv_pool = MHATokenToKVPool( self.max_total_num_tokens, dtype=self.kv_cache_dtype, - head_num=self.model_config.get_num_kv_heads(self.tp_size), + head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()), head_dim=self.model_config.head_dim, layer_num=self.model_config.num_hidden_layers, device=self.device, + enable_memory_saver=self.server_args.enable_memory_saver, ) logger.info( f"Memory pool end. " @@ -596,7 +696,6 @@ def init_attention_backend(self): ) def init_double_sparsity_channel_config(self, selected_channel): - selected_channel = "." + selected_channel + "_proj" self.sorted_channels = [] # load channel config @@ -639,10 +738,6 @@ def apply_torch_tp(self): tensor_parallel(self.model, device_mesh) def forward_decode(self, forward_batch: ForwardBatch): - if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch): - return self.cuda_graph_runner.replay(forward_batch) - - forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64) self.attn_backend.init_forward_metadata(forward_batch) return self.model.forward( forward_batch.input_ids, forward_batch.positions, forward_batch @@ -672,14 +767,18 @@ def forward_extend(self, forward_batch: ForwardBatch): ) def forward_idle(self, forward_batch: ForwardBatch): - if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch): - return self.cuda_graph_runner.replay(forward_batch) - return self.model.forward( forward_batch.input_ids, forward_batch.positions, forward_batch ) def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput: + if ( + forward_batch.forward_mode.is_cuda_graph() + and self.cuda_graph_runner + and self.cuda_graph_runner.can_run(forward_batch) + ): + return self.cuda_graph_runner.replay(forward_batch) + if forward_batch.forward_mode.is_decode(): return self.forward_decode(forward_batch) elif forward_batch.forward_mode.is_extend(): @@ -687,11 +786,12 @@ def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput: elif forward_batch.forward_mode.is_idle(): return self.forward_idle(forward_batch) else: - raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}") + raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}") def sample( self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch ) -> torch.Tensor: + # Apply logit bias sampling_info = forward_batch.sampling_info if sampling_info.sampling_info_done: # Overlap mode: the function update_regex_vocab_mask was executed @@ -702,35 +802,17 @@ def sample( # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass. sampling_info.update_regex_vocab_mask() sampling_info.update_penalties() - logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info) - - # Sample the next tokens. - next_token_ids = self.sampler(logits, sampling_info) + sampling_info.apply_logits_bias(logits_output.next_token_logits) + + # Sample the next tokens + next_token_ids = self.sampler( + logits_output, + sampling_info, + forward_batch.return_logprob, + forward_batch.top_logprobs_nums, + ) return next_token_ids - def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo): - # Apply logit_bias - if sampling_info.logit_bias is not None: - logits.add_(sampling_info.logit_bias) - - # min-token, presence, frequency - if sampling_info.linear_penalties is not None: - logits.add_(sampling_info.linear_penalties) - - # repetition - if sampling_info.scaling_penalties is not None: - logits = torch.where( - logits > 0, - logits / sampling_info.scaling_penalties, - logits * sampling_info.scaling_penalties, - ) - - # Apply regex vocab_mask - if sampling_info.vocab_mask is not None: - sampling_info.apply_mask(logits=logits, vocab_mask=sampling_info.vocab_mask) - - return logits - @property def model_is_mrope(self) -> bool: """Detect if the model has "mrope" rope_scaling type. diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index e0b03d7710e..9e6b09488e6 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -21,14 +21,14 @@ from torch import nn from transformers import AutoModelForCausalLM, PretrainedConfig from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from vllm.distributed import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig, LoadFormat from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.model_loader.utils import ( get_model_architecture, @@ -374,6 +374,78 @@ def load_model( return model.eval() +class LayeredModelLoader(DefaultModelLoader): + """Model loader that loads weights layer by layer so that one can quantize a + layer before loading another to make the peak memory envelope smaller.""" + + def __init__(self, load_config: LoadConfig): + # Back to the default load format + load_config.load_format = LoadFormat.AUTO + super().__init__(load_config) + + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + ) -> nn.Module: + from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model + from sglang.srt.managers.schedule_batch import global_server_args_dict + + torchao_config = global_server_args_dict.get("torchao_config") + target_device = torch.device(device_config.device) + + with set_default_torch_dtype(model_config.dtype): + # Create model on meta device + with torch.device("meta"): + model = _initialize_model( + model_config, + self.load_config, + ) + + # Check model's layered load support + if not hasattr(model, "load_weights_to_module"): + raise ValueError( + "LayeredModelLoader requires the model to have a " + "`load_weights_to_module` method. " + f"{model_config.model_path} does not support it." + ) + + # Get all weights from disk + weights = self._get_all_weights(model_config, model) + + # Helper function to recursively fill the weights of a module + def fill_module(module, fqn: List[str], weights): + """ + fqn: list of strings representing the fully qualified name of `module`. + """ + # Layer by layer + for name, submod in module.named_children(): + fill_module(submod, fqn + [name], weights) + + # First materialize on target device + module.to_empty(device=target_device, recurse=False) + fqn_path = ".".join(fqn) + # Fill weights + model.load_weights_to_module( + fqn_path, + weights, + ) + # Quantize weights if applicable + if torchao_config and "proj" in fqn_path: + # Note: `None` here is needed to indicate no filter, see + # `apply_torchao_config_to_model` for details. + apply_torchao_config_to_model(module, torchao_config, None) + + # Start calling on root module + fill_module(model, [], weights) + + if torchao_config: + model.torchao_applied = True + + return model.eval() + + class DummyModelLoader(BaseModelLoader): """Model loader that will set model weights to random values.""" @@ -496,7 +568,8 @@ def load_model( device_config: DeviceConfig, ) -> nn.Module: from safetensors.torch import safe_open - from vllm.distributed import get_tensor_model_parallel_rank + + from sglang.srt.distributed import get_tensor_model_parallel_rank local_model_path = self._prepare_weights( model_config.model_path, model_config.revision @@ -556,7 +629,8 @@ def save_model( max_size: Optional[int] = None, ) -> None: from safetensors.torch import save_file - from vllm.distributed import get_tensor_model_parallel_rank + + from sglang.srt.distributed import get_tensor_model_parallel_rank if pattern is None: pattern = ShardedStateLoader.DEFAULT_PATTERN @@ -770,6 +844,21 @@ def _get_quantized_weights_iterator( quant_state_dict, ) + def _is_8bit_weight_name(self, weight_name: str): + quantized_suffix = {".scb", ".weight_format"} + return any(weight_name.lower().endswith(suffix) for suffix in quantized_suffix) + + def _is_4bit_weight_name(self, weight_name: str): + quantized_suffix = { + "absmax", + "quant_map", + "nested_absmax", + "nested_quant_map", + "bitsandbytes", + } + suffix = weight_name.split(".")[-1] + return any(q_suffix in suffix for q_suffix in quantized_suffix) + def _quantized_8bit_generator( self, hf_weights_files, use_safetensors, quant_state_dict ) -> Generator: @@ -779,21 +868,18 @@ def _quantized_8bit_generator( if not weight_name.lower().endswith(".scb"): continue - weight_key = weight_name.lower().replace(".scb", ".qweight") + weight_key = weight_name.lower().replace(".scb", ".weight") quant_state_dict[weight_key] = weight_tensor for weight_name, weight_tensor in self._hf_weight_iter( hf_weights_files, use_safetensors ): - - if not weight_name.endswith((".weight", ".bias")): + if self._is_8bit_weight_name(weight_name): continue - qweight_name = weight_name.replace(".weight", ".qweight") - - if qweight_name in quant_state_dict: + if weight_name in quant_state_dict: set_weight_attrs(weight_tensor, {"load_in_8bit": True}) - yield qweight_name, weight_tensor + yield weight_name, weight_tensor else: yield weight_name, weight_tensor @@ -806,7 +892,7 @@ def _quantized_4bit_generator( weight_iterator = self._hf_weight_iter(hf_weights_files, use_safetensors) temp_state_dict = {} for weight_name, weight_tensor in weight_iterator: - if weight_name.endswith((".weight", ".bias")): + if not self._is_4bit_weight_name(weight_name): continue # bitsandbytes library requires # weight.quant_state.bitsandbytes__* in CPU @@ -830,16 +916,15 @@ def _parse_quant_state(param_name: str, temp_state_dict: Dict) -> QuantState: hf_weights_files, use_safetensors ): - if not weight_name.endswith((".weight", ".bias")): + if self._is_4bit_weight_name(weight_name): continue if (f"{weight_name}.quant_state.bitsandbytes__nf4" in temp_state_dict) or ( f"{weight_name}.quant_state.bitsandbytes__fp4" in temp_state_dict ): quant_state = _parse_quant_state(weight_name, temp_state_dict) - weight_name = weight_name.replace(".weight", ".qweight") quant_state_dict[weight_name] = quant_state - yield weight_name.replace(".weight", ".qweight"), weight_tensor + yield weight_name, weight_tensor else: yield weight_name, weight_tensor @@ -1136,4 +1221,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: if load_config.load_format == LoadFormat.GGUF: return GGUFModelLoader(load_config) + if load_config.load_format == LoadFormat.LAYERED: + return LayeredModelLoader(load_config) + return DefaultModelLoader(load_config) diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index 13b323b5d32..c07a346f471 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -9,7 +9,17 @@ import os import tempfile from collections import defaultdict -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Generator, + Iterable, + List, + Optional, + Tuple, + Union, +) import filelock import gguf @@ -17,12 +27,13 @@ import numpy as np import torch from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download +from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator from safetensors.torch import load_file, safe_open, save_file from tqdm.auto import tqdm -from vllm.distributed import get_tensor_model_parallel_rank from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config from sglang.srt.utils import print_warning_once @@ -393,8 +404,13 @@ def np_cache_weights_iterator( def safetensors_weights_iterator( hf_weights_files: List[str], + is_all_weights_sharded: bool = False, ) -> Generator[Tuple[str, torch.Tensor], None, None]: - """Iterate over the weights in the model safetensor files.""" + """Iterate over the weights in the model safetensor files. + + If is_all_weights_sharded is True, it uses more optimize read by reading an + entire file instead of reading each tensor one by one. + """ enable_tqdm = ( not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 ) @@ -404,9 +420,14 @@ def safetensors_weights_iterator( disable=not enable_tqdm, bar_format=_BAR_FORMAT, ): - with safe_open(st_file, framework="pt") as f: - for name in f.keys(): # noqa: SIM118 - param = f.get_tensor(name) + if not is_all_weights_sharded: + with safe_open(st_file, framework="pt") as f: + for name in f.keys(): # noqa: SIM118 + param = f.get_tensor(name) + yield name, param + else: + result = load_file(st_file, device="cpu") + for name, param in result.items(): yield name, param @@ -638,3 +659,121 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: # If there were no matches, return the untouched param name return name + + +# Adapted from https://github.com/vllm-project/vllm/blob/68ad4e3a8d8a66fb2a43be57471ee13a8bec4ec0/vllm/model_executor/layers/quantization/schema.py +class KVCacheQuantSchema(BaseModel): + dtype: str + # Each key is a TP rank. Each value is a dictionary mapping a TP rank's + # layer indices to their per-tensor KV cache scaling factor. + # TODO: Consider pulling this and its validation methods out into its + # own schema class (tricky as its members are variable) + scaling_factor: Dict[int, Dict[int, float]] + + @model_validator(mode="after") + def check_is_fp8(self) -> "KVCacheQuantSchema": + assert self.dtype == "float8_e4m3fn", ( + "Loaded scaling factors intended for KV cache dtype = " + f"{self.dtype} rather than float8_e4m3fn!" + ) + return self + + @model_validator(mode="after") + def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema": + context = info.context + if context: + tp_size = context["tp_size"] + num_hidden_layers = context["num_hidden_layers"] + assert len(self.scaling_factor) == tp_size, ( + f"Loaded dictionary has TP size {len(self.scaling_factor)} " + f"but LLM engine is currently running with TP size {tp_size}." + ) + for tp_rank, layer_maps in self.scaling_factor.items(): + assert len(layer_maps) == num_hidden_layers, ( + f"KV cache scales map for TP rank {tp_rank} is malformed. " + f"Expected {num_hidden_layers} layers, got " + f"{len(layer_maps)}." + ) + for i in range(tp_size): + assert ( + i in self.scaling_factor + ), f"KV cache scales map for TP rank {i} not found." + return self + + @model_validator(mode="after") + def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema": + context = info.context + if context: + tp_rank = context["tp_rank"] + num_hidden_layers = context["num_hidden_layers"] + layer_scales_map = self.scaling_factor[tp_rank] + for i in range(num_hidden_layers): + assert i in layer_scales_map, ( + f"Could not find KV cache scales for layer {i} in " + f"TP rank {tp_rank}." + ) + return self + + +class QuantParamSchema(BaseModel): + # TODO: Generalize and extend with more fields + # (e.g. weights/activations params) once functionality is enabled + model_config = ConfigDict(protected_namespaces=()) + model_type: Optional[str] + kv_cache: KVCacheQuantSchema + + @model_validator(mode="after") + def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema": + context = info.context + if context: + model_type = context.get("model_type", None) + if model_type is not None: + assert model_type == self.model_type, ( + f"Model type is {model_type} but loaded " + f"scaling factors belonging to different " + f"model type {self.model_type}!" + ) + return self + + +def kv_cache_scales_loader( + filename: str, + tp_rank: int, + tp_size: int, + num_hidden_layers: int, + model_type: Optional[str], +) -> Iterable[Tuple[int, float]]: + """ + A simple utility to read in KV cache scaling factors that have been + previously serialized to disk. Used by the model to populate the appropriate + KV cache scaling factors. The serialization should represent a dictionary + whose keys are the TP ranks and values are another dictionary mapping layers + to their KV cache scaling factors. + """ + try: + with open(filename) as f: + context = { + "model_type": model_type, + "num_hidden_layers": num_hidden_layers, + "tp_rank": tp_rank, + "tp_size": tp_size, + } + schema_dct = json.load(f) + schema = QuantParamSchema.model_validate(schema_dct, context=context) + layer_scales_map = schema.kv_cache.scaling_factor[tp_rank] + return layer_scales_map.items() + except FileNotFoundError: + logger.error("File or directory '%s' not found.", filename) + except json.JSONDecodeError: + logger.error("Error decoding JSON in file '%s'.", filename) + except Exception: + logger.error("An error occurred while reading '%s'.", filename) + # This section is reached if and only if any of the excepts are hit + # Return an empty iterable (list) => no KV cache scales are loaded + # which ultimately defaults to 1.0 scales + logger.warning( + "Defaulting to KV cache scaling factors = 1.0 for all " + "layers in TP rank %d as an error occurred during loading.", + tp_rank, + ) + return [] diff --git a/python/sglang/srt/model_parallel.py b/python/sglang/srt/model_parallel.py index 778347b8ef3..53c8b622e4d 100644 --- a/python/sglang/srt/model_parallel.py +++ b/python/sglang/srt/model_parallel.py @@ -2,18 +2,18 @@ Common utilities for torch model parallelism. """ -from typing import Optional +from typing import Optional, Sequence import torch +import torch.nn as nn from torch.distributed.device_mesh import DeviceMesh try: - from torch.distributed.tensor import DTensor, Shard + import torch.distributed.tensor as dt except ImportError: # torch 2.4 or older - from torch.distributed._tensor import DTensor, Shard + import torch.distributed._tensor as dt -from torch.distributed._functional_collectives import AsyncCollectiveTensor from torch.distributed.tensor.parallel import ( ColwiseParallel, RowwiseParallel, @@ -21,6 +21,50 @@ ) +def _shard_tensor( + full_tensor: torch.Tensor, + device_mesh: DeviceMesh, + placements: Sequence[dt.Shard], +) -> "dt.DTensor": + """ + Locally shards a full tensor based on indicated sharding arrangement, and + returns a DTensor containing the local shard. + + .. warning:: This is a private API that is subject to change. It skips the + communication otherwise required by `distribute_tensor`. It is only + applicable to cases where all ranks have the same `full_tensor`. For + example, in distributed inference all ranks load from the same + checkpoint. This API will not check for data equality between ranks, it + is thus user's responsibility to ensure the `full_tensor` is the same + across ranks. + + Args: + full_tensor (torch.Tensor): the full tensor to be sharded. + device_mesh (:class:`DeviceMesh`): DeviceMesh to place the + DTensor. Must have same dimension as the number of placements. + placements (Sequence[:class:`Shard`]): the placements that + describes how to place the local tensor on DeviceMesh. + + Returns: + A :class:`DTensor` object with the shard as its local tensor. + + Examples: + >>> # xdoctest: +SKIP("need world_size and rank") + >>> device_mesh = dist.init_device_mesh("cuda", (world_size,)) + >>> full_tensor = torch.arange(world_size, device=f"cuda:{rank}") + >>> dtensor = _shard_tensor(full_tensor, device_mesh, [Shard(1)]) + """ + shape, offset = dt._utils.compute_local_shape_and_global_offset( + full_tensor.shape, device_mesh, placements + ) + slices = [ + slice(cur_offset, cur_offset + cur_shape) + for cur_shape, cur_offset in zip(shape, offset) + ] + local_tensor = full_tensor[slices] + return dt.DTensor.from_local(local_tensor, device_mesh, placements) + + class ColwiseParallelSharded(ColwiseParallel): """ A version of ColwiseParallel where the local weight has been already @@ -34,7 +78,7 @@ def _partition_linear_fn(self, name, module, device_mesh): # means Colwise as Linear is input * weight^T + bias, where # weight would become Shard(1) for name, param in module.named_parameters(): - dtensor = DTensor.from_local(param, device_mesh, [Shard(0)]) + dtensor = dt.DTensor.from_local(param, device_mesh, [dt.Shard(0)]) dist_param = torch.nn.Parameter(dtensor, requires_grad=False) module.register_parameter(name, dist_param) @@ -47,6 +91,23 @@ class RowwiseParallelMaybeWait(RowwiseParallel): AsyncCollectiveTensor and custom ops, such as `class RMSNorm(CustomOp)`. """ + def _partition_linear_fn(self, name, module, device_mesh): + # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1) + # means Rowwise as nn.Linear is input * weight^T + bias, where + # weight would become Shard(0) + module.register_parameter( + "weight", + nn.Parameter(_shard_tensor(module.weight, device_mesh, [dt.Shard(1)])), + ) + if getattr(module, "bias", None) is not None: + # The Linear module has bias + module.register_parameter( + "bias", + nn.Parameter( + dt.distribute_tensor(module.bias, device_mesh, [dt.Replicate()]) + ), + ) + @staticmethod def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): outputs = super( diff --git a/python/sglang/srt/models/baichuan.py b/python/sglang/srt/models/baichuan.py index 3bd60c25d3e..066157f05ce 100644 --- a/python/sglang/srt/models/baichuan.py +++ b/python/sglang/srt/models/baichuan.py @@ -24,22 +24,22 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.distributed import ( + +from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from vllm.model_executor.layers.linear import ( +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) -from vllm.model_executor.layers.rotary_embedding import get_rope - -from sglang.srt.layers.activation import SiluAndMul -from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/chatglm.py b/python/sglang/srt/models/chatglm.py index 9c3bc2ee9e0..222cc3e2d80 100644 --- a/python/sglang/srt/models/chatglm.py +++ b/python/sglang/srt/models/chatglm.py @@ -21,10 +21,9 @@ import torch from torch import nn from torch.nn import LayerNorm -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.transformers_utils.configs import ChatGLMConfig +from sglang.srt.configs import ChatGLMConfig +from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -35,6 +34,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index 83ac3d8671b..e4b291b66cb 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -44,12 +44,11 @@ from torch import nn from torch.nn.parameter import Parameter from transformers import PretrainedConfig -from vllm.distributed import ( + +from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from vllm.model_executor.layers.rotary_embedding import get_rope - from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.linear import ( MergedColumnParallelLinear, @@ -59,9 +58,13 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) from sglang.srt.utils import get_compiler_backend, set_weight_attrs @@ -372,10 +375,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) -EntryClass = CohereForCausalLM +class Cohere2ForCausalLM(CohereForCausalLM): + pass + + +EntryClass = [CohereForCausalLM, Cohere2ForCausalLM] diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index 45561d1dbc0..92fc679391f 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -19,30 +19,33 @@ import torch import torch.nn as nn -from vllm.distributed import ( + +from sglang.srt.configs import DbrxConfig +from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.transformers_utils.configs.dbrx import DbrxConfig - -from sglang.srt.layers.fused_moe_triton import fused_moe from sglang.srt.layers.linear import ( QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.fused_moe_triton import fused_moe from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) from sglang.srt.utils import set_weight_attrs @@ -411,6 +414,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, weight_name) break else: + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/models/deepseek.py b/python/sglang/srt/models/deepseek.py index ce1b152fbc7..7d2c0700fe4 100644 --- a/python/sglang/srt/models/deepseek.py +++ b/python/sglang/srt/models/deepseek.py @@ -21,15 +21,13 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.distributed import ( + +from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.layers.rotary_embedding import get_rope - from sglang.srt.layers.activation import SiluAndMul -from sglang.srt.layers.fused_moe_triton import fused_moe from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( MergedColumnParallelLinear, @@ -38,8 +36,10 @@ RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.fused_moe_triton import fused_moe from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 63cea92c289..4384410476c 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -19,20 +19,18 @@ from typing import Any, Dict, Iterable, Optional, Tuple import torch +import torch.nn.functional as F from torch import nn from transformers import PretrainedConfig from vllm import _custom_ops as ops -from vllm.distributed import ( + +from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.layers.rotary_embedding import get_rope - from sglang.srt.layers.activation import SiluAndMul -from sglang.srt.layers.ep_moe.layer import EPMoE -from sglang.srt.layers.fused_moe_triton import FusedMoE from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( ColumnParallelLinear, @@ -41,8 +39,16 @@ RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.ep_moe.layer import EPMoE +from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.quantization.fp8_utils import ( + block_quant_to_tensor_quant, + input_to_float8, + normalize_e4m3fn_to_e4m3fnuz, +) from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -50,10 +56,12 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import is_flashinfer_available +from sglang.srt.utils import is_cuda_available, is_hip + +is_hip_ = is_hip() -if is_flashinfer_available(): - from flashinfer import bmm_fp8 +if is_cuda_available(): + from sgl_kernel import bmm_fp8 class DeepseekV2MLP(nn.Module): @@ -90,6 +98,24 @@ def forward(self, x): return x +class MoEGate(nn.Module): + def __init__(self, config): + super().__init__() + self.weight = nn.Parameter( + torch.empty((config.n_routed_experts, config.hidden_size)) + ) + if config.topk_method == "noaux_tc": + self.e_score_correction_bias = nn.Parameter( + torch.empty((config.n_routed_experts)) + ) + else: + self.e_score_correction_bias = None + + def forward(self, hidden_states): + logits = F.linear(hidden_states, self.weight, None) + return logits + + class DeepseekV2MoE(nn.Module): def __init__( @@ -114,6 +140,8 @@ def __init__( "Only silu is supported for now." ) + self.gate = MoEGate(config=config) + MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE self.experts = MoEImpl( num_experts=config.n_routed_experts, @@ -125,11 +153,9 @@ def __init__( use_grouped_topk=True, num_expert_group=config.n_group, topk_group=config.topk_group, + correction_bias=self.gate.e_score_correction_bias, ) - self.gate = ReplicatedLinear( - config.hidden_size, config.n_routed_experts, bias=False, quant_config=None - ) if config.n_shared_experts is not None: intermediate_size = config.moe_intermediate_size * config.n_shared_experts self.shared_experts = DeepseekV2MLP( @@ -146,7 +172,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.n_shared_experts is not None: shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) + router_logits = self.gate(hidden_states) final_hidden_states = ( self.experts(hidden_states=hidden_states, router_logits=router_logits) * self.routed_scaling_factor @@ -167,15 +193,6 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: return 0.1 * mscale * math.log(scale) + 1.0 -def input_to_float8(x, dtype=torch.float8_e4m3fn): - finfo = torch.finfo(dtype) - min_val, max_val = x.aminmax() - amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) - scale = finfo.max / amax - x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) - return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() - - class DeepseekV2Attention(nn.Module): def __init__( @@ -254,13 +271,14 @@ def __init__( quant_config=quant_config, ) rope_scaling["rope_type"] = "deepseek_yarn" - self.rotary_emb = get_rope( + self.rotary_emb = get_rope_wrapper( qk_rope_head_dim, rotary_dim=qk_rope_head_dim, max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, is_neox_style=False, + device=global_server_args_dict["device"], ) if rope_scaling: @@ -439,7 +457,10 @@ def __init__( quant_config=quant_config, ) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) - rope_scaling["rope_type"] = "deepseek_yarn" + + if rope_scaling: + rope_scaling["rope_type"] = "deepseek_yarn" + self.rotary_emb = get_rope( qk_rope_head_dim, rotary_dim=qk_rope_head_dim, @@ -454,6 +475,8 @@ def __init__( scaling_factor = rope_scaling["factor"] mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale + else: + self.rotary_emb.forward = self.rotary_emb.forward_native self.attn_mqa = RadixAttention( self.num_local_heads, @@ -554,7 +577,13 @@ def forward_absorb( ) q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - if self.w_kc.dtype == torch.float8_e4m3fn: + if self.w_kc.dtype == torch.float8_e4m3fnuz: + # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz + q_nope_out = torch.bmm( + q_nope.to(torch.bfloat16).transpose(0, 1), + self.w_kc.to(torch.bfloat16) * self.w_scale, + ) + elif self.w_kc.dtype == torch.float8_e4m3fn: q_nope_val, q_nope_scale = input_to_float8( q_nope.transpose(0, 1), torch.float8_e4m3fn ) @@ -579,7 +608,13 @@ def forward_absorb( attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) - if self.w_vc.dtype == torch.float8_e4m3fn: + if self.w_vc.dtype == torch.float8_e4m3fnuz: + # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz + attn_bmm_output = torch.bmm( + attn_output.to(torch.bfloat16).transpose(0, 1), + self.w_vc.to(torch.bfloat16) * self.w_scale, + ) + elif self.w_vc.dtype == torch.float8_e4m3fn: attn_output_val, attn_output_scale = input_to_float8( attn_output.transpose(0, 1), torch.float8_e4m3fn ) @@ -821,10 +856,9 @@ def forward( forward_batch: ForwardBatch, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch) - if not forward_batch.forward_mode.is_idle(): - return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch - ) + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -845,6 +879,16 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: + # TODO(HandH1998): Modify it when nextn is supported. + if hasattr(self.config, "num_nextn_predict_layers"): + num_nextn_layers = self.config.num_nextn_predict_layers + if num_nextn_layers > 0 and name.startswith("model.layers"): + name_list = name.split(".") + if ( + len(name_list) >= 3 + and int(name_list[2]) >= self.config.num_hidden_layers + ): + continue if "rotary_emb.inv_freq" in name: continue for param_name, weight_name, shard_id in stacked_params_mapping: @@ -909,13 +953,45 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ).T else: w = self_attn.kv_b_proj.weight + # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. + # This may affect the accuracy of fp8 model. + if hasattr(self.quant_config, "weight_block_size") and w.dtype in ( + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + ): + weight_block_size = self.quant_config.weight_block_size + if weight_block_size is not None: + assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") + if is_hip_: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, + weight_scale=self_attn.kv_b_proj.weight_scale_inv, + input_scale=None, + ) + else: + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale_inv + + w, scale = block_quant_to_tensor_quant( + weight, weight_scale, weight_block_size + ) + self_attn.w_scale = scale w_kc, w_vc = w.unflatten( 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2) self_attn.w_vc = w_vc.contiguous().transpose(1, 2) - if hasattr(self_attn.kv_b_proj, "weight_scale"): + if ( + hasattr(self_attn.kv_b_proj, "weight_scale") + and self_attn.w_scale is None + ): self_attn.w_scale = self_attn.kv_b_proj.weight_scale + if is_hip_: + self_attn.w_scale *= 2.0 + + +class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): + pass -EntryClass = DeepseekV2ForCausalLM +EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM] diff --git a/python/sglang/srt/models/exaone.py b/python/sglang/srt/models/exaone.py index 536c253c33a..10be1e74d61 100644 --- a/python/sglang/srt/models/exaone.py +++ b/python/sglang/srt/models/exaone.py @@ -20,9 +20,8 @@ import torch from torch import nn -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.rotary_embedding import get_rope +from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -33,6 +32,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index 10949a2f572..9940c569e25 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -21,9 +21,8 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.rotary_embedding import get_rope +from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -34,6 +33,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index dbca7268803..06a7b030260 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -15,13 +15,13 @@ # Adapted from: # https://github.com/vllm-project/vllm/blob/56b325e977435af744f8b3dca7af0ca209663558/vllm/model_executor/models/gemma2.py -from typing import Iterable, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple import torch from torch import nn from transformers import PretrainedConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.layernorm import GemmaRMSNorm from sglang.srt.layers.linear import ( @@ -32,9 +32,13 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) from sglang.srt.utils import make_layers @@ -44,23 +48,6 @@ def get_attention_sliding_window_size(config): return config.sliding_window - 1 -# FIXME: temporary solution, remove after next vllm release -from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding - - -class GemmaRotaryEmbedding(RotaryEmbedding): - def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: - # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107 - inv_freq = 1.0 / ( - base - ** ( - torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float() - / self.rotary_dim - ) - ) - return inv_freq - - class Gemma2MLP(nn.Module): def __init__( self, @@ -143,14 +130,12 @@ def __init__( bias=config.attention_bias, quant_config=quant_config, ) - # from vLLM: TODO(woosuk): Use the `get_rope` interface. - self.rotary_emb = GemmaRotaryEmbedding( + self.rotary_emb = get_rope( self.head_dim, - self.head_dim, - max_position_embeddings, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, base=self.rope_theta, is_neox_style=True, - dtype=torch.get_default_dtype(), ) use_sliding_window = layer_id % 2 == 0 and hasattr(config, "sliding_window") @@ -307,6 +292,25 @@ def forward( class Gemma2ForCausalLM(nn.Module): + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -355,6 +359,40 @@ def forward( input_ids, hidden_states, self.model.embed_tokens, forward_batch ) + def get_hidden_dim(self, module_name): + # return input_dim, output_dim + if module_name in ["q_proj", "qkv_proj"]: + return ( + self.config.hidden_size, + self.config.head_dim * self.config.num_attention_heads, + ) + elif module_name in ["o_proj"]: + return ( + self.config.head_dim * self.config.num_attention_heads, + self.config.hidden_size, + ) + elif module_name in ["kv_proj"]: + return ( + self.config.hidden_size, + self.config.head_dim * self.config.num_key_value_heads, + ) + elif module_name == "gate_up_proj": + return self.config.hidden_size, self.config.intermediate_size + elif module_name == "down_proj": + return self.config.intermediate_size, self.config.hidden_size + else: + raise NotImplementedError() + + def get_module_name(self, name): + params_mapping = { + "q_proj": "qkv_proj", + "k_proj": "qkv_proj", + "v_proj": "qkv_proj", + "gate_proj": "gate_up_proj", + "up_proj": "gate_up_proj", + } + return params_mapping.get(name, name) + def get_attention_sliding_window_size(self): return get_attention_sliding_window_size(self.config) @@ -389,6 +427,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/models/gemma2_reward.py b/python/sglang/srt/models/gemma2_reward.py index e5c2fc07aaf..1fe87c30aef 100644 --- a/python/sglang/srt/models/gemma2_reward.py +++ b/python/sglang/srt/models/gemma2_reward.py @@ -32,7 +32,6 @@ def __init__( ) -> None: super().__init__() self.config = config - self.torchao_config = None self.quant_config = quant_config self.num_labels = config.num_labels self.model = Gemma2Model(config, quant_config=quant_config) diff --git a/python/sglang/srt/models/gpt2.py b/python/sglang/srt/models/gpt2.py index 144ad8bbf72..04c3005ce2f 100644 --- a/python/sglang/srt/models/gpt2.py +++ b/python/sglang/srt/models/gpt2.py @@ -17,16 +17,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-2 model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, Optional, Tuple import torch from torch import nn from transformers import GPT2Config -from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding -# from sglang.srt.layers.activation import get_act_fn +from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_world_size +from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, diff --git a/python/sglang/srt/models/gpt_bigcode.py b/python/sglang/srt/models/gpt_bigcode.py index f2f5ebd5204..0d705fb41b6 100644 --- a/python/sglang/srt/models/gpt_bigcode.py +++ b/python/sglang/srt/models/gpt_bigcode.py @@ -21,8 +21,8 @@ import torch from torch import nn from transformers import GPTBigCodeConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.linear import ( ColumnParallelLinear, diff --git a/python/sglang/srt/models/granite.py b/python/sglang/srt/models/granite.py new file mode 100644 index 00000000000..255f23227ff --- /dev/null +++ b/python/sglang/srt/models/granite.py @@ -0,0 +1,517 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# Adapted from +# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1 +"""Inference-only Granite model compatible with HuggingFace weights.""" + +import logging +from typing import Any, Dict, Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import GraniteConfig + +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.pooler import Pooler, PoolingType +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.utils import get_exception_traceback + +logger = logging.getLogger(__name__) + + +class GraniteMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class GraniteAttention(nn.Module): + def __init__( + self, + config: GraniteConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + layer_id: int = 0, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + rope_is_neox_style: bool = True, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr( + config, "head_dim", self.hidden_size // self.total_num_heads + ) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = config.attention_multiplier + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=rope_is_neox_style, + ) + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, forward_batch) + output, _ = self.o_proj(attn_output) + return output + + +class GraniteDecoderLayer(nn.Module): + def __init__( + self, + config: GraniteConfig, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.residual_multiplier = config.residual_multiplier + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None + ): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings + ) + rope_is_neox_style = getattr(config, "rope_is_neox_style", True) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + self.self_attn = GraniteAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + layer_id=layer_id, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + rope_is_neox_style=rope_is_neox_style, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = GraniteMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = ( + self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + * self.residual_multiplier + ) # multiplier for Maximal Update Parameterization + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) * self.residual_multiplier + return hidden_states, residual + + +class GraniteModel(nn.Module): + def __init__( + self, + config: GraniteConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, config.hidden_size + ) + self.layers = nn.ModuleList( + [ + GraniteDecoderLayer( + config, i, quant_config=quant_config, prefix=f"model.layers.{i}" + ) + for i in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + residual = None + hidden_states *= self.config.embedding_multiplier + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + forward_batch, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class GraniteForCausalLM(nn.Module): + def __init__( + self, + config: GraniteConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self.model = GraniteModel(config, quant_config=quant_config) + # If tie_word_embeddings == True, then input and output embeddings are + # the same tensor. Enforce during object creation so that weights will + # load correctly even if the LM head weights don't have a separate entry + # in the state dict. + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) + if self.config.tie_word_embeddings: + self.lm_head.tie_weights(self.model.embed_tokens) + + # Granite logit scaling factors are applied via division, but + # LogitsProcessor expects a multiplicative factor. + if hasattr(config, "logits_scaling"): + logit_scale = 1.0 / config.logits_scaling + else: + logit_scale = None + self.logits_processor = LogitsProcessor(config, logit_scale=logit_scale) + self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + self.stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + get_embedding: bool = False, + ) -> LogitsProcessorOutput: + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) + if not get_embedding: + logits_processor_output: LogitsProcessorOutput = self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + return logits_processor_output + else: + return self.pooler(hidden_states, forward_batch) + + def get_hidden_dim(self, module_name): + # return input_dim, output_dim + if module_name in ["q_proj", "o_proj", "qkv_proj"]: + return self.config.hidden_size, self.config.hidden_size + elif module_name in ["kv_proj"]: + return self.config.hidden_size, self.config.hidden_size // ( + self.config.num_attention_heads // self.config.num_key_value_heads + ) + elif module_name == "gate_up_proj": + return self.config.hidden_size, self.config.intermediate_size + elif module_name == "down_proj": + return self.config.intermediate_size, self.config.hidden_size + else: + raise NotImplementedError() + + def get_module_name(self, name): + params_mapping = { + "q_proj": "qkv_proj", + "k_proj": "qkv_proj", + "v_proj": "qkv_proj", + "gate_proj": "gate_up_proj", + "up_proj": "gate_up_proj", + } + return params_mapping.get(name, name) + + def get_module_name_from_weight_name(self, name): + for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping: + if weight_name in name: + return ( + name.replace(weight_name, param_name)[: -len(".weight")], + num_shard, + ) + return name[: -len(".weight")], 1 + + def get_num_params(self): + params_dict = dict(self.named_parameters()) + return len(params_dict) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name or "projector" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if name.startswith("model.vision_tower") and name not in params_dict: + continue + if "lm_head.weight" in name and self.config.tie_word_embeddings: + # Input and output embeddings are tied, so the output embeddings + # may not be present in the checkpoint. We assume that the input + # embeddings are always present in the checkpoint. + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # This block only runs if the preceding for loop doesn't find + # a match for `name` in `stacked_params_mapping`. + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip loading kv_scale from ckpts towards new design. + if name.endswith(".kv_scale") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + def get_weights_by_name( + self, name: str, truncate_size: int = 100, tp_size: int = 1 + ) -> Optional[torch.Tensor]: + """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face. + + Only used for unit test with an unoptimized performance. + For optimized performance, please use torch.save and torch.load. + """ + try: + if name == "lm_head.weight" and self.config.tie_word_embeddings: + logger.info( + "word embedding is tied for this model, return embed_tokens.weight as lm_head.weight." + ) + return ( + self.model.embed_tokens.weight.cpu() + .to(torch.float32) + .numpy() + .tolist()[:truncate_size] + ) + + mapped_name = name + mapped_shard_id = None + for param_name, weight_name, shard_id in self.stacked_params_mapping: + if weight_name in name: + mapped_name = name.replace(weight_name, param_name) + mapped_shard_id = shard_id + break + params_dict = dict(self.named_parameters()) + param = params_dict[mapped_name] + if mapped_shard_id is not None: + if mapped_shard_id in ["q", "k", "v"]: + num_heads = self.config.num_attention_heads // tp_size + num_kv_heads = self.config.num_key_value_heads // tp_size + head_dim = ( + self.config.hidden_size // self.config.num_attention_heads + ) + if mapped_shard_id == "q": + offset = 0 + size = num_heads * head_dim + elif mapped_shard_id == "k": + offset = num_heads * head_dim + size = num_kv_heads * head_dim + elif mapped_shard_id == "v": + offset = (num_heads + num_kv_heads) * head_dim + size = num_kv_heads * head_dim + weight = param.data.narrow(0, offset, size) + elif mapped_shard_id in [0, 1]: + intermediate_size = self.config.intermediate_size + slice_size = intermediate_size // tp_size + if mapped_shard_id == 0: # gate_proj + offset = 0 + size = slice_size + elif mapped_shard_id == 1: # up_proj + offset = slice_size + size = slice_size + + weight = param.data.narrow(0, offset, size) + else: + weight = param.data + else: + weight = param.data + if tp_size > 1 and ("o_proj" in name or "down_proj" in name): + gathered_weights = [torch.zeros_like(weight) for _ in range(tp_size)] + torch.distributed.all_gather(gathered_weights, weight) + weight = torch.cat(gathered_weights, dim=1) + return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size] + + except Exception: + logger.error( + f"Error getting weights by name {name} in GraniteForCausalLM: {get_exception_traceback()}" + ) + return None + + +EntryClass = [GraniteForCausalLM] diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 2b52e2b1bcc..0471e37d982 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -16,25 +16,30 @@ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1 """Inference-only Grok1 model.""" -from typing import Iterable, Optional, Tuple +from typing import Iterable, List, Optional, Tuple import torch import torch.nn.functional as F from torch import nn from transformers import PretrainedConfig -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.rotary_embedding import get_rope -from sglang.srt.layers.fused_moe_triton import FusedMoE +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -44,6 +49,43 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader +class Grok1MLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + reduce_results=True, + use_presharded_weights: bool = False, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + use_presharded_weights=use_presharded_weights, + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + reduce_results=reduce_results, + use_presharded_weights=use_presharded_weights, + ) + self.act_fn = GeluAndMul(approximate="tanh") + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + class Grok1MoE(nn.Module): """A tensor-parallel MoE implementation for Grok1 that shards each expert across all ranks. @@ -55,6 +97,7 @@ class Grok1MoE(nn.Module): def __init__( self, + config: PretrainedConfig, num_experts: int, top_k: int, hidden_size: int, @@ -62,6 +105,8 @@ def __init__( params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, + reduce_results=True, + use_presharded_weights: bool = False, ): super().__init__() self.hidden_size = hidden_size @@ -75,25 +120,33 @@ def __init__( quant_config=None, ) + self.router_logit_softcapping = getattr( + config, "router_logit_softcapping", 30.0 + ) self.experts = FusedMoE( num_experts=num_experts, top_k=top_k, hidden_size=hidden_size, intermediate_size=intermediate_size, params_dtype=params_dtype, - reduce_results=True, + reduce_results=reduce_results, renormalize=False, quant_config=quant_config, tp_size=tp_size, + activation="gelu", + use_presharded_weights=use_presharded_weights, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. orig_shape = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) router_logits = 30.0 * F.tanh(router_logits / 30.0) + + # need to assert self.gate.quant_method is unquantized final_hidden_states = self.experts(hidden_states, router_logits) return final_hidden_states.view(orig_shape) @@ -101,16 +154,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class Grok1Attention(nn.Module): def __init__( self, + config: PretrainedConfig, hidden_size: int, num_heads: int, num_kv_heads: int, layer_id: int = 0, max_position: int = 4096 * 32, rope_theta: float = 10000, - logit_cap: float = 30, quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, ) -> None: super().__init__() + self.config = config + self.layer_id = layer_id self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads @@ -126,7 +182,7 @@ def __init__( # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - self.head_dim = 128 + self.head_dim = getattr(config, "head_dim", 128) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 @@ -140,12 +196,12 @@ def __init__( bias=False, quant_config=quant_config, ) - self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, quant_config=quant_config, + reduce_results=reduce_results, ) self.rotary_emb = get_rope( self.head_dim, @@ -154,6 +210,9 @@ def __init__( base=int(self.rope_theta), is_neox_style=True, ) + + logit_cap = max(getattr(config, "attn_logit_softcapping", 30.0), 0.0) + self.attn = RadixAttention( self.num_heads, self.head_dim, @@ -162,7 +221,6 @@ def __init__( layer_id=layer_id, logit_cap=logit_cap, ) - # TODO(lianmin): load logit cap from config def forward( self, @@ -184,12 +242,16 @@ def __init__( config: PretrainedConfig, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + use_presharded_weights: bool = False, ) -> None: super().__init__() + self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size + self.layer_id = layer_id rope_theta = getattr(config, "rope_theta", 10000) self.self_attn = Grok1Attention( + config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, max_position=config.max_position_embeddings, @@ -199,11 +261,18 @@ def __init__( quant_config=quant_config, ) self.block_sparse_moe = Grok1MoE( + config=config, num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, + intermediate_size=getattr( + config, + "moe_intermediate_size", + getattr(config, "intermediate_size", None), + ), quant_config=quant_config, + reduce_results=True, + use_presharded_weights=use_presharded_weights, ) self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -241,6 +310,7 @@ def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + use_presharded_weights: bool = False, ) -> None: super().__init__() self.config = config @@ -253,7 +323,12 @@ def __init__( ) self.layers = nn.ModuleList( [ - Grok1DecoderLayer(config, i, quant_config=quant_config) + Grok1DecoderLayer( + config, + i, + quant_config=quant_config, + use_presharded_weights=use_presharded_weights, + ) for i in range(config.num_hidden_layers) ] ) @@ -284,11 +359,26 @@ def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + cache_config=None, ) -> None: super().__init__() self.config = config self.quant_config = quant_config - self.model = Grok1Model(config, quant_config=quant_config) + + if ( + self.config.num_local_experts > 0 + and get_tensor_model_parallel_world_size() > 1 + ): + self.use_presharded_weights = True + setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights) + else: + self.use_presharded_weights = False + + self.model = Grok1Model( + config, + quant_config=quant_config, + use_presharded_weights=self.use_presharded_weights, + ) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) @@ -304,12 +394,19 @@ def forward( input_ids, hidden_states, self.lm_head, forward_batch ) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights( + self, + weights: Iterable[Tuple[str, torch.Tensor]], + ): + num_experts = self.config.num_local_experts + stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), ] # Params for weights, fp8 weight scales, fp8 activation scales @@ -318,10 +415,23 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ckpt_gate_proj_name="w1", ckpt_down_proj_name="w2", ckpt_up_proj_name="w3", - num_experts=self.config.num_local_experts, + num_experts=num_experts, ) params_dict = dict(self.named_parameters()) + all_names = set(params_dict.keys()) + hit_names = set() + + def load_weight_wrapper(name, loaded_weight, *args, **kwargs): + if name not in params_dict: + return + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight, *args, **kwargs) + + hit_names.add(name) + for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue @@ -334,9 +444,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if name.endswith(".bias") and name not in params_dict: continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) + load_weight_wrapper(name, loaded_weight, shard_id) break else: for mapping in expert_params_mapping: @@ -345,10 +453,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue name = name.replace(weight_name, param_name) - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader( - param, + load_weight_wrapper( + name, loaded_weight, name, shard_id=shard_id, @@ -359,17 +465,56 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - # Skip loading kv_scale from ckpts towards new design. - if name.endswith(".kv_scale") and name not in params_dict: - continue if name is None: continue - param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight) + load_weight_wrapper(name=name, loaded_weight=loaded_weight) + + +old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights") + + +def _prepare_presharded_weights( + self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool +) -> Tuple[str, List[str], bool]: + import glob + import os + + if get_tensor_model_parallel_world_size() == 1: + return old_prepare_weights(self, model_name_or_path, revision, fall_back_to_pt) + + if not os.path.isdir(model_name_or_path): + from sglang.srt.model_loader.weight_utils import download_weights_from_hf + + allow_patterns = ["*.safetensors", "*.bin"] + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + allow_patterns, + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + else: + hf_folder = model_name_or_path + + tp_rank = get_tensor_model_parallel_rank() + + # The old format + allow_patterns = [f"*-{tp_rank:03d}.bin"] + + # The new format + allow_patterns += [f"*-TP-{tp_rank:03d}.safetensors", "*-TP-common.safetensors"] + + hf_weights_files: List[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + + if hf_weights_files[0].endswith("safetensors"): + use_safetensors = True + else: + use_safetensors = False + + return hf_folder, hf_weights_files, use_safetensors class Grok1ModelForCausalLM(Grok1ForCausalLM): diff --git a/python/sglang/srt/models/internlm2.py b/python/sglang/srt/models/internlm2.py index 0a737c1388b..ce8f9a3cf65 100644 --- a/python/sglang/srt/models/internlm2.py +++ b/python/sglang/srt/models/internlm2.py @@ -19,9 +19,8 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.rotary_embedding import get_rope +from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -32,6 +31,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index e3e44ea6ffc..4ea77eede9b 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -22,9 +22,11 @@ import torch from torch import nn from transformers import LlamaConfig -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.rotary_embedding import get_rope +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -36,12 +38,16 @@ from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + kv_cache_scales_loader, +) from sglang.srt.utils import make_layers from sglang.utils import get_exception_traceback @@ -100,6 +106,7 @@ def __init__( max_position_embeddings: int = 8192, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + bias: bool = False, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -132,14 +139,14 @@ def __init__( self.head_dim, self.total_num_heads, self.total_num_kv_heads, - bias=False, + bias=bias, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, - bias=False, + bias=bias, quant_config=quant_config, prefix=f"{prefix}.o_proj", ) @@ -194,6 +201,11 @@ def __init__( ) rope_is_neox_style = getattr(config, "rope_is_neox_style", True) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + # Support llamafy/Qwen-Qwen2.5-7B-Instruct-llamafied with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False + ) self.self_attn = LlamaAttention( config=config, hidden_size=self.hidden_size, @@ -206,6 +218,7 @@ def __init__( max_position_embeddings=max_position_embeddings, quant_config=quant_config, prefix=f"{prefix}.self_attn", + bias=attention_bias, ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, @@ -292,8 +305,54 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + # If this function is called, it should always initialize KV cache scale + # factors (or else raise an exception). Thus, handled exceptions should + # make sure to leave KV cache scale factors in a known good (dummy) state + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, + tp_rank, + tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type, + ): + if not isinstance(self.layers[layer_idx], nn.Identity): + layer_self_attn = self.layers[layer_idx].self_attn + + if hasattr(layer_self_attn.attn, "k_scale"): + layer_self_attn.attn.k_scale = scaling_factor + layer_self_attn.attn.v_scale = scaling_factor + else: + raise RuntimeError( + "Self attention has no KV cache scaling " "factor attribute!" + ) + class LlamaForCausalLM(nn.Module): + + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + # in TP, these weights are partitioned along the column dimension (dim=-1) + column_parallel_weights_modules = [".down_proj.", ".o_proj."] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + def __init__( self, config: LlamaConfig, @@ -303,8 +362,8 @@ def __init__( self.config = config self.quant_config = quant_config self.model = LlamaModel(config, quant_config=quant_config) - # Llama 3.2 1B Insturct set tie_word_embeddings to True - # Llama 3.1 8B Insturct set tie_word_embeddings to False + # Llama 3.2 1B Instruct set tie_word_embeddings to True + # Llama 3.1 8B Instruct set tie_word_embeddings to False if self.config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: @@ -494,9 +553,27 @@ def get_weights_by_name( ) return None + def get_embed_and_head(self): + return self.model.embed_tokens.weight, self.lm_head.weight + + def set_embed_and_head(self, embed, head): + del self.model.embed_tokens.weight + del self.lm_head.weight + self.model.embed_tokens.weight = embed + self.lm_head.weight = head + torch.cuda.empty_cache() + torch.cuda.synchronize() + + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + self.model.load_kv_cache_scales(quantization_param_path) + class Phi3ForCausalLM(LlamaForCausalLM): pass -EntryClass = [LlamaForCausalLM, Phi3ForCausalLM] +class InternLM3ForCausalLM(LlamaForCausalLM): + pass + + +EntryClass = [LlamaForCausalLM, Phi3ForCausalLM, InternLM3ForCausalLM] diff --git a/python/sglang/srt/models/llama_classification.py b/python/sglang/srt/models/llama_classification.py index 038732476ed..75e8af9af32 100644 --- a/python/sglang/srt/models/llama_classification.py +++ b/python/sglang/srt/models/llama_classification.py @@ -18,7 +18,7 @@ from torch import nn from transformers import LlamaConfig -from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader @@ -33,14 +33,13 @@ def __init__( ) -> None: super().__init__() self.config = config - self.torchao_config = None self.quant_config = quant_config self.model = LlamaModel(config, quant_config=quant_config) self.classification_head = nn.Linear( config.hidden_size, config.classification_out_size, bias=False ) - self.eos_token_id = config.eos_token_id + self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False) @torch.no_grad() def forward( @@ -49,28 +48,17 @@ def forward( positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, - ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) - is_eos_token = input_ids == self.eos_token_id - hidden_states = hidden_states[is_eos_token] - scores = self.classification_head(hidden_states) - - if scores.shape[0] != forward_batch.batch_size: - print("Warning: the EOS tokens are missing in some sentences.") - scores = torch.ones( - (forward_batch.batch_size, self.config.classification_out_size) - ).to(input_ids.device) + get_embedding: bool = True, + ) -> EmbeddingPoolerOutput: + assert ( + get_embedding + ), "LlamaForClassification is only used for embedding. Please add --is-embedding when you launch the server." - logits_output = LogitsProcessorOutput( - next_token_logits=scores, - next_token_logprobs=scores, - normalized_prompt_logprobs=scores, - input_token_logprobs=torch.ones_like(input_ids), - input_top_logprobs=None, - output_top_logprobs=None, - ) + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) + last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings + scores = self.classification_head(last_token_hidden) - return logits_output + return EmbeddingPoolerOutput(scores) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) diff --git a/python/sglang/srt/models/llama_eagle.py b/python/sglang/srt/models/llama_eagle.py new file mode 100644 index 00000000000..09bfbb170c0 --- /dev/null +++ b/python/sglang/srt/models/llama_eagle.py @@ -0,0 +1,132 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Adapted from +# https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py +"""Inference-only LLaMA-EAGLE model compatible with HuggingFace weights.""" + +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import LlamaConfig + +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.models.llama import LlamaDecoderLayer, LlamaForCausalLM + + +class LlamaDecoderLayer(LlamaDecoderLayer): + def __init__( + self, + config: LlamaConfig, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config, layer_id, quant_config, prefix) + + # Skip the input_layernorm + # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427 + if layer_id == 0: + del self.input_layernorm + setattr(self, "input_layernorm", lambda x: x) + + +class LlamaModel(nn.Module): + def __init__( + self, + config: LlamaConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer( + config, i, quant_config=quant_config, prefix=f"model.layers.{i}" + ) + for i in range(config.num_hidden_layers) + ] + ) + self.fc = torch.nn.Linear(config.hidden_size * 2, config.hidden_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + + hidden_states = self.fc( + torch.cat((hidden_states, forward_batch.spec_info.hidden_states), dim=-1) + ) + + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + forward_batch, + residual, + ) + return hidden_states + residual + + +class LlamaForCausalLMEagle(LlamaForCausalLM): + def __init__( + self, + config: LlamaConfig, + quant_config: Optional[QuantizationConfig] = None, + cache_config=None, + ) -> None: + nn.Module.__init__(self) + self.config = config + self.quant_config = quant_config + self.model = LlamaModel(config, quant_config=quant_config) + # Llama 3.2 1B Instruct set tie_word_embeddings to True + # Llama 3.1 8B Instruct set tie_word_embeddings to False + if self.config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) + self.logits_processor = LogitsProcessor(config) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + for name, loaded_weight in weights: + if "lm_head" not in name: + name = "model." + name + super().load_weights([(name, loaded_weight)]) + + +EntryClass = [LlamaForCausalLMEagle] diff --git a/python/sglang/srt/models/llama_reward.py b/python/sglang/srt/models/llama_reward.py index dcde8b468ea..6550ee411a1 100644 --- a/python/sglang/srt/models/llama_reward.py +++ b/python/sglang/srt/models/llama_reward.py @@ -21,7 +21,6 @@ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel @@ -33,7 +32,6 @@ def __init__( ) -> None: super().__init__() self.config = config - self.torchao_config = None self.quant_config = quant_config self.num_labels = config.num_labels self.model = LlamaModel(config, quant_config=quant_config) diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 4c62dbb25f1..c8ce9302b4f 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -57,6 +57,7 @@ def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): else: image_aspect_ratio = "anyres" offset_list = [] + image_inputs.image_pad_len = [] for image_idx, image_s in enumerate(image_sizes): if len(image_sizes) > 16: # 2x2 pooling with stride 2 @@ -103,6 +104,7 @@ def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): + input_ids[offset + 1 :] ) offset_list.append(offset) + image_inputs.image_pad_len.append(new_image_feature_len) image_inputs.image_offsets = offset_list return input_ids @@ -134,6 +136,14 @@ def forward( image_inputs = forward_batch.image_inputs if forward_batch.forward_mode.is_extend(): + # Clamp input ids. This is because the input_ids for the image tokens are + # filled with the hash values of the image for the prefix matching in the radix attention. + # There values are useless because their embeddings will be replaced by vision embeddings anyway. + input_ids.clamp_(min=0, max=self.config.vocab_size - 1) + + # Embed text inputs + input_embeds = self.language_model.model.embed_tokens(input_ids) + # Got List[List[str]] extend it to List[str] # The length of the List should be equal to batch size modalities_list = [] @@ -142,18 +152,12 @@ def forward( if im and im.modalities is not None: modalities_list.extend(im.modalities) if im and im.image_offsets: - max_image_offset.append(max(im.image_offsets)) + max_image_offset.append( + np.max(np.array(im.image_offsets) + np.array(im.image_pad_len)) + ) else: max_image_offset.append(-1) - # Clamp input ids. This is because the input_ids for the image tokens are - # filled with the hash values of the image for the prefix matching in the radix attention. - # There values are useless because their embeddings will be replaced by vision embeddings anyway. - input_ids.clamp_(min=0, max=self.config.vocab_size - 1) - - # Embed text inputs - input_embeds = self.language_model.model.embed_tokens(input_ids) - start_positions = positions[forward_batch.extend_start_loc].cpu().numpy() need_vision = start_positions <= np.array(max_image_offset) @@ -350,6 +354,7 @@ def forward( # Fill in the placeholder for the image extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() + extend_seq_lens = forward_batch.extend_seq_lens.cpu().numpy() prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu pt = 0 for i in range(bs): @@ -357,18 +362,36 @@ def forward( continue start_idx = extend_start_loc_cpu[i] + seq_len = extend_seq_lens[i] prefix_len = prefix_lens_cpu[i] # Multiple images - for j, image_offset in enumerate(image_inputs[i].image_offsets): - if image_offset < prefix_len: + for image_idx, image_offset in enumerate( + image_inputs[i].image_offsets + ): + if ( + image_offset + image_inputs[i].image_pad_len[image_idx] + <= prefix_len + ): continue + if image_offset >= prefix_len + seq_len: + break - tmp_image_feature = image_features[pt][j] + tmp_image_feature = image_features[pt][image_idx] pad_len = tmp_image_feature.shape[0] - left_idx = start_idx + (image_offset - prefix_len) - right_idx = start_idx + (image_offset - prefix_len) + pad_len + input_offset = image_offset - prefix_len + left_idx = start_idx + input_offset + right_idx = left_idx + pad_len + assert right_idx > start_idx + if input_offset < 0: + left_idx = start_idx + tmp_image_feature = tmp_image_feature[-input_offset:] + if right_idx > start_idx + seq_len: + tmp_image_feature = tmp_image_feature[ + : start_idx + seq_len - right_idx + ] + right_idx = start_idx + seq_len try: input_embeds[left_idx:right_idx] = tmp_image_feature except RuntimeError as e: diff --git a/python/sglang/srt/models/minicpm.py b/python/sglang/srt/models/minicpm.py index 3482a828132..f5e69411acc 100644 --- a/python/sglang/srt/models/minicpm.py +++ b/python/sglang/srt/models/minicpm.py @@ -18,9 +18,8 @@ import torch from torch import nn -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.rotary_embedding import get_rope +from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -31,6 +30,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/minicpm3.py b/python/sglang/srt/models/minicpm3.py index b0c93274e2b..31ea7cd9f25 100644 --- a/python/sglang/srt/models/minicpm3.py +++ b/python/sglang/srt/models/minicpm3.py @@ -19,20 +19,20 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.linear import ( + +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( ColumnParallelLinear, MergedColumnParallelLinear, ReplicatedLinear, RowParallelLinear, ) -from vllm.model_executor.layers.rotary_embedding import get_rope - -from sglang.srt.layers.activation import SiluAndMul -from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -40,10 +40,10 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import is_flashinfer_available +from sglang.srt.utils import is_cuda_available -if is_flashinfer_available(): - from flashinfer import bmm_fp8 +if is_cuda_available(): + from sgl_kernel import bmm_fp8 class MiniCPM3MLP(nn.Module): diff --git a/python/sglang/srt/models/minicpmv.py b/python/sglang/srt/models/minicpmv.py new file mode 100644 index 00000000000..7b02b4cedbb --- /dev/null +++ b/python/sglang/srt/models/minicpmv.py @@ -0,0 +1,1291 @@ +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The SGLang team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only MiniCPM-V model compatible with HuggingFace weights.""" +from functools import partial +from typing import ( + Any, + Callable, + Iterable, + List, + Literal, + Optional, + Tuple, + TypedDict, + Union, +) + +import numpy as np +import torch +import torch.types +from PIL import Image +from torch import nn +from torch.nn.init import trunc_normal_ +from transformers import PretrainedConfig + +from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size +from sglang.srt.layers.activation import get_act_fn +from sglang.srt.layers.attention.vision import VisionAttention +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.managers.schedule_batch import ImageInputs +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.utils import set_default_torch_dtype +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM + +RawImageType = Union[Image.Image, torch.Tensor] + + +# sin/cos positional embedding helpers are adapted from: +# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 +def get_1d_sincos_pos_embed_from_grid( + embed_dim: int, pos: np.ndarray, version: Tuple[int, int] = (2, 0) +) -> torch.Tensor: + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) / (H, W) + out: (M, D) / (H, W, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + if version == (2, 0): + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + else: + out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product + emb_sin = np.sin(out) # (H, W, D/2) + emb_cos = np.cos(out) # (H, W, D/2) + emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D) + return emb + + +def get_2d_sincos_pos_embed_from_grid( + embed_dim: int, grid: np.ndarray, version: Tuple[int, int] = (2, 0) +) -> torch.Tensor: + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[0], version + ) # (H*W, D/2) or (H, W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[1], version + ) # (H*W, D/2) or (H, W, D/2) + + if version == (2, 0): + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + else: + emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D) + return emb + + +def get_2d_sincos_pos_embed( + embed_dim: int, + grid_size: Union[int, Tuple[int, int]], + cls_token: bool = False, + version: Tuple[int, int] = (2, 0), +) -> torch.Tensor: + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, int): + grid_h_size, grid_w_size = grid_size, grid_size + else: + grid_h_size, grid_w_size = grid_size[0], grid_size[1] + + grid_h = np.arange(grid_h_size, dtype=np.float32) + grid_w = np.arange(grid_w_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + assert isinstance(grid, np.ndarray) and grid.shape == (2, grid_h_size, grid_w_size) + + if version == (2, 0): + grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + else: + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) + return pos_embed + + +class Idefics2VisionMLP(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + +class Idefics2EncoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.embed_dim = config.hidden_size + + self.num_heads = config.num_attention_heads + tp_size = get_tensor_model_parallel_world_size() + num_heads_per_partition = divide(self.num_heads, tp_size) + self.self_attn = VisionAttention( + embed_dim=config.hidden_size, + num_heads=num_heads_per_partition, + projection_size=config.intermediate_size, + use_qkv_parallel=True, + quant_config=quant_config, + dropout=config.attention_dropout, + use_context_forward=False, + use_full_precision_softmax=True, + flatten_batch=False, + prefix=f"{prefix}.self_attn", + ) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = Idefics2VisionMLP(config, quant_config=quant_config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + + """ + residual = hidden_states + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn(hidden_states, cu_seqlens=cu_seqlens) + + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Idefics2Encoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention + layers. Each layer is a + [`Idefics2EncoderLayer`]. + + Args: + config: Idefics2Config + """ + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.layers = nn.ModuleList( + [ + Idefics2EncoderLayer( + config, + quant_config=quant_config, + ) + for _ in range(config.num_hidden_layers) + ] + ) + + def forward( + self, + inputs_embeds: torch.Tensor, + cu_seqlens: torch.Tensor, + ) -> torch.Tensor: + r""" + Args: + inputs_embeds (torch.Tensor): + Optionally, instead of passing `input_ids` you can choose to + directly pass an embedded representation. + This is useful if you want more control over how to convert + `input_ids` indices into associated vectorsthan the model's + internal embedding lookup matrix. + """ + hidden_states = inputs_embeds + for encoder_layer in self.layers: + layer_outputs = encoder_layer( + hidden_states, + cu_seqlens=cu_seqlens, + ) + hidden_states = layer_outputs + return hidden_states + + +class Idefics2VisionEmbeddings(nn.Module): + """ + This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings + ` to enable images of variable + resolution. + + The modifications are adapted from [Patch n' Pack: NaViT, a Vision + Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304) + which allows treating images in their native aspect ratio and without the + need to resize them to the same fixed size. In particular, we start from the + original pre-trained SigLIP model(which uses images of fixed-size square + images) and adapt it by training on images of variable resolutions. + """ + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + + def get_position_ids( + self, + pixel_values: torch.FloatTensor, + patch_attention_mask: torch.BoolTensor, + tgt_sizes: Optional[torch.IntTensor] = None, + ): + batch_size, _, max_im_h, max_im_w = pixel_values.shape + + max_nb_patches_h, max_nb_patches_w = ( + max_im_h // self.patch_size, + max_im_w // self.patch_size, + ) + boundaries = torch.arange( + 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side + ) + position_ids = torch.full( + size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0 + ) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + + if tgt_sizes is not None: + nb_patches_h = tgt_sizes[batch_idx][0] + nb_patches_w = tgt_sizes[batch_idx][1] + else: + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + bucket_coords_h = torch.bucketize( + fractional_coords_h, boundaries, right=True + ) + bucket_coords_w = torch.bucketize( + fractional_coords_w, boundaries, right=True + ) + pos_ids = ( + bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w + ).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + position_ids = position_ids.to(self.position_embedding.weight.device) + return position_ids + + def forward( + self, + pixel_values: torch.FloatTensor, + patch_attention_mask: torch.BoolTensor, + tgt_sizes: Optional[torch.IntTensor] = None, + ) -> torch.Tensor: + target_dtype = self.patch_embedding.weight.dtype + pixel_values = pixel_values.to( + device=self.patch_embedding.weight.device, dtype=target_dtype + ) + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + position_ids = self.get_position_ids( + pixel_values, patch_attention_mask, tgt_sizes + ) + + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +class Idefics2VisionTransformer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + + embed_dim = config.hidden_size + self.config = config + self.embeddings = Idefics2VisionEmbeddings(config) + self.encoder = Idefics2Encoder(config=config, quant_config=quant_config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + def get_input_embeddings(self): + return self.embeddings + + def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor: + patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] # shape: (batch_size,) + cu_seqlens = torch.cat( + [ + torch.tensor([0], device=patch_len.device, dtype=torch.int32), + torch.cumsum(patch_len, dim=0, dtype=torch.int32), + ], + dim=0, + ).to(tgt_sizes.device) + return cu_seqlens + + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + tgt_sizes: Optional[torch.IntTensor] = None, + ) -> torch.Tensor: + hidden_states = self.embeddings( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + tgt_sizes=tgt_sizes, + ) + cu_seqlens = self.compute_cu_seqlens(tgt_sizes) + encoder_outputs = self.encoder( + hidden_states, + cu_seqlens=cu_seqlens, + ) + last_hidden_state = self.post_layernorm(encoder_outputs) + return last_hidden_state + + +class MiniCPMVImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: List[torch.Tensor] + """ + Shape: `(batch_size * num_images, num_channels, height, width)` + + Note that the image size may vary, so we pass it as a list + instead of a batched tensor. + """ + + image_bounds: torch.Tensor + """ + Shape: `(batch_size * num_images, 2)` + + This should be in `(start, stop)` format. + """ + + tgt_sizes: torch.Tensor + """ + Shape: `(batch_size * num_images, 2)` + + This should be in `(height, width)` format. + """ + + +class MiniCPMVImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + """ + Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + instead of a batched tensor. + """ + + image_bounds: torch.Tensor + """ + Shape: `(batch_size * num_images, 2)` + + This should be in `(start, stop)` format. + """ + + +MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs, MiniCPMVImageEmbeddingInputs] + +DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) + + +class BaseResampler(nn.Module): + """ + A 2D perceiver-resampler network with one cross attention layers by + (grid_size**2) learnable queries and 2d sincos pos_emb. + Outputs: + A tensor with the shape of (grid_size**2, embed_dim) + """ + + def __init__( + self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + do_post_projection: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.num_queries = num_queries + self.embed_dim = embed_dim + self.num_heads = num_heads + + self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) + trunc_normal_(self.query, std=0.02) + if kv_dim is not None and kv_dim != embed_dim: + self.kv_proj = ReplicatedLinear( + kv_dim, + embed_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_proj", + ) + else: + # Maintain the same return value with ReplicatedLinear.forward + self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa + nn.Identity()(*args, **kwargs), + None, + ) + self.attn = nn.MultiheadAttention(embed_dim, num_heads) + self.ln_q = norm_layer(embed_dim) + self.ln_kv = norm_layer(embed_dim) + self.do_post_projection = do_post_projection + self.ln_post = norm_layer(embed_dim) if do_post_projection else None + self.proj = ( + nn.Parameter((embed_dim**-0.5) * torch.randn(embed_dim, embed_dim)) + if do_post_projection + else None + ) + + def _init_weights(self, m: nn.Module) -> None: + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def _repeat(self, query, N: int): + return query.unsqueeze(1).repeat(1, N, 1) + + +class Resampler2_5(BaseResampler): + + def __init__( + self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + max_size: Tuple[int, int] = (70, 70), + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__( + num_queries, + embed_dim, + num_heads, + kv_dim, + norm_layer, + quant_config=quant_config, + prefix=prefix, + ) + + self.max_size = max_size + self._set_2d_pos_cache(self.max_size) + + self.apply(self._init_weights) + + def _set_2d_pos_cache( + self, max_size: Tuple[int, int], device: torch.types.Device = "cpu" + ) -> None: + pos_embed_arr = get_2d_sincos_pos_embed( + self.embed_dim, max_size, version=(2, 5) + ) + pos_embed = torch.from_numpy(pos_embed_arr).float().to(device) + self.register_buffer("pos_embed", pos_embed, persistent=False) + + def _adjust_pos_cache( + self, tgt_sizes: torch.Tensor, device: torch.types.Device + ) -> None: + max_h = tgt_sizes[:, 0].max().item() + max_w = tgt_sizes[:, 1].max().item() + assert isinstance(max_h, int) and isinstance(max_w, int) + + if max_h > self.max_size[0] or max_w > self.max_size[1]: + self.max_size = ( + max(max_h, self.max_size[0]), + max(max_w, self.max_size[1]), + ) + self._set_2d_pos_cache(self.max_size, device) + + def forward(self, x: torch.Tensor, tgt_sizes: torch.Tensor) -> torch.Tensor: + assert x.shape[0] == tgt_sizes.shape[0] + bs = x.shape[0] + + device = x.device + dtype = x.dtype + + patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] + + self._adjust_pos_cache(tgt_sizes, device=device) + + max_patch_len = patch_len.max().item() + assert isinstance(max_patch_len, int) + + key_padding_mask = torch.zeros( + (bs, max_patch_len), dtype=torch.bool, device=device + ) + + pos_embed = [] + for i in range(bs): + tgt_h, tgt_w = tgt_sizes[i].tolist() + pos_embed.append( + self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype) + ) # patches * D + key_padding_mask[i, patch_len[i] :] = True + pos_embed = torch.nn.utils.rnn.pad_sequence( + pos_embed, batch_first=True, padding_value=0.0 + ).permute( + 1, 0, 2 + ) # BLD => L * B * D + x, _ = self.kv_proj(x) # B * L * D + x = self.ln_kv(x).permute(1, 0, 2) # L * B * D + + q = self.ln_q(self.query) # Q * D + + out = self.attn( + self._repeat(q, bs), # Q * B * D + x + pos_embed, # L * B * D + L * B * D + x, + key_padding_mask=key_padding_mask, + )[0] + # out: Q * B * D + x = out.permute(1, 0, 2) # B * Q * D + + x = self.ln_post(x) + x = x @ self.proj + return x + + +def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]: + version_float = getattr(config, "version", None) + + # The old configs do not include version number + # TODO: Remove this after the HF repos are updated + if version_float is None: + if config.hidden_size == 2304 and config.query_num == 64: + return 2, 0 + return 2, 5 + + version_str = str(version_float) + return tuple(int(x) for x in version_str.split(".")) + + +class MiniCPMVBaseModel(nn.Module): + """ + The abstract class of MiniCPMV can only be inherited, but cannot be + instantiated. + """ + + def __init__( + self, + *, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + # All MiniCPM-V models disable `tie_word_embeddings` but + # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot + # check `tie_word_embeddings` until SGLang integrate MiniCPM-V model + # and config class + self.config = config + + self.version = get_version_by_config(self.config) + self.llm = self.init_llm(config=config, quant_config=quant_config) + self.vpm = self.init_vision_module(config, quant_config) + self.vision_dim = ( + self.vpm.embed_dim + if self.version == (2, 0) + else self.vpm.embeddings.embed_dim + ) + self.embed_dim = self.config.hidden_size + + self.resampler = self.init_resampler( + self.embed_dim, self.vision_dim, quant_config=quant_config + ) + + self.logits_processor = LogitsProcessor(config) + + def _get_image_bounds( + self, + input_ids: torch.Tensor, + pad_values: List[int], + im_start_id: torch.Tensor, + im_end_id: torch.Tensor, + slice_start_id: Optional[torch.Tensor] = None, + slice_end_id: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Returns a tensor indicating the bounds (start and end token ids) of the images + """ + # All the images in the batch should share the same special image + # bound token ids. + start_cond = input_ids == im_start_id[0] + end_cond = input_ids == im_end_id[0] + if slice_start_id is not None: + start_cond |= input_ids == slice_start_id[0] + end_cond |= input_ids == slice_end_id[0] + + (image_start_tokens,) = torch.where(start_cond) + image_start_tokens += 1 + (image_end_tokens,) = torch.where(end_cond) + + # the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the images + if len(image_start_tokens) != len(image_end_tokens): + if ( + len(image_start_tokens) + 1 == len(image_end_tokens) + and input_ids[0] in pad_values + and image_end_tokens[0] < image_start_tokens[0] + ): + image_start_tokens = torch.cat( + [ + torch.tensor([0], device=image_start_tokens.device), + image_start_tokens, + ] + ) + valid_image_nums = min(len(image_start_tokens), len(image_end_tokens)) + + if valid_image_nums == 0: + return torch.zeros((0, 2), device=input_ids.device) + + # Filter out pairs where start_token >= end_token + valid_pairs = [] + for i in range(valid_image_nums): + start_token = image_start_tokens[i] + end_token = image_end_tokens[i] + if start_token < end_token: + valid_pairs.append((start_token, end_token)) + + if not valid_pairs: + return torch.zeros((0, 2), device=input_ids.device) + + # Convert valid pairs to tensor + valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device) + return valid_pairs_tensor + + def get_embedding( + self, + input_ids: torch.Tensor, + image_inputs: Optional[MiniCPMVImageInputs], + ) -> Tuple[torch.Tensor, torch.Tensor]: + vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids) + + if image_inputs is None: # No image + vision_hidden_states = torch.tensor([], device=input_ids.device) + else: + if image_inputs["type"] == "image_embeds": + vision_hidden_states = ( + image_inputs["data"] + .type(vlm_embedding.dtype) + .to(vlm_embedding.device) + ) + else: + vision_hidden_states = self.get_vision_hidden_states(image_inputs) + # See NOTE in _parse_and_validate_inputs + image_bounds = image_inputs["image_bounds"] + if len(image_bounds) > 0: + image_indices = torch.stack( + [ + torch.arange(start, end, dtype=torch.long) + for start, end in image_bounds.tolist() + ] + ).to(vlm_embedding.device) + + vlm_embedding.scatter_( + 0, + image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]), + vision_hidden_states.view(-1, vision_hidden_states.shape[-1]), + ) + + return vlm_embedding, vision_hidden_states + + def _parse_and_validate_inputs( + self, + input_ids: torch.Tensor, + **kwargs: object, + ) -> Optional[MiniCPMVImageInputs]: + pixel_values = kwargs.pop("pixel_values", []) + tgt_sizes = kwargs.pop("tgt_sizes", []) + im_start_id = kwargs.pop("im_start_id", None) + im_end_id = kwargs.pop("im_end_id", None) + slice_start_id = kwargs.pop("slice_start_id", None) + slice_end_id = kwargs.pop("slice_end_id", None) + image_embeds = kwargs.pop("image_embeds", None) + pad_values = kwargs.pop("pad_values", None) + + if image_embeds is not None: + image_bounds = self._get_image_bounds( + input_ids=input_ids, + pad_values=pad_values, + im_start_id=im_start_id, + im_end_id=im_end_id, + slice_start_id=slice_start_id, + slice_end_id=slice_end_id, + ) + if not isinstance(image_embeds, (torch.Tensor, list)): + raise ValueError( + f"Incorrect type of image embeds. " + f"Got type: {type(image_embeds)}" + ) + + if isinstance(image_embeds, list): + image_embeds = torch.concat(image_embeds) + + return MiniCPMVImageEmbeddingInputs( + image_bounds=image_bounds, + data=image_embeds, + type="image_embeds", + ) + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError( + "Incorrect type of pixel values. " f"Got type: {type(pixel_values)}" + ) + + if not isinstance(tgt_sizes, (torch.Tensor, list)): + raise ValueError( + "Incorrect type of target sizes. " f"Got type: {type(tgt_sizes)}" + ) + + if len(pixel_values) != len(tgt_sizes): + raise ValueError( + "Inconsistent batch lengths, found: " + f"{len(pixel_values)} vs. {len(tgt_sizes)}" + ) + + pixel_values_flat: List[torch.Tensor] = [] + tgt_sizes_flat: List[torch.Tensor] = [] + for pixel_b, tgt_b in zip(pixel_values, tgt_sizes): + if len(pixel_b) != len(tgt_b): + raise ValueError( + "Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}" + ) + + for pixel_n, tgt_n in zip(pixel_b, tgt_b): + pixel_values_flat += pixel_n + tgt_sizes_flat += tgt_n + + # NOTE: Input IDs does not contain image tokens during memory profiling, + # so we allow it to be empty + if len(pixel_values_flat) != len(tgt_sizes_flat): + raise ValueError( + "Inconsistent flattened lengths, found: " + f"{len(pixel_values_flat)} vs. " + f"{len(tgt_sizes_flat)}" + ) + + if len(pixel_values_flat) == 0: + return None + + image_bounds = self._get_image_bounds( + input_ids=input_ids, + pad_values=pad_values, + im_start_id=im_start_id, + im_end_id=im_end_id, + slice_start_id=slice_start_id, + slice_end_id=slice_end_id, + ) + return MiniCPMVImagePixelInputs( + image_bounds=image_bounds.to(device=input_ids.device), + data=pixel_values_flat, + tgt_sizes=torch.stack(tgt_sizes_flat), + type="pixel_values", + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + **kwargs: Any, + ) -> torch.Tensor: + if forward_batch.image_inputs is not None and forward_batch.image_inputs != [ + None + ]: + kwargs.update( + { + "pixel_values": ( + None + if forward_batch.image_inputs is None + else [ + i.pixel_values + for i in forward_batch.image_inputs + if i is not None + ] + ), + "tgt_sizes": ( + None + if forward_batch.image_inputs is None + else [ + i.tgt_sizes + for i in forward_batch.image_inputs + if i is not None + ] + ), + "im_start_id": forward_batch.image_inputs[0].im_start_id, + "im_end_id": forward_batch.image_inputs[0].im_end_id, + "slice_start_id": forward_batch.image_inputs[0].slice_start_id, + "slice_end_id": forward_batch.image_inputs[0].slice_end_id, + "pad_values": forward_batch.image_inputs[0].pad_values, + } + ) + + image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs) + + # Clamp input ids. This is because the input_ids for the image tokens are + # filled with the hash values of the image for the prefix matching in the radix attention. + # There values are useless because their embeddings will be replaced by vision embeddings anyway. + input_ids.clamp_(min=0, max=self.config.vocab_size - 1) + + vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs) + + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent + # for `torch.compile` integration + input_ids = None + + hidden_states = self.llm.model( + input_ids=input_ids, + positions=positions, + forward_batch=forward_batch, + input_embeds=vlm_embeddings, + ) + + return self.logits_processor( + input_ids, hidden_states, self.llm.lm_head, forward_batch + ) + + def init_llm( + self, + config: Qwen2Config, + quant_config: Optional[QuantizationConfig] = None, + ) -> nn.Module: + raise NotImplementedError + + def init_vision_module( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + ) -> nn.Module: + raise NotImplementedError + + def init_resampler( + self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + ) -> nn.Module: + raise NotImplementedError + + def get_vision_embedding( + self, + pixel_values: List[torch.Tensor], + patch_attn_mask: Optional[torch.Tensor] = None, + tgt_sizes: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError + + def get_vision_hidden_states(self, data: MiniCPMVImageInputs) -> torch.Tensor: + raise NotImplementedError + + +class MiniCPMV2_6(MiniCPMVBaseModel): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + # LoRA specific attributes + supported_lora_modules = [ + # vision encoder + "fc1", + "fc2", + "out_proj", + # language model + "qkv_proj", # same name with vision encoder + "o_proj", + "gate_up_proj", + "down_proj", + # resampler + "kv_proj", + ] + + # BitandBytes specific attributes + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + embedding_modules = {} + embedding_padding_modules = [] + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__(config=config, quant_config=quant_config) + assert self.version == (2, 6) + + def init_llm( + self, + config: Qwen2Config, + quant_config: Optional[QuantizationConfig] = None, + ) -> nn.Module: + return Qwen2ForCausalLM(config=config, quant_config=quant_config) + + def init_vision_module( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + ) -> nn.Module: + model = Idefics2VisionTransformer( + config=config.vision_config, quant_config=quant_config + ) + if self.config.drop_vision_last_layer: + model.encoder.layers = model.encoder.layers[:-1] + + setattr(model, "embed_dim", model.embeddings.embed_dim) + setattr(model, "patch_size", model.embeddings.patch_size) + return model + + def init_resampler( + self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + ) -> nn.Module: + with set_default_torch_dtype(torch.float16): + # The resampler in 2.6 remains consistent with the one in 2.5. + resampler = Resampler2_5( + num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + quant_config=quant_config, + ) + + return resampler.to(device="cuda", dtype=torch.get_default_dtype()) + + def get_vision_embedding( + self, + pixel_values: List[torch.Tensor], + patch_attn_mask: Optional[torch.Tensor] = None, + tgt_sizes: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + vision_embedding = self.vpm( + pixel_values, + patch_attention_mask=patch_attn_mask, + tgt_sizes=tgt_sizes, + ) + return vision_embedding + + def get_vision_hidden_states( + self, + data: MiniCPMVImageInputs, + ) -> torch.Tensor: + pixel_values = data["data"] + tgt_sizes = data["tgt_sizes"] + + device = self.vpm.embeddings.position_embedding.weight.device + dtype = self.vpm.embeddings.position_embedding.weight.dtype + all_pixel_values_lst = [ + i.flatten(end_dim=1).permute(1, 0) for i in pixel_values + ] + + max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item() + assert isinstance(max_patches, int) + + all_pixel_values = torch.nn.utils.rnn.pad_sequence( + all_pixel_values_lst, batch_first=True, padding_value=0.0 + ) + B, L, _ = all_pixel_values.shape + all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L) + patch_attn_mask = torch.zeros( + (B, 1, max_patches), dtype=torch.bool, device=device + ) + + tgt_sizes_tensor = tgt_sizes.clone().to(device=patch_attn_mask.device) + mask_shapes = tgt_sizes_tensor[:, 0] * tgt_sizes_tensor[:, 1] + patch_attn_mask[:, 0, :] = torch.arange( + patch_attn_mask.size(2), device=patch_attn_mask.device + ).unsqueeze(0) < mask_shapes.unsqueeze(1) + + vision_embedding = self.vpm( + all_pixel_values.type(dtype), + patch_attention_mask=patch_attn_mask, + tgt_sizes=tgt_sizes, + ) + return self.resampler(vision_embedding, tgt_sizes) + + def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): + if not isinstance(image_inputs.im_start_id, list) or not isinstance( + image_inputs.im_end_id, list + ): + return input_ids + + new_input_ids = [] + last_idx = 0 + image_idx = -1 + image_inputs.image_offsets = [] + + # Get all special token IDs + im_start_id = ( + image_inputs.im_start_id[0].item() + if isinstance(image_inputs.im_start_id[0], torch.Tensor) + else image_inputs.im_start_id[0] + ) + im_end_id = ( + image_inputs.im_end_id[0].item() + if isinstance(image_inputs.im_end_id[0], torch.Tensor) + else image_inputs.im_end_id[0] + ) + slice_start_id = ( + image_inputs.slice_start_id[0].item() + if isinstance(image_inputs.slice_start_id[0], torch.Tensor) + else image_inputs.slice_start_id[0] + ) + slice_end_id = ( + image_inputs.slice_end_id[0].item() + if isinstance(image_inputs.slice_end_id[0], torch.Tensor) + else image_inputs.slice_end_id[0] + ) + + # Find all start and end positions for both types + start_indices = [ + i + for i, x in enumerate(input_ids) + if x == im_start_id or x == slice_start_id + ] + end_indices = [ + i for i, x in enumerate(input_ids) if x == im_end_id or x == slice_end_id + ] + + if len(start_indices) != len(end_indices): + return input_ids + # Process each region (both image and slice) + for start_idx, end_idx in zip(start_indices, end_indices): + # Add non-image tokens before this region + new_input_ids.extend( + input_ids[last_idx : start_idx + 1] + ) # include start token + + is_image_start = input_ids[start_idx] == im_start_id + + if is_image_start: + image_inputs.image_offsets += [start_idx] + image_idx += 1 + + num_tokens = end_idx - start_idx - 1 # exclude start and end tokens + + # Generate pad_ids + pad_values = [image_inputs.pad_values[image_idx]] + + pad_ids = pad_values * ((num_tokens + len(pad_values)) // len(pad_values)) + pad_ids = pad_ids[:num_tokens] + + # Add pad_ids + new_input_ids.extend(pad_ids) + + # Update last_idx to after end token + last_idx = end_idx + + # Add remaining tokens after last region + new_input_ids.extend(input_ids[last_idx:]) + assert len(input_ids) == len(new_input_ids) + return new_input_ids + + +_SUPPORT_VERSION = {(2, 6): MiniCPMV2_6} + + +class MiniCPMV: + """ + Different versions of MiniCPMV use different visual encoders and LLMs, + which is not conducive to the current integration logic of LoRA and + bitsandbytes in SGLang. Therefore, it is necessary to separate them. + """ + + # Ensure that the LoRA support check passes when the class is not + # initialized, but set all these attributes to empty. + packed_modules_mapping = {} + supported_lora_modules = [] + embedding_modules = {} + embedding_padding_modules = [] + + minicpmv: nn.Module + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + + if not hasattr(config, "version"): + version = (2, 6) + else: + version = str(config.version).split(".") + version = tuple([int(x) for x in version]) + # Dispatch class based on version + instance_class = _SUPPORT_VERSION.get(version) + if instance_class is None: + raise ValueError("Currently, MiniCPMV only supports versions 2.6") + + try: + minicpmv = instance_class(config=config, quant_config=quant_config) + self.minicpmv = minicpmv + except Exception as e: + print(f"Failed to instantiate MiniCPMV: {e}") + raise e + self.config = config + + def __getattr__(self, name): + if name == "minicpmv": + return None + return getattr(self.minicpmv, name) + + def __call__(self, *args, **kwargs): + return self.minicpmv(*args, **kwargs) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.minicpmv.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq~" in name or "projector" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if name.startswith("model.vision_tower") and name not in params_dict: + continue + + # adapt to VisionAttention + name = name.replace(r"self_attn.out_proj", r"self_attn.proj") + + if "sampler" in name: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + # replace the name and load with customized loader + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +EntryClass = MiniCPMV diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index f3fad226091..4ea734836af 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -21,14 +21,11 @@ import torch from torch import nn from transformers import MixtralConfig -from vllm.distributed import ( + +from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.layers.rotary_embedding import get_rope - -from sglang.srt.layers.ep_moe.layer import EPMoE -from sglang.srt.layers.fused_moe_triton import FusedMoE from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( QKVParallelLinear, @@ -36,8 +33,11 @@ RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.ep_moe.layer import EPMoE +from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/mixtral_quant.py b/python/sglang/srt/models/mixtral_quant.py index e5f49f5662f..244dc7df2d0 100644 --- a/python/sglang/srt/models/mixtral_quant.py +++ b/python/sglang/srt/models/mixtral_quant.py @@ -23,13 +23,12 @@ import torch.nn.functional as F from torch import nn from transformers import MixtralConfig -from vllm.distributed import ( + +from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.layers.rotary_embedding import get_rope - from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( QKVParallelLinear, @@ -39,6 +38,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/mllama.py b/python/sglang/srt/models/mllama.py index 019d21c2086..05069edb69b 100644 --- a/python/sglang/srt/models/mllama.py +++ b/python/sglang/srt/models/mllama.py @@ -8,15 +8,16 @@ import torch.nn.functional as F import torch.utils.checkpoint import transformers.models.mllama.configuration_mllama as config_mllama -import vllm.distributed.parallel_state as ps from torch import nn from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast from transformers.models.mllama.modeling_mllama import ( _prepare_aspect_ratio_attention_mask, ) -from vllm.distributed import get_tensor_model_parallel_world_size +import sglang.srt.distributed.parallel_state as ps +from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import get_act_fn +from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( ColumnParallelLinear, @@ -145,61 +146,6 @@ def forward( return hidden_state -class MllamaVisionSdpaAttention(nn.Module): - def __init__(self, config: config_mllama.MllamaVisionConfig): - super().__init__() - - model_parallel_size = get_tensor_model_parallel_world_size() - self.embed_dim = config.hidden_size - self.num_heads = config.attention_heads - self.head_dim = config.hidden_size // config.attention_heads - self.num_local_heads = self.num_heads // model_parallel_size - self.q_size = self.num_local_heads * self.head_dim - self.kv_size = self.num_local_heads * self.head_dim - - self.qkv_proj = QKVParallelLinear( - self.embed_dim, - self.head_dim, - self.num_heads, - bias=False, - ) - self.o_proj = RowParallelLinear( - self.num_heads * self.head_dim, - self.embed_dim, - bias=False, - input_is_parallel=True, - ) - - def forward( - self, - hidden_state: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_state) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q = q.view( - q.shape[0], q.shape[1], self.num_local_heads, self.head_dim - ).transpose(1, 2) - k = k.view( - k.shape[0], k.shape[1], self.num_local_heads, self.head_dim - ).transpose(1, 2) - v = v.view( - v.shape[0], v.shape[1], self.num_local_heads, self.head_dim - ).transpose(1, 2) - - # TODO: remove padding in image encoder - attn_output = F.scaled_dot_product_attention( - q, k, v, attn_mask=attention_mask, dropout_p=0.0 - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape( - attn_output.shape[0], attn_output.shape[1], -1 - ) - output, _ = self.o_proj(attn_output) - return output - - class MllamaVisionMLP(nn.Module): def __init__(self, config, quant_config: Optional[QuantizationConfig] = None): super().__init__() @@ -237,7 +183,17 @@ def __init__( self.is_gated = is_gated self.intermediate_size = config.intermediate_size - self.self_attn = MllamaVisionSdpaAttention(config) + self.self_attn = VisionAttention( + self.hidden_size, + self.num_attention_heads, + self.hidden_size, + use_qkv_parallel=True, + quant_config=None, + dropout=0.0, + use_context_forward=False, + use_full_precision_softmax=False, + flatten_batch=False, + ) self.mlp = MllamaVisionMLP(config) self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps) @@ -992,6 +948,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, shard_id) break else: + if "vision_model" in name: + # adapt to VisionAttention + name = name.replace("self_attn.o_proj", "self_attn.proj") + param = params_dict.pop(name) weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/models/olmo.py b/python/sglang/srt/models/olmo.py index 1cfa27309fe..4d8a79900f4 100644 --- a/python/sglang/srt/models/olmo.py +++ b/python/sglang/srt/models/olmo.py @@ -15,14 +15,13 @@ # Adapted from # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/olmo.py#L1 """Inference-only OLMo model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, Optional, Tuple import torch from torch import nn from transformers import OlmoConfig -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.rotary_embedding import get_rope +from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.linear import ( MergedColumnParallelLinear, @@ -32,6 +31,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/olmo2.py b/python/sglang/srt/models/olmo2.py old mode 100755 new mode 100644 index 0944b572092..f3e1979f849 --- a/python/sglang/srt/models/olmo2.py +++ b/python/sglang/srt/models/olmo2.py @@ -21,15 +21,13 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.distributed import ( + +from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, tensor_model_parallel_all_gather, ) -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader - from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -40,11 +38,13 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import make_layers diff --git a/python/sglang/srt/models/olmoe.py b/python/sglang/srt/models/olmoe.py index 859f4135c4b..10b781d72ff 100644 --- a/python/sglang/srt/models/olmoe.py +++ b/python/sglang/srt/models/olmoe.py @@ -17,30 +17,24 @@ """Inference-only OLMoE model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, Optional, Tuple import torch -import torch.nn.functional as F from torch import nn from transformers import PretrainedConfig -from vllm.distributed import ( - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce, -) -from vllm.model_executor.layers.linear import ( - MergedColumnParallelLinear, + +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) -from vllm.model_executor.layers.rotary_embedding import get_rope - -from sglang.srt.layers.activation import SiluAndMul -from sglang.srt.layers.fused_moe_triton import FusedMoE -from sglang.srt.layers.layernorm import RMSNorm -from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/phi3_small.py b/python/sglang/srt/models/phi3_small.py index 1e70c7d7874..b7195dbaa28 100644 --- a/python/sglang/srt/models/phi3_small.py +++ b/python/sglang/srt/models/phi3_small.py @@ -5,9 +5,8 @@ from torch import nn from transformers import Phi3Config from transformers.configuration_utils import PretrainedConfig -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.rotary_embedding import get_rope +from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, @@ -17,6 +16,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index 5492a3e1221..2c99da926b6 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -20,9 +20,8 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.rotary_embedding import get_rope +from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -33,6 +32,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 9383fde4d09..46b62f837f6 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -20,9 +20,11 @@ import torch from torch import nn -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.rotary_embedding import get_rope +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -34,12 +36,16 @@ from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + kv_cache_scales_loader, +) from sglang.srt.utils import make_layers Qwen2Config = None @@ -242,6 +248,12 @@ def __init__( ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + if hasattr(self.config, "scale_emb"): + return self.embed_tokens(input_ids) * self.config.scale_emb + else: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -265,8 +277,50 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + # If this function is called, it should always initialize KV cache scale + # factors (or else raise an exception). Thus, handled exceptions should + # make sure to leave KV cache scale factors in a known good (dummy) state + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, + tp_rank, + tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type, + ): + if not isinstance(self.layers[layer_idx], nn.Identity): + layer_self_attn = self.layers[layer_idx].self_attn + if hasattr(layer_self_attn.attn, "k_scale"): + layer_self_attn.attn.k_scale = scaling_factor + layer_self_attn.attn.v_scale = scaling_factor + else: + raise RuntimeError( + "Self attention has no KV cache scaling " "factor attribute!" + ) + class Qwen2ForCausalLM(nn.Module): + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".gate_proj.", + ".down_proj.", + ".up_proj.", + ".q_proj.", + ".k_proj.", + ".v_proj.", + ".o_proj.", + ] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + def __init__( self, config: Qwen2Config, @@ -285,6 +339,9 @@ def __init__( self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + @torch.no_grad() def forward( self, @@ -342,5 +399,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + def get_embed_and_head(self): + return self.model.embed_tokens.weight, self.lm_head.weight + + def set_embed_and_head(self, embed, head): + del self.model.embed_tokens.weight + del self.lm_head.weight + self.model.embed_tokens.weight = embed + self.lm_head.weight = head + torch.cuda.empty_cache() + torch.cuda.synchronize() + + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + self.model.load_kv_cache_scales(quantization_param_path) + EntryClass = Qwen2ForCausalLM diff --git a/python/sglang/srt/models/qwen2_eagle.py b/python/sglang/srt/models/qwen2_eagle.py new file mode 100644 index 00000000000..01069ef482c --- /dev/null +++ b/python/sglang/srt/models/qwen2_eagle.py @@ -0,0 +1,131 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Adapted from +# https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py +"""Inference-only LLaMA-EAGLE model compatible with HuggingFace weights.""" + +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn + +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.models.qwen2 import Qwen2DecoderLayer, Qwen2ForCausalLM + +Qwen2Config = None + + +class Qwen2DecoderLayer(Qwen2DecoderLayer): + def __init__( + self, + config: Qwen2Config, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config, layer_id, quant_config) + + # Skip the input_layernorm + # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427 + if layer_id == 0: + del self.input_layernorm + setattr(self, "input_layernorm", lambda x: x) + + +class Qwen2Model(nn.Module): + def __init__( + self, + config: Qwen2Config, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList( + [ + Qwen2DecoderLayer( + config, i, quant_config=quant_config, prefix=f"model.layers.{i}" + ) + for i in range(config.num_hidden_layers) + ] + ) + self.fc = torch.nn.Linear(config.hidden_size * 2, config.hidden_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + + hidden_states = self.fc( + torch.cat((hidden_states, forward_batch.spec_info.hidden_states), dim=-1) + ) + + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + forward_batch, + residual, + ) + return hidden_states + residual + + +class Qwen2ForCausalLMEagle(Qwen2ForCausalLM): + def __init__( + self, + config: Qwen2Config, + quant_config: Optional[QuantizationConfig] = None, + cache_config=None, + ) -> None: + nn.Module.__init__(self) + self.config = config + self.quant_config = quant_config + self.model = Qwen2Model(config, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) + self.logits_processor = LogitsProcessor(config) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + for name, loaded_weight in weights: + if "lm_head" not in name: + name = "model." + name + super().load_weights([(name, loaded_weight)]) + + +EntryClass = [Qwen2ForCausalLMEagle] diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 62cd3281d03..6183f30daf4 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -22,14 +22,12 @@ import torch.nn.functional as F from torch import nn from transformers import PretrainedConfig -from vllm.distributed import ( + +from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.layers.rotary_embedding import get_rope - from sglang.srt.layers.activation import SiluAndMul -from sglang.srt.layers.fused_moe_triton import FusedMoE from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( MergedColumnParallelLinear, @@ -38,8 +36,10 @@ RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index 2e9ec9d8f50..365891544e0 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -22,6 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" +import logging from functools import lru_cache, partial from typing import Iterable, List, Optional, Tuple, Type, TypedDict @@ -29,17 +30,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange, repeat -from vllm.distributed import parallel_state -from vllm.distributed import utils as dist_utils -from vllm.logger import init_logger +from einops import rearrange from vllm.model_executor.layers.activation import QuickGELU from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig from sglang.srt.hf_transformers_utils import get_processor -from sglang.srt.layers.attention.triton_ops.prefill_attention import ( - context_attention_fwd, -) +from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.pooler import Pooler, PoolingType @@ -50,7 +46,8 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2 import Qwen2Model -logger = init_logger(__name__) +logger = logging.getLogger(__name__) + # === Vision Inputs === # @@ -110,118 +107,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: - if not interleaved: - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - else: - x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange( - torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 - ) - - -def apply_rotary_emb_torch( - x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False -) -> torch.Tensor: - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) - """ - ro_dim = cos.shape[-1] * 2 - assert ro_dim <= x.shape[-1] - cos = repeat( - cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" - ) - sin = repeat( - sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" - ) - return torch.cat( - [ - x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, - x[..., ro_dim:], - ], - dim=-1, - ) - - -def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: - t_ = t.float() - cos = freqs.cos() - sin = freqs.sin() - output = apply_rotary_emb_torch(t_, cos, sin).type_as(t) - return output - - -class Qwen2VisionAttention(nn.Module): - - def __init__( - self, - embed_dim: Optional[int] = None, - num_heads: Optional[int] = None, - projection_size: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - # Per attention head and per partition values. - world_size = parallel_state.get_tensor_model_parallel_world_size() - self.hidden_size_per_attention_head = dist_utils.divide( - projection_size, num_heads - ) - self.num_attention_heads_per_partition = dist_utils.divide( - num_heads, world_size - ) - - self.qkv = ColumnParallelLinear( - input_size=embed_dim, - output_size=3 * projection_size, - quant_config=quant_config, - ) - self.proj = RowParallelLinear( - input_size=projection_size, output_size=embed_dim, quant_config=quant_config - ) - - def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor = None, - ) -> torch.Tensor: - # [s, b, c] --> [s, b, head * 3 * head_dim] - x, _ = self.qkv(x) - - # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim] - new_x_shape = x.size()[:-1] + ( - self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head, - ) - x = x.view(*new_x_shape) - - # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim] - q, k, v = dist_utils.split_tensor_along_last_dim(x, 3) - batch_size = q.shape[1] - - q, k, v = [rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)] - if rotary_pos_emb is not None: - q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) - k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) - - seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] - max_seqlen = (seq_lens).max().item() - q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]] - - output = torch.empty_like(q) - context_attention_fwd( - q, k, v, output, cu_seqlens, seq_lens, max_seqlen, is_causal=False - ) - - context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) - context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous() - - output, _ = self.proj(context_layer) - return output - - class Qwen2VisionBlock(nn.Module): def __init__( @@ -231,6 +116,7 @@ def __init__( mlp_ratio: float, act_layer: Type[nn.Module] = QuickGELU, norm_layer: Type[nn.Module] = None, + attn_implementation: Optional[str] = "sdpa", quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -239,11 +125,24 @@ def __init__( self.norm1 = norm_layer(dim) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - - self.attn = Qwen2VisionAttention( + if attn_implementation == "sdpa": + use_context_forward = False + use_full_precision_softmax = False + elif attn_implementation == "flash_attention_2": + use_full_precision_softmax = False + use_context_forward = True + elif attn_implementation == "eager": + use_full_precision_softmax = True + use_context_forward = False + + self.attn = VisionAttention( embed_dim=dim, num_heads=num_heads, projection_size=dim, + use_qkv_parallel=False, + use_context_forward=use_context_forward, + use_full_precision_softmax=use_full_precision_softmax, + flatten_batch=True, quant_config=quant_config, ) self.mlp = Qwen2VisionMLP( @@ -253,9 +152,13 @@ def __init__( def forward( self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor ) -> torch.Tensor: - x = x + self.attn( - self.norm1(x), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb + hidden_states = self.norm1(x) + hidden_states = rearrange(hidden_states, "s b ... -> b s ...") + attn = self.attn( + hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb ) + attn = rearrange(attn, "b s ... -> s b ...") + x = x + attn x = x + self.mlp(self.norm2(x)) return x @@ -394,7 +297,6 @@ def __init__( norm_layer = partial(nn.LayerNorm, eps=norm_eps) head_dim = embed_dim // num_heads self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2) - self.blocks = nn.ModuleList( [ Qwen2VisionBlock( @@ -402,6 +304,7 @@ def __init__( num_heads=num_heads, mlp_ratio=mlp_ratio, norm_layer=norm_layer, + attn_implementation="sdpa", quant_config=quant_config, ) for _ in range(depth) @@ -590,10 +493,6 @@ def forward( opensource models), the shape will be `(3, seq_len)`, otherwise it will be `(seq_len,). (Use input_metadata.mrope_positions to replace it) - pixel_values: Pixel values to be fed to a model. - `None` if no images are passed. - image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. - `None` if no images are passed. """ if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": positions = forward_batch.mrope_positions @@ -648,15 +547,18 @@ def forward( num_image_tokens = self.calculate_num_image_tokens( image_grid_thws[idx] ) + left_idx = start_idx + (image_offset - prefix_len) right_idx = ( start_idx + (image_offset - prefix_len) + num_image_tokens ) + inputs_embeds[left_idx:right_idx] = image_embeds[ image_embeds_offset : image_embeds_offset + num_image_tokens ] image_embeds_offset += num_image_tokens + input_ids = None hidden_states = self.model( input_ids=input_ids, positions=positions, @@ -684,10 +586,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue @@ -696,6 +600,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, shard_id) break else: + if "visual" in name and "qkv.weight" in name: visual_num_heads = self.config.vision_config.num_heads visual_embed_dim = self.config.vision_config.embed_dim @@ -712,6 +617,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loaded_weight = loaded_weight.view(3, visual_num_heads, head_size) loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.reshape(-1) + + if "visual" in name: + # adapt to VisionAttention + name = name.replace(r"attn.qkv.", r"attn.qkv_proj.") + try: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: diff --git a/python/sglang/srt/models/stablelm.py b/python/sglang/srt/models/stablelm.py index 079d54e3c83..c169dd6fba4 100644 --- a/python/sglang/srt/models/stablelm.py +++ b/python/sglang/srt/models/stablelm.py @@ -24,9 +24,8 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.rotary_embedding import get_rope +from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.linear import ( MergedColumnParallelLinear, @@ -36,6 +35,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index 7a55d50457a..7b3e5bc5ddd 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -47,17 +47,17 @@ from torch import nn from torch.nn.parameter import Parameter from transformers import LlamaConfig -from vllm.distributed import ( + +from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from vllm.model_executor.layers.rotary_embedding import get_rope - from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -460,7 +460,12 @@ def get_num_params(self): params_dict = dict(self.named_parameters()) return len(params_dict) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights_to_module( + self, + fqn: str, + weights: Iterable[Tuple[str, torch.Tensor]], + ): + """Load weights onto submodule pointed by path `fqn`.""" stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -469,7 +474,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] - params_dict = dict(self.named_parameters()) + module = self.get_submodule(fqn) + params_dict = dict(module.named_parameters(prefix=fqn, recurse=False)) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name or "projector" in name: @@ -486,7 +492,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if name.endswith(".bias") or name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader @@ -494,12 +500,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): break else: # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if name.endswith(".bias") or name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + def load_weights( + self, + weights: Iterable[Tuple[str, torch.Tensor]], + ): + """Load weights onto the full model.""" + self.load_weights_to_module("", weights) + class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM): pass diff --git a/python/sglang/srt/models/xverse.py b/python/sglang/srt/models/xverse.py index e6551421519..7fd24182374 100644 --- a/python/sglang/srt/models/xverse.py +++ b/python/sglang/srt/models/xverse.py @@ -21,19 +21,19 @@ import torch from torch import nn from transformers import LlamaConfig -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ( + +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) -from vllm.model_executor.layers.rotary_embedding import get_rope - from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/xverse_moe.py b/python/sglang/srt/models/xverse_moe.py index e1f3288753b..218b96f9cb4 100644 --- a/python/sglang/srt/models/xverse_moe.py +++ b/python/sglang/srt/models/xverse_moe.py @@ -18,25 +18,25 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.distributed import ( + +from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ( +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) -from vllm.model_executor.layers.rotary_embedding import get_rope - -from sglang.srt.layers.fused_moe_triton import fused_moe from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.fused_moe_triton import fused_moe from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index dfb7d4f18bf..6687a4c0133 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -20,7 +20,7 @@ import time import uuid from http import HTTPStatus -from typing import Dict, List +from typing import Dict, List, Optional from fastapi import HTTPException, Request, UploadFile from fastapi.responses import ORJSONResponse, StreamingResponse @@ -40,6 +40,7 @@ generate_chat_conv, register_conv_template, ) +from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput from sglang.srt.openai_api.protocol import ( BatchRequest, @@ -65,7 +66,9 @@ FileDeleteResponse, FileRequest, FileResponse, + FunctionResponse, LogProbs, + ToolCall, TopLogprob, UsageInfo, ) @@ -306,6 +309,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe ret, to_file=True, cache_report=tokenizer_manager.server_args.enable_cache_report, + tool_call_parser=tokenizer_manager.server_args.tool_call_parser, ) else: responses = v1_generate_response( @@ -510,11 +514,14 @@ def v1_generate_request( "stop": request.stop, "stop_token_ids": request.stop_token_ids, "top_p": request.top_p, + "top_k": request.top_k, + "min_p": request.min_p, "presence_penalty": request.presence_penalty, "frequency_penalty": request.frequency_penalty, "repetition_penalty": request.repetition_penalty, "regex": request.regex, "json_schema": request.json_schema, + "ebnf": request.ebnf, "n": request.n, "no_stop_trim": request.no_stop_trim, "ignore_eos": request.ignore_eos, @@ -856,6 +863,7 @@ def v1_chat_generate_request( logprob_start_lens = [] top_logprobs_nums = [] modalities_list = [] + lora_paths = [] # NOTE: with openai API, the prompt's logprobs are always not computed @@ -867,6 +875,18 @@ def v1_chat_generate_request( # None skips any image processing in GenerateReqInput. if not isinstance(request.messages, str): # Apply chat template and its stop strings. + tools = None + if request.tools and request.tool_choice != "none": + request.skip_special_tokens = False + if not isinstance(request.tool_choice, str): + tools = [ + item.function.model_dump() + for item in request.tools + if item.function.name == request.tool_choice.function.name + ] + else: + tools = [item.function.model_dump() for item in request.tools] + if chat_template_name is None: openai_compatible_messages = [] for message in request.messages: @@ -886,11 +906,26 @@ def v1_chat_generate_request( openai_compatible_messages = openai_compatible_messages[:-1] else: assistant_prefix = None - prompt_ids = tokenizer_manager.tokenizer.apply_chat_template( - openai_compatible_messages, - tokenize=True, - add_generation_prompt=True, - ) + + try: + prompt_ids = tokenizer_manager.tokenizer.apply_chat_template( + openai_compatible_messages, + tokenize=True, + add_generation_prompt=True, + tools=tools, + ) + except: + # This except branch will be triggered when the chosen model + # has a different tools input format that is not compatiable + # with openAI's apply_chat_template tool_call format, like Mistral. + tools = [t if "function" in t else {"function": t} for t in tools] + prompt_ids = tokenizer_manager.tokenizer.apply_chat_template( + openai_compatible_messages, + tokenize=True, + add_generation_prompt=True, + tools=tools, + ) + if assistant_prefix: prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix) stop = request.stop @@ -918,6 +953,7 @@ def v1_chat_generate_request( return_logprobs.append(request.logprobs) logprob_start_lens.append(-1) top_logprobs_nums.append(request.top_logprobs or 0) + lora_paths.append(request.lora_path) sampling_params = { "temperature": request.temperature, @@ -926,10 +962,13 @@ def v1_chat_generate_request( "stop": stop, "stop_token_ids": request.stop_token_ids, "top_p": request.top_p, + "top_k": request.top_k, + "min_p": request.min_p, "presence_penalty": request.presence_penalty, "frequency_penalty": request.frequency_penalty, "repetition_penalty": request.repetition_penalty, "regex": request.regex, + "ebnf": request.ebnf, "n": request.n, "no_stop_trim": request.no_stop_trim, "ignore_eos": request.ignore_eos, @@ -954,6 +993,7 @@ def v1_chat_generate_request( logprob_start_lens = logprob_start_lens[0] top_logprobs_nums = top_logprobs_nums[0] modalities_list = modalities_list[0] + lora_paths = lora_paths[0] else: if isinstance(input_ids[0], str): prompt_kwargs = {"text": input_ids} @@ -971,12 +1011,15 @@ def v1_chat_generate_request( return_text_in_logprobs=True, rid=request_ids, modalities=modalities_list, + lora_path=lora_paths, ) return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0] -def v1_chat_generate_response(request, ret, to_file=False, cache_report=False): +def v1_chat_generate_response( + request, ret, to_file=False, cache_report=False, tool_call_parser=None +): choices = [] for idx, ret_item in enumerate(ret): @@ -1023,11 +1066,47 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False): finish_reason = ret_item["meta_info"]["finish_reason"] + tool_calls = None + text = ret_item["text"] + + if isinstance(request, list): + tool_choice = request[idx].tool_choice + tools = request[idx].tools + else: + tool_choice = request.tool_choice + tools = request.tools + + if tool_choice != "none" and any([i in text for i in TOOLS_TAG_LIST]): + if finish_reason == "stop": + finish_reason = "tool_calls" + try: + parser = FunctionCallParser(tools, tool_call_parser) + full_normal_text, call_info_list = parser.parse_non_stream(text) + tool_calls = [ + ToolCall( + id=str(call_info.tool_index), + function=FunctionResponse( + name=call_info.name, arguments=call_info.parameters + ), + ) + for call_info in call_info_list + ] + except Exception as e: + logger.error(f"Exception: {e}") + return create_error_response( + HTTPStatus.BAD_REQUEST, + "Failed to parse fc related info to json format!", + ) + if to_file: # to make the choice data json serializable choice_data = { "index": 0, - "message": {"role": "assistant", "content": ret_item["text"]}, + "message": { + "role": "assistant", + "content": ret_item["text"] if tool_calls is None else None, + "tool_calls": tool_calls, + }, "logprobs": choice_logprobs, "finish_reason": (finish_reason["type"] if finish_reason else ""), "matched_stop": ( @@ -1039,7 +1118,11 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False): else: choice_data = ChatCompletionResponseChoice( index=idx, - message=ChatMessage(role="assistant", content=ret_item["text"]), + message=ChatMessage( + role="assistant", + content=ret_item["text"] if tool_calls is None else None, + tool_calls=tool_calls, + ), logprobs=choice_logprobs, finish_reason=(finish_reason["type"] if finish_reason else ""), matched_stop=( @@ -1104,6 +1187,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request): adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager) if adapted_request.stream: + parser_dict = {} async def generate_stream_resp(): is_firsts = {} @@ -1116,6 +1200,7 @@ async def generate_stream_resp(): adapted_request, raw_request ): index = content.get("index", 0) + text = content["text"] is_first = is_firsts.get(index, True) stream_buffer = stream_buffers.get(index, "") @@ -1195,29 +1280,111 @@ async def generate_stream_resp(): text = content["text"] delta = text[len(stream_buffer) :] - stream_buffer = stream_buffer + delta - choice_data = ChatCompletionResponseStreamChoice( - index=index, - delta=DeltaMessage(content=delta), - finish_reason=(finish_reason["type"] if finish_reason else ""), - matched_stop=( - finish_reason["matched"] - if finish_reason and "matched" in finish_reason - else None - ), - logprobs=choice_logprobs, - ) - chunk = ChatCompletionStreamResponse( - id=content["meta_info"]["id"], - choices=[choice_data], - model=request.model, - ) + new_stream_buffer = stream_buffer + delta - is_firsts[index] = is_first - stream_buffers[index] = stream_buffer - n_prev_tokens[index] = n_prev_token + if request.tool_choice != "none" and request.tools: + if index not in parser_dict: + parser_dict[index] = FunctionCallParser( + tools=request.tools, + tool_call_parser=tokenizer_manager.server_args.tool_call_parser, + ) + parser = parser_dict[index] + + # parse_increment => returns (normal_text, calls) + normal_text, calls = parser.parse_stream_chunk(delta) + + # 1) if there's normal_text, output it as normal content + if normal_text: + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(content=normal_text), + finish_reason=( + finish_reason["type"] if finish_reason else "" + ), + ) + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.model_dump_json()}\n\n" + + # 2) if we found calls, we output them as separate chunk(s) + for call_item in calls: + # transform call_item -> FunctionResponse + ToolCall + + if ( + content["meta_info"]["finish_reason"] + and content["meta_info"]["finish_reason"]["type"] + == "stop" + ): + latest_delta_len = 0 + if isinstance(call_item.parameters, str): + latest_delta_len = len(call_item.parameters) + + expected_call = json.dumps( + parser.multi_format_parser.detectors[0] + .prev_tool_call_arr[index] + .get("arguments", {}), + ensure_ascii=False, + ) + actual_call = parser.multi_format_parser.detectors[ + 0 + ].streamed_args_for_tool[index] + if latest_delta_len > 0: + actual_call = actual_call[:-latest_delta_len] + remaining_call = expected_call.replace( + actual_call, "", 1 + ) + call_item.parameters = remaining_call + + tool_call = ToolCall( + id=str(call_item.tool_index), + function=FunctionResponse( + name=call_item.name, + arguments=call_item.parameters, + ), + ) + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage( + role="assistant", tool_calls=[tool_call] + ), + finish_reason="tool_call", + ) + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.model_dump_json()}\n\n" - yield f"data: {chunk.model_dump_json()}\n\n" + stream_buffers[index] = new_stream_buffer + is_firsts[index] = is_first + + else: + # No tool calls => just treat this as normal text + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(content=delta), + finish_reason=( + finish_reason["type"] if finish_reason else "" + ), + matched_stop=( + finish_reason["matched"] + if finish_reason and "matched" in finish_reason + else None + ), + logprobs=choice_logprobs, + ) + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], + choices=[choice_data], + model=request.model, + ) + yield f"data: {chunk.model_dump_json()}\n\n" + stream_buffers[index] = new_stream_buffer + is_firsts[index] = is_first if request.stream_options and request.stream_options.include_usage: total_prompt_tokens = sum( tokens @@ -1265,7 +1432,10 @@ async def generate_stream_resp(): ret = [ret] response = v1_chat_generate_response( - request, ret, cache_report=tokenizer_manager.server_args.enable_cache_report + request, + ret, + cache_report=tokenizer_manager.server_args.enable_cache_report, + tool_call_parser=tokenizer_manager.server_args.tool_call_parser, ) return response diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 7c88ad5332e..95b34527edb 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -166,17 +166,21 @@ class CompletionRequest(BaseModel): temperature: float = 1.0 top_p: float = 1.0 user: Optional[str] = None - lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None # Extra parameters for SRT backend only and will be ignored by OpenAI models. + top_k: int = -1 + min_p: float = 0.0 + min_tokens: int = 0 json_schema: Optional[str] = None regex: Optional[str] = None - min_tokens: int = 0 + ebnf: Optional[str] = None repetition_penalty: float = 1.0 stop_token_ids: Optional[List[int]] = None no_stop_trim: bool = False ignore_eos: bool = False skip_special_tokens: bool = True + lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None + session_params: Optional[Dict] = None class CompletionResponseChoice(BaseModel): @@ -254,6 +258,34 @@ class ResponseFormat(BaseModel): json_schema: Optional[JsonSchemaResponseFormat] = None +class Function(BaseModel): + """Function descriptions.""" + + description: Optional[str] = Field(default=None, examples=[None]) + name: Optional[str] = None + parameters: Optional[object] = None + + +class Tool(BaseModel): + """Function wrapper.""" + + type: str = Field(default="function", examples=["function"]) + function: Function + + +class ToolChoiceFuncName(BaseModel): + """The name of tool choice function.""" + + name: Optional[str] = None + + +class ToolChoice(BaseModel): + """The tool choice definition.""" + + function: ToolChoiceFuncName + type: Literal["function"] = Field(default="function", examples=["function"]) + + class ChatCompletionRequest(BaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/chat/create @@ -274,20 +306,45 @@ class ChatCompletionRequest(BaseModel): temperature: float = 0.7 top_p: float = 1.0 user: Optional[str] = None + tools: Optional[List[Tool]] = Field(default=None, examples=[None]) + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field( + default="auto", examples=["none"] + ) # noqa # Extra parameters for SRT backend only and will be ignored by OpenAI models. - regex: Optional[str] = None + top_k: int = -1 + min_p: float = 0.0 min_tokens: int = 0 + regex: Optional[str] = None + ebnf: Optional[str] = None repetition_penalty: float = 1.0 stop_token_ids: Optional[List[int]] = None no_stop_trim: bool = False ignore_eos: bool = False skip_special_tokens: bool = True + lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None + session_params: Optional[Dict] = None + + +class FunctionResponse(BaseModel): + """Function response.""" + + name: Optional[str] = None + arguments: Optional[str] = None + + +class ToolCall(BaseModel): + """Tool call response.""" + + id: str + type: Literal["function"] = "function" + function: FunctionResponse class ChatMessage(BaseModel): role: Optional[str] = None content: Optional[str] = None + tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) class ChatCompletionResponseChoice(BaseModel): @@ -310,6 +367,7 @@ class ChatCompletionResponse(BaseModel): class DeltaMessage(BaseModel): role: Optional[str] = None content: Optional[str] = None + tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) class ChatCompletionResponseStreamChoice(BaseModel): diff --git a/python/sglang/srt/sampling/custom_logit_processor.py b/python/sglang/srt/sampling/custom_logit_processor.py new file mode 100644 index 00000000000..a64b2498f23 --- /dev/null +++ b/python/sglang/srt/sampling/custom_logit_processor.py @@ -0,0 +1,38 @@ +import json +from abc import ABC, abstractmethod +from functools import lru_cache +from typing import Any, Dict, List, Optional + +import dill +import torch + + +@lru_cache(maxsize=None) +def _cache_from_str(json_str: str): + """Deserialize a json string to a Callable object. + This function is cached to avoid redundant deserialization. + """ + data = json.loads(json_str) + return dill.loads(bytes.fromhex(data["callable"])) + + +class CustomLogitProcessor(ABC): + """Abstract base class for callable functions.""" + + @abstractmethod + def __call__( + self, + logits: torch.Tensor, + custom_param_list: Optional[List[Dict[str, Any]]] = None, + ) -> torch.Tensor: + """Define the callable behavior.""" + raise NotImplementedError + + def to_str(self) -> str: + """Serialize the callable function to a JSON-compatible string.""" + return json.dumps({"callable": dill.dumps(self).hex()}) + + @classmethod + def from_str(cls, json_str: str): + """Deserialize a callable function from a JSON string.""" + return _cache_from_str(json_str) diff --git a/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py b/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py index 4c293b89520..fe687c569d4 100644 --- a/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +++ b/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py @@ -3,6 +3,16 @@ import torch from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs +from sglang.srt.utils import get_compiler_backend + + +@torch.compile(dynamic=True, backend=get_compiler_backend()) +def apply_scaling_penalties(logits, scaling_penalties): + logits[:] = torch.where( + logits > 0, + logits / scaling_penalties, + logits * scaling_penalties, + ) class BatchedRepetitionPenalizer(_BatchedPenalizer): @@ -56,11 +66,8 @@ def _cumulate_output_tokens(self, output_ids: _TokenIDs): self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask] def _apply(self, logits: torch.Tensor) -> torch.Tensor: - return torch.where( - logits > 0, - logits / self.cumulated_repetition_penalties, - logits * self.cumulated_repetition_penalties, - ) + apply_scaling_penalties(logits, self.cumulated_repetition_penalties) + return logits def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor): self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep] diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index a64a84a62dc..9521a34f4f6 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -3,11 +3,15 @@ import dataclasses import logging import threading -from typing import TYPE_CHECKING, Callable, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple import torch import sglang.srt.sampling.penaltylib as penaltylib +from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor +from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import ( + apply_scaling_penalties, +) logger = logging.getLogger(__name__) @@ -30,6 +34,9 @@ class SamplingBatchInfo: # Dispatch in CUDA graph need_min_p_sampling: bool + # Whether any request has custom logit processor + has_custom_logit_processor: bool + # Bias Tensors vocab_size: int grammars: Optional[List] = None @@ -46,6 +53,14 @@ class SamplingBatchInfo: # Device device: str = "cuda" + # Custom Parameters + custom_params: Optional[List[Optional[Dict[str, Any]]]] = None + + # Custom Logit Processor + custom_logit_processor: Optional[ + Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]] + ] = None + @classmethod def from_schedule_batch( cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool @@ -70,6 +85,39 @@ def from_schedule_batch( [r.sampling_params.min_p for r in reqs], dtype=torch.float ).to(device, non_blocking=True) + # Check if any request has custom logit processor + has_custom_logit_processor = ( + batch.enable_custom_logit_processor # check the flag first. + and any(r.custom_logit_processor for r in reqs) # then check the requests. + ) + + if has_custom_logit_processor: + # Merge the same type of custom logit processors together + processor_dict = {} + for i, r in enumerate(reqs): + if r.custom_logit_processor is None: + continue + processor_str = r.custom_logit_processor + if processor_str not in processor_dict: + processor_dict[processor_str] = [] + processor_dict[processor_str].append(i) + + merged_custom_logit_processor = { + hash(processor_str): ( + # The deserialized custom logit processor object + CustomLogitProcessor.from_str(processor_str), + # The mask tensor for the requests that use this custom logit processor + torch.zeros(len(reqs), dtype=torch.bool) + .scatter_(0, torch.tensor(true_indices), True) + .to(device, non_blocking=True), + ) + for processor_str, true_indices in processor_dict.items() + } + custom_params = [r.sampling_params.custom_params for r in reqs] + else: + merged_custom_logit_processor = None + custom_params = None + ret = cls( temperatures=temperatures, top_ps=top_ps, @@ -77,8 +125,11 @@ def from_schedule_batch( min_ps=min_ps, need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs), is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs), + has_custom_logit_processor=has_custom_logit_processor, vocab_size=vocab_size, device=device, + custom_params=custom_params, + custom_logit_processor=merged_custom_logit_processor, ) # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge. @@ -178,6 +229,8 @@ def update_regex_vocab_mask(self): def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor): self.penalizer_orchestrator.filter(unfinished_indices, new_indices) + if self.has_custom_logit_processor: + self._filter_batch_custom_logit_processor(unfinished_indices, new_indices) for item in [ "temperatures", @@ -190,6 +243,27 @@ def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor) if value is not None: # logit_bias can be None setattr(self, item, value[new_indices]) + def _filter_batch_custom_logit_processor( + self, unfinished_indices: List[int], new_indices: torch.Tensor + ): + """Filter the custom logit processor and custom params""" + + self.custom_logit_processor = { + k: (p, mask[new_indices]) + for k, (p, mask) in self.custom_logit_processor.items() + if any( + mask[new_indices] + ) # ignore the custom logit processor whose mask is all False + } + self.custom_params = [self.custom_params[i] for i in unfinished_indices] + + # If the custom logit processor is an empty dict, set the flag to False, + # and set the custom logit processor and custom params to None. + if len(self.custom_logit_processor) == 0: + self.custom_logit_processor = None + self.custom_params = None + self.has_custom_logit_processor = False + @staticmethod def merge_bias_tensor( lhs: torch.Tensor, @@ -215,9 +289,76 @@ def merge_bias_tensor( return None + @staticmethod + def merge_custom_logit_processor( + lhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]], + rhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]], + bs1: int, + bs2: int, + device: str, + ): + if lhs is None and rhs is None: + return None + lhs, rhs = lhs or {}, rhs or {} + + keys = set(lhs.keys()).union(set(rhs.keys())) + merged_dict = {} + + for k in keys: + # Get the logit processor object + processor = lhs[k][0] if k in lhs else rhs[k][0] + # Get and merge the mask tensors from the two dicts + left_mask = ( + lhs[k][1] + if k in lhs + else torch.zeros(bs1, dtype=torch.bool, device=device) + ) + right_mask = ( + rhs[k][1] + if k in rhs + else torch.zeros(bs2, dtype=torch.bool, device=device) + ) + merged_dict[k] = (processor, torch.cat([left_mask, right_mask])) + + assert merged_dict[k][1].shape[0] == bs1 + bs2, ( + f"The batch size of merged mask ({merged_dict[k][1].shape[0]}) does not match " + f"the sum of the batch sizes of the two masks ({bs1 + bs2})" + f"\n{left_mask=}\n{right_mask=}\n{bs1=}\n{bs2=}" + f"\n{lhs=}\n{rhs=}" + ) + + return merged_dict + def merge_batch(self, other: "SamplingBatchInfo"): self.penalizer_orchestrator.merge(other.penalizer_orchestrator) + # Merge the logit bias tensor + self.logit_bias = SamplingBatchInfo.merge_bias_tensor( + self.logit_bias, other.logit_bias, len(self), len(other), self.device + ) + # Merge the custom logit processors and custom params lists + if self.has_custom_logit_processor or other.has_custom_logit_processor: + # Merge the custom logit processors + self.custom_logit_processor = ( + SamplingBatchInfo.merge_custom_logit_processor( + self.custom_logit_processor, + other.custom_logit_processor, + len(self), + len(other), + self.device, + ) + ) + # Merge the custom params lists + self.custom_params = self.custom_params or [None] * len(self) + other.custom_params = other.custom_params or [None] * len(other) + self.custom_params.extend(other.custom_params) + + # Set the flag to True if any of the two has custom logit processor + self.has_custom_logit_processor = True + + # Note: becasue the __len()__ operator is defined on the temperatures tensor, + # please make sure any merge operation with len(self) or len(other) is done before + # the merge operation of the temperatures tensor below. for item in [ "temperatures", "top_ps", @@ -229,6 +370,21 @@ def merge_batch(self, other: "SamplingBatchInfo"): setattr(self, item, torch.concat([self_val, other_val])) self.is_all_greedy = self.is_all_greedy and other.is_all_greedy - self.logit_bias = SamplingBatchInfo.merge_bias_tensor( - self.logit_bias, other.logit_bias, len(self), len(other), self.device - ) + self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling + + def apply_logits_bias(self, logits: torch.Tensor): + # Apply logit_bias + if self.logit_bias is not None: + logits.add_(self.logit_bias) + + # min-token, presence, frequency + if self.linear_penalties is not None: + logits.add_(self.linear_penalties) + + # repetition + if self.scaling_penalties is not None: + apply_scaling_penalties(logits, self.scaling_penalties) + + # Apply regex vocab_mask + if self.vocab_mask is not None: + self.apply_mask(logits=logits, vocab_mask=self.vocab_mask) diff --git a/python/sglang/srt/sampling/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py index 64d5e0783ea..2224fb0919a 100644 --- a/python/sglang/srt/sampling/sampling_params.py +++ b/python/sglang/srt/sampling/sampling_params.py @@ -13,12 +13,20 @@ # ============================================================================== """Sampling parameters for text generation.""" -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Union _SAMPLING_EPS = 1e-6 class SamplingParams: + """ + The sampling parameters. + + See docs/references/sampling_params.md or + https://docs.sglang.ai/references/sampling_params.html + for the documentation. + """ + def __init__( self, max_new_tokens: int = 128, @@ -33,12 +41,14 @@ def __init__( repetition_penalty: float = 1.0, min_new_tokens: int = 0, spaces_between_special_tokens: bool = True, - regex: Optional[str] = None, n: int = 1, json_schema: Optional[str] = None, + regex: Optional[str] = None, + ebnf: Optional[str] = None, no_stop_trim: bool = False, ignore_eos: bool = False, skip_special_tokens: bool = True, + custom_params: Optional[Dict[str, Any]] = None, ) -> None: self.temperature = temperature self.top_p = top_p @@ -60,7 +70,9 @@ def __init__( self.regex = regex self.n = n self.json_schema = json_schema + self.ebnf = ebnf self.no_stop_trim = no_stop_trim + self.custom_params = custom_params # Process some special cases if self.temperature < _SAMPLING_EPS: @@ -111,8 +123,13 @@ def verify(self): f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got " f"{self.min_new_tokens}." ) - if self.regex is not None and self.json_schema is not None: - raise ValueError("regex and json_schema cannot be both set.") + grammars = [ + self.json_schema, + self.regex, + self.ebnf, + ] # since mutually exclusive, only one can be set + if sum(x is not None for x in grammars) > 1: + raise ValueError("Only one of regex, json_schema, or ebnf can be set.") def normalize(self, tokenizer): # Process stop strings diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 29bc44eb524..869a984d0cf 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -11,1027 +11,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -""" -The entry point of inference server. -SRT = SGLang Runtime. -""" -import asyncio -import atexit -import dataclasses -import json -import logging -import multiprocessing as mp -import os -import signal -import threading -import time -from http import HTTPStatus -from typing import AsyncIterator, Dict, List, Optional, Union - -# Fix a bug of Python threading -setattr(threading, "_register_atexit", lambda *args, **kwargs: None) - -import aiohttp -import orjson -import requests -import uvicorn -import uvloop -from fastapi import FastAPI, File, Form, Request, UploadFile -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import ORJSONResponse, Response, StreamingResponse -from uvicorn.config import LOGGING_CONFIG - -from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint -from sglang.srt.hf_transformers_utils import get_tokenizer -from sglang.srt.managers.data_parallel_controller import ( - run_data_parallel_controller_process, -) -from sglang.srt.managers.detokenizer_manager import run_detokenizer_process -from sglang.srt.managers.io_struct import ( - CloseSessionReqInput, - EmbeddingReqInput, - GenerateReqInput, - GetWeightsByNameReqInput, - InitWeightsUpdateGroupReqInput, - OpenSessionReqInput, - UpdateWeightFromDiskReqInput, - UpdateWeightsFromDistributedReqInput, -) -from sglang.srt.managers.scheduler import run_scheduler_process -from sglang.srt.managers.tokenizer_manager import TokenizerManager -from sglang.srt.metrics.func_timer import enable_func_timer, time_func_latency -from sglang.srt.openai_api.adapter import ( - load_chat_template_for_openai_api, - v1_batches, - v1_cancel_batch, - v1_chat_completions, - v1_completions, - v1_delete_file, - v1_embeddings, - v1_files_create, - v1_retrieve_batch, - v1_retrieve_file, - v1_retrieve_file_content, -) -from sglang.srt.openai_api.protocol import ModelCard, ModelList -from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import ( - add_api_key_middleware, - add_prometheus_middleware, - assert_pkg_version, - configure_logger, - delete_directory, - is_port_available, - kill_process_tree, - maybe_set_triton_cache_manager, - prepare_model_and_tokenizer, - set_prometheus_multiproc_dir, - set_ulimit, -) -from sglang.utils import get_exception_traceback -from sglang.version import __version__ - -logger = logging.getLogger(__name__) - -asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) - -# Fast API -app = FastAPI() -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -tokenizer_manager: TokenizerManager = None -scheduler_info: Dict = None - -##### Native API endpoints ##### - - -@app.get("/health") -async def health() -> Response: - """Check the health of the http server.""" - return Response(status_code=200) - - -@app.get("/health_generate") -async def health_generate(request: Request) -> Response: - """Check the health of the inference server by generating one token.""" - - if tokenizer_manager.is_generation: - gri = GenerateReqInput( - input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7} - ) - else: - gri = EmbeddingReqInput( - input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7} - ) - - try: - async for _ in tokenizer_manager.generate_request(gri, request): - break - return Response(status_code=200) - except Exception as e: - logger.exception(e) - return Response(status_code=503) - - -@app.get("/get_model_info") -async def get_model_info(): - """Get the model information.""" - result = { - "model_path": tokenizer_manager.model_path, - "tokenizer_path": tokenizer_manager.server_args.tokenizer_path, - "is_generation": tokenizer_manager.is_generation, - } - return result - - -@app.get("/get_server_info") -async def get_server_info(): - return { - **dataclasses.asdict(tokenizer_manager.server_args), # server args - **scheduler_info, - "version": __version__, - } - - -@app.post("/flush_cache") -async def flush_cache(): - """Flush the radix cache.""" - tokenizer_manager.flush_cache() - return Response( - content="Cache flushed.\nPlease check backend logs for more details. " - "(When there are running or waiting requests, the operation will not be performed.)\n", - status_code=200, - ) - - -@app.get("/start_profile") -@app.post("/start_profile") -async def start_profile_async(): - """Start profiling.""" - tokenizer_manager.start_profile() - return Response( - content="Start profiling.\n", - status_code=200, - ) - - -@app.get("/stop_profile") -@app.post("/stop_profile") -async def stop_profile_async(): - """Stop profiling.""" - tokenizer_manager.stop_profile() - return Response( - content="Stop profiling. This will take some time.\n", - status_code=200, - ) - - -@app.post("/update_weights_from_disk") -@time_func_latency -async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request): - """Update the weights from disk inplace without re-launching the server.""" - success, message = await tokenizer_manager.update_weights_from_disk(obj, request) - content = {"success": success, "message": message} - if success: - return ORJSONResponse( - content, - status_code=HTTPStatus.OK, - ) - else: - return ORJSONResponse( - content, - status_code=HTTPStatus.BAD_REQUEST, - ) - - -@app.post("/init_weights_update_group") -async def init_weights_update_group( - obj: InitWeightsUpdateGroupReqInput, request: Request -): - """Initialize the parameter update group.""" - success, message = await tokenizer_manager.init_weights_update_group(obj, request) - content = {"success": success, "message": message} - if success: - return ORJSONResponse(content, status_code=200) - else: - return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) - - -@app.post("/update_weights_from_distributed") -async def update_weights_from_distributed( - obj: UpdateWeightsFromDistributedReqInput, request: Request -): - """Update model parameter from distributed online.""" - success, message = await tokenizer_manager.update_weights_from_distributed( - obj, request - ) - content = {"success": success, "message": message} - if success: - return ORJSONResponse(content, status_code=200) - else: - return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) - - -@app.api_route("/get_weights_by_name", methods=["GET", "POST"]) -async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request): - """Get model parameter by name.""" - try: - ret = await tokenizer_manager.get_weights_by_name(obj, request) - if ret is None: - return ORJSONResponse( - {"error": {"message": "Get parameter by name failed"}}, - status_code=HTTPStatus.BAD_REQUEST, - ) - else: - return ORJSONResponse(ret, status_code=200) - except Exception as e: - return ORJSONResponse( - {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST - ) - - -@app.api_route("/open_session", methods=["GET", "POST"]) -async def open_session(obj: OpenSessionReqInput, request: Request): - """Open a session, and return its unique session id.""" - try: - session_id = await tokenizer_manager.open_session(obj, request) - return session_id - except Exception as e: - return ORJSONResponse( - {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST - ) - - -@app.api_route("/close_session", methods=["GET", "POST"]) -async def close_session(obj: CloseSessionReqInput, request: Request): - """Close the session""" - try: - await tokenizer_manager.close_session(obj, request) - return Response(status_code=200) - except Exception as e: - return ORJSONResponse( - {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST - ) - - -# fastapi implicitly converts json in the request to obj (dataclass) -@app.api_route("/generate", methods=["POST", "PUT"]) -@time_func_latency -async def generate_request(obj: GenerateReqInput, request: Request): - """Handle a generate request.""" - if obj.stream: - - async def stream_results() -> AsyncIterator[bytes]: - try: - async for out in tokenizer_manager.generate_request(obj, request): - yield b"data: " + orjson.dumps( - out, option=orjson.OPT_NON_STR_KEYS - ) + b"\n\n" - except ValueError as e: - out = {"error": {"message": str(e)}} - yield b"data: " + orjson.dumps( - out, option=orjson.OPT_NON_STR_KEYS - ) + b"\n\n" - yield b"data: [DONE]\n\n" - - return StreamingResponse( - stream_results(), - media_type="text/event-stream", - background=tokenizer_manager.create_abort_task(obj), - ) - else: - try: - ret = await tokenizer_manager.generate_request(obj, request).__anext__() - return ret - except ValueError as e: - return ORJSONResponse( - {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST - ) - - -@app.api_route("/encode", methods=["POST", "PUT"]) -@time_func_latency -async def encode_request(obj: EmbeddingReqInput, request: Request): - """Handle an embedding request.""" - try: - ret = await tokenizer_manager.generate_request(obj, request).__anext__() - return ret - except ValueError as e: - return ORJSONResponse( - {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST - ) - - -@app.api_route("/classify", methods=["POST", "PUT"]) -@time_func_latency -async def classify_request(obj: EmbeddingReqInput, request: Request): - """Handle a reward model request. Now the arguments and return values are the same as embedding models.""" - try: - ret = await tokenizer_manager.generate_request(obj, request).__anext__() - return ret - except ValueError as e: - return ORJSONResponse( - {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST - ) - - -##### OpenAI-compatible API endpoints ##### - - -@app.post("/v1/completions") -@time_func_latency -async def openai_v1_completions(raw_request: Request): - return await v1_completions(tokenizer_manager, raw_request) - - -@app.post("/v1/chat/completions") -@time_func_latency -async def openai_v1_chat_completions(raw_request: Request): - return await v1_chat_completions(tokenizer_manager, raw_request) - - -@app.post("/v1/embeddings", response_class=ORJSONResponse) -@time_func_latency -async def openai_v1_embeddings(raw_request: Request): - response = await v1_embeddings(tokenizer_manager, raw_request) - return response - - -@app.get("/v1/models", response_class=ORJSONResponse) -def available_models(): - """Show available models.""" - served_model_names = [tokenizer_manager.served_model_name] - model_cards = [] - for served_model_name in served_model_names: - model_cards.append(ModelCard(id=served_model_name, root=served_model_name)) - return ModelList(data=model_cards) - - -@app.post("/v1/files") -async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")): - return await v1_files_create( - file, purpose, tokenizer_manager.server_args.file_storage_pth - ) - - -@app.delete("/v1/files/{file_id}") -async def delete_file(file_id: str): - # https://platform.openai.com/docs/api-reference/files/delete - return await v1_delete_file(file_id) - - -@app.post("/v1/batches") -async def openai_v1_batches(raw_request: Request): - return await v1_batches(tokenizer_manager, raw_request) - - -@app.post("/v1/batches/{batch_id}/cancel") -async def cancel_batches(batch_id: str): - # https://platform.openai.com/docs/api-reference/batch/cancel - return await v1_cancel_batch(tokenizer_manager, batch_id) - - -@app.get("/v1/batches/{batch_id}") -async def retrieve_batch(batch_id: str): - return await v1_retrieve_batch(batch_id) - - -@app.get("/v1/files/{file_id}") -async def retrieve_file(file_id: str): - # https://platform.openai.com/docs/api-reference/files/retrieve - return await v1_retrieve_file(file_id) - - -@app.get("/v1/files/{file_id}/content") -async def retrieve_file_content(file_id: str): - # https://platform.openai.com/docs/api-reference/files/retrieve-contents - return await v1_retrieve_file_content(file_id) - - -def launch_engine( - server_args: ServerArgs, -): - """ - Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess. - """ - - global tokenizer_manager - global scheduler_info - - # Configure global environment - configure_logger(server_args) - server_args.check_server_args() - _set_envs_and_config(server_args) - - # Allocate ports for inter-process communications - port_args = PortArgs.init_new(server_args) - logger.info(f"{server_args=}") - - # If using model from www.modelscope.cn, first download the model. - server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer( - server_args.model_path, server_args.tokenizer_path - ) - - if server_args.dp_size == 1: - # Launch tensor parallel scheduler processes - scheduler_procs = [] - scheduler_pipe_readers = [] - tp_size_per_node = server_args.tp_size // server_args.nnodes - tp_rank_range = range( - tp_size_per_node * server_args.node_rank, - tp_size_per_node * (server_args.node_rank + 1), - ) - for tp_rank in tp_rank_range: - reader, writer = mp.Pipe(duplex=False) - gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node - proc = mp.Process( - target=run_scheduler_process, - args=(server_args, port_args, gpu_id, tp_rank, None, writer), - ) - proc.start() - scheduler_procs.append(proc) - scheduler_pipe_readers.append(reader) - - if server_args.node_rank >= 1: - # For other nodes, they do not need to run tokenizer or detokenizer, - # so they can just wait here. - for proc in scheduler_procs: - proc.join() - else: - # Launch the data parallel controller - reader, writer = mp.Pipe(duplex=False) - scheduler_pipe_readers = [reader] - proc = mp.Process( - target=run_data_parallel_controller_process, - args=(server_args, port_args, writer), - ) - proc.start() - - # Launch detokenizer process - detoken_proc = mp.Process( - target=run_detokenizer_process, - args=( - server_args, - port_args, - ), - ) - detoken_proc.start() - - # Launch tokenizer process - tokenizer_manager = TokenizerManager(server_args, port_args) - if server_args.chat_template: - load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) - - # Wait for model to finish loading - scheduler_infos = [] - for i in range(len(scheduler_pipe_readers)): - data = scheduler_pipe_readers[i].recv() - - if data["status"] != "ready": - raise RuntimeError( - "Initialization failed. Please see the error messages above." - ) - scheduler_infos.append(data) - - # Assume all schedulers have same max_total_num_tokens - scheduler_info = scheduler_infos[0] - - -def launch_server( - server_args: ServerArgs, - pipe_finish_writer: Optional[mp.connection.Connection] = None, -): - """ - Launch SRT (SGLang Runtime) Server - - The SRT server consists of an HTTP server and the SRT engine. - - 1. HTTP server: A FastAPI server that routes requests to the engine. - 2. SRT engine: - 1. TokenizerManager: Tokenizes the requests and sends them to the scheduler. - 2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager. - 3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager. - - Note: - 1. The HTTP server and TokenizerManager both run in the main process. - 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library. - """ - launch_engine(server_args=server_args) - - # Add api key authorization - if server_args.api_key: - add_api_key_middleware(app, server_args.api_key) - - # Add prometheus middleware - if server_args.enable_metrics: - add_prometheus_middleware(app) - enable_func_timer() - - # Send a warmup request - t = threading.Thread( - target=_wait_and_warmup, args=(server_args, pipe_finish_writer) - ) - t.start() - - try: - # Update logging configs - LOGGING_CONFIG["formatters"]["default"][ - "fmt" - ] = "[%(asctime)s] %(levelprefix)s %(message)s" - LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S" - LOGGING_CONFIG["formatters"]["access"][ - "fmt" - ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s' - LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S" - - # Listen for HTTP requests - uvicorn.run( - app, - host=server_args.host, - port=server_args.port, - log_level=server_args.log_level_http or server_args.log_level, - timeout_keep_alive=5, - loop="uvloop", - ) - finally: - t.join() - - -def _set_envs_and_config(server_args: ServerArgs): - # Set global environments - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - os.environ["NCCL_CUMEM_ENABLE"] = "0" - os.environ["NCCL_NVLS_ENABLE"] = "0" - os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" - - # Set prometheus env vars - if server_args.enable_metrics: - set_prometheus_multiproc_dir() - - # Set ulimit - set_ulimit() - - # Fix triton bugs - if server_args.tp_size * server_args.dp_size > 1: - # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. - maybe_set_triton_cache_manager() - - # Check flashinfer version - if server_args.attention_backend == "flashinfer": - assert_pkg_version( - "flashinfer", - "0.1.6", - "Please uninstall the old version and " - "reinstall the latest version by following the instructions " - "at https://docs.flashinfer.ai/installation.html.", - ) - - # Register the signal handler. - # The child processes will send SIGQUIT to this process when any error happens - # This process then clean up the whole process tree - def sigquit_handler(signum, frame): - kill_process_tree(os.getpid()) - - signal.signal(signal.SIGQUIT, sigquit_handler) - - # Set mp start method - mp.set_start_method("spawn", force=True) - - -def _wait_and_warmup(server_args, pipe_finish_writer): - headers = {} - url = server_args.url() - if server_args.api_key: - headers["Authorization"] = f"Bearer {server_args.api_key}" - - # Wait until the server is launched - success = False - for _ in range(120): - time.sleep(1) - try: - res = requests.get(url + "/get_model_info", timeout=5, headers=headers) - assert res.status_code == 200, f"{res=}, {res.text=}" - success = True - break - except (AssertionError, requests.exceptions.RequestException): - last_traceback = get_exception_traceback() - pass - - if not success: - if pipe_finish_writer is not None: - pipe_finish_writer.send(last_traceback) - logger.error(f"Initialization failed. warmup error: {last_traceback}") - kill_process_tree(os.getpid()) - return - - model_info = res.json() - - # Send a warmup request - request_name = "/generate" if model_info["is_generation"] else "/encode" - max_new_tokens = 8 if model_info["is_generation"] else 1 - json_data = { - "sampling_params": { - "temperature": 0, - "max_new_tokens": max_new_tokens, - }, - } - if server_args.skip_tokenizer_init: - json_data["input_ids"] = [10, 11, 12] - else: - json_data["text"] = "The capital city of France is" - - try: - for _ in range(server_args.dp_size): - res = requests.post( - url + request_name, - json=json_data, - headers=headers, - timeout=600, - ) - assert res.status_code == 200, f"{res}" - except Exception: - last_traceback = get_exception_traceback() - if pipe_finish_writer is not None: - pipe_finish_writer.send(last_traceback) - logger.error(f"Initialization failed. warmup error: {last_traceback}") - kill_process_tree(os.getpid()) - return - - # Debug print - # logger.info(f"{res.json()=}") - - logger.info("The server is fired up and ready to roll!") - if pipe_finish_writer is not None: - pipe_finish_writer.send("ready") - - if server_args.delete_ckpt_after_loading: - delete_directory(server_args.model_path) - - -STREAM_END_SYMBOL = b"data: [DONE]" -STREAM_CHUNK_START_SYMBOL = b"data:" - - -class Engine: - """ - SRT Engine without an HTTP server layer. - - This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where - launching the HTTP server adds unnecessary complexity or overhead, - """ - - def __init__(self, log_level: str = "error", *args, **kwargs): - """See the arguments in server_args.py::ServerArgs""" - - # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown() - atexit.register(self.shutdown) - - server_args = ServerArgs(*args, log_level=log_level, **kwargs) - launch_engine(server_args=server_args) - - def generate( - self, - # The input prompt. It can be a single prompt or a batch of prompts. - prompt: Optional[Union[List[str], str]] = None, - sampling_params: Optional[Union[List[Dict], Dict]] = None, - # The token ids for text; one can either specify text or input_ids. - input_ids: Optional[Union[List[List[int]], List[int]]] = None, - return_logprob: Optional[Union[List[bool], bool]] = False, - logprob_start_len: Optional[Union[List[int], int]] = None, - top_logprobs_num: Optional[Union[List[int], int]] = None, - lora_path: Optional[List[Optional[str]]] = None, - stream: bool = False, - ): - obj = GenerateReqInput( - text=prompt, - input_ids=input_ids, - sampling_params=sampling_params, - return_logprob=return_logprob, - logprob_start_len=logprob_start_len, - top_logprobs_num=top_logprobs_num, - lora_path=lora_path, - stream=stream, - ) - - # get the current event loop - loop = asyncio.get_event_loop() - ret = loop.run_until_complete(generate_request(obj, None)) - - if stream is True: - - def generator_wrapper(): - offset = 0 - loop = asyncio.get_event_loop() - generator = ret.body_iterator - while True: - chunk = loop.run_until_complete(generator.__anext__()) - - if chunk.startswith(STREAM_END_SYMBOL): - break - else: - data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :]) - data["text"] = data["text"][offset:] - offset += len(data["text"]) - yield data - - # we cannot yield in the scope of generate() because python does not allow yield + return in the same function - # however, it allows to wrap the generator as a subfunction and return - return generator_wrapper() - else: - return ret - - async def async_generate( - self, - # The input prompt. It can be a single prompt or a batch of prompts. - prompt: Optional[Union[List[str], str]] = None, - sampling_params: Optional[Dict] = None, - # The token ids for text; one can either specify text or input_ids. - input_ids: Optional[Union[List[List[int]], List[int]]] = None, - return_logprob: Optional[Union[List[bool], bool]] = False, - logprob_start_len: Optional[Union[List[int], int]] = None, - top_logprobs_num: Optional[Union[List[int], int]] = None, - lora_path: Optional[List[Optional[str]]] = None, - stream: bool = False, - ): - obj = GenerateReqInput( - text=prompt, - input_ids=input_ids, - sampling_params=sampling_params, - return_logprob=return_logprob, - logprob_start_len=logprob_start_len, - top_logprobs_num=top_logprobs_num, - lora_path=lora_path, - stream=stream, - ) - - ret = await generate_request(obj, None) - - if stream is True: - generator = ret.body_iterator - - async def generator_wrapper(): - - offset = 0 - - while True: - chunk = await generator.__anext__() - - if chunk.startswith(STREAM_END_SYMBOL): - break - else: - data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :]) - data["text"] = data["text"][offset:] - offset += len(data["text"]) - yield data - - return generator_wrapper() - else: - return ret - - def shutdown(self): - kill_process_tree(os.getpid(), include_parent=False) - - def get_tokenizer(self): - global tokenizer_manager - - if tokenizer_manager is None: - raise ReferenceError("Tokenizer Manager is not initialized.") - else: - return tokenizer_manager.tokenizer - - def encode( - self, - prompt: Union[str, List[str], List[Dict], List[List[Dict]]], - ): - obj = EmbeddingReqInput(text=prompt) - - # get the current event loop - loop = asyncio.get_event_loop() - return loop.run_until_complete(encode_request(obj, None)) - - def start_profile(self): - tokenizer_manager.start_profile() - - def stop_profile(self): - tokenizer_manager.stop_profile() - - def get_server_info(self): - return { - **dataclasses.asdict(tokenizer_manager.server_args), # server args - **scheduler_info, - "version": __version__, - } - - def init_weights_update_group( - self, - master_address: str, - master_port: int, - rank_offset: int, - world_size: int, - group_name: str, - backend: str = "nccl", - ): - """Initialize parameter update group.""" - obj = InitWeightsUpdateGroupReqInput( - master_address=master_address, - master_port=master_port, - rank_offset=rank_offset, - world_size=world_size, - group_name=group_name, - backend=backend, - ) - - async def _init_group(): - return await tokenizer_manager.init_weights_update_group(obj, None) - - loop = asyncio.get_event_loop() - return loop.run_until_complete(_init_group()) - - def update_weights_from_distributed(self, name, dtype, shape): - """Update weights from distributed source.""" - obj = UpdateWeightsFromDistributedReqInput( - name=name, - dtype=dtype, - shape=shape, - ) - - async def _update_weights(): - return await tokenizer_manager.update_weights_from_distributed(obj, None) - - loop = asyncio.get_event_loop() - return loop.run_until_complete(_update_weights()) - - def get_weights_by_name(self, name, truncate_size=100): - """Get weights by parameter name.""" - obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) - - async def _get_weights(): - return await tokenizer_manager.get_weights_by_name(obj, None) - - loop = asyncio.get_event_loop() - return loop.run_until_complete(_get_weights()) - - -class Runtime: - """ - A wrapper for the HTTP server. - This is used for launching the server in a python program without - using the commond line interface. - - It is mainly used for the frontend language. - You should use the Engine class if you want to do normal offline processing. - """ - - def __init__( - self, - log_level: str = "error", - *args, - **kwargs, - ): - """See the arguments in server_args.py::ServerArgs""" - self.server_args = ServerArgs(*args, log_level=log_level, **kwargs) - - # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown() - atexit.register(self.shutdown) - - # Pre-allocate ports - for port in range(10000, 40000): - if is_port_available(port): - break - port += 1 - self.server_args.port = port - - self.url = self.server_args.url() - self.generate_url = self.url + "/generate" - - # NOTE: We store pid instead of proc to fix some issues during __delete__ - self.pid = None - pipe_reader, pipe_writer = mp.Pipe(duplex=False) - - proc = mp.Process( - target=launch_server, - args=(self.server_args, pipe_writer), - ) - proc.start() - pipe_writer.close() - self.pid = proc.pid - - try: - init_state = pipe_reader.recv() - except EOFError: - init_state = "" - - if init_state != "ready": - self.shutdown() - raise RuntimeError( - "Initialization failed. Please see the error messages above." - ) - - self.endpoint = RuntimeEndpoint(self.url) - - def shutdown(self): - if self.pid is not None: - kill_process_tree(self.pid) - self.pid = None - - def cache_prefix(self, prefix: str): - self.endpoint.cache_prefix(prefix) - - def get_tokenizer(self): - return get_tokenizer( - self.server_args.tokenizer_path, - tokenizer_mode=self.server_args.tokenizer_mode, - trust_remote_code=self.server_args.trust_remote_code, - ) - - async def async_generate( - self, - prompt: str, - sampling_params: Optional[Dict] = None, - ): - if self.server_args.skip_tokenizer_init: - json_data = { - "input_ids": prompt, - "sampling_params": sampling_params, - "stream": True, - } - else: - json_data = { - "text": prompt, - "sampling_params": sampling_params, - "stream": True, - } - pos = 0 - - timeout = aiohttp.ClientTimeout(total=3 * 3600) - async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: - async with session.post(self.generate_url, json=json_data) as response: - async for chunk, _ in response.content.iter_chunks(): - chunk = chunk.decode("utf-8") - if chunk and chunk.startswith("data:"): - if chunk == "data: [DONE]\n\n": - break - data = json.loads(chunk[5:].strip("\n")) - if "text" in data: - cur = data["text"][pos:] - if cur: - yield cur - pos += len(cur) - else: - yield data - - add_request = async_generate - - def generate( - self, - prompt: Union[str, List[str]], - sampling_params: Optional[Dict] = None, - return_logprob: Optional[Union[List[bool], bool]] = False, - logprob_start_len: Optional[Union[List[int], int]] = None, - top_logprobs_num: Optional[Union[List[int], int]] = None, - lora_path: Optional[List[Optional[str]]] = None, - ): - json_data = { - "text": prompt, - "sampling_params": sampling_params, - "return_logprob": return_logprob, - "logprob_start_len": logprob_start_len, - "top_logprobs_num": top_logprobs_num, - "lora_path": lora_path, - } - assert not isinstance(lora_path, list) or len(lora_path) == len(prompt) - response = requests.post( - self.url + "/generate", - json=json_data, - ) - return json.dumps(response.json()) - - def encode( - self, - prompt: Union[str, List[str], List[Dict], List[List[Dict]]], - ): - json_data = {"text": prompt} - response = requests.post(self.url + "/encode", json=json_data) - return json.dumps(response.json()) - - async def get_server_info(self): - async with aiohttp.ClientSession() as session: - async with session.get(f"{self.url}/get_server_info") as response: - if response.status == 200: - return await response.json() - else: - error_data = await response.json() - raise RuntimeError( - f"Failed to get server info. {error_data['error']['message']}" - ) - - def __del__(self): - self.shutdown() +# Some shortcuts for backward compatibility. +# They will be removed in new versions. +from sglang.srt.entrypoints.engine import Engine +from sglang.srt.entrypoints.http_server import kill_process_tree, launch_server diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index c2e75a642bd..f9340e47764 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -29,8 +29,9 @@ get_nvgpu_memory_capacity, is_flashinfer_available, is_hip, - is_ipv6, is_port_available, + is_valid_ipv6_address, + nullable_str, ) logger = logging.getLogger(__name__) @@ -42,11 +43,11 @@ class ServerArgs: model_path: str tokenizer_path: Optional[str] = None tokenizer_mode: str = "auto" - skip_tokenizer_init: bool = False load_format: str = "auto" trust_remote_code: bool = True dtype: str = "auto" kv_cache_dtype: str = "auto" + quantization_param_path: nullable_str = None quantization: Optional[str] = None context_length: Optional[int] = None device: str = "cuda" @@ -54,8 +55,9 @@ class ServerArgs: chat_template: Optional[str] = None is_embedding: bool = False revision: Optional[str] = None + skip_tokenizer_init: bool = False - # Port + # Port for the HTTP server host: str = "127.0.0.1" port: int = 30000 @@ -68,10 +70,12 @@ class ServerArgs: schedule_policy: str = "lpm" schedule_conservativeness: float = 1.0 cpu_offload_gb: int = 0 + prefill_only_one_req: bool = False # Other runtime options tp_size: int = 1 stream_interval: int = 1 + stream_output: bool = False random_seed: Optional[int] = None constrained_json_whitespace_pattern: Optional[str] = None watchdog_timeout: float = 300 @@ -88,12 +92,13 @@ class ServerArgs: # API related api_key: Optional[str] = None - file_storage_pth: str = "SGLang_storage" + file_storage_pth: str = "sglang_storage" enable_cache_report: bool = False # Data parallelism dp_size: int = 1 load_balance_method: str = "round_robin" + # Expert parallelism ep_size: int = 1 @@ -105,14 +110,6 @@ class ServerArgs: # Model override args in JSON json_model_override_args: str = "{}" - # Double Sparsity - enable_double_sparsity: bool = False - ds_channel_config_path: str = None - ds_heavy_channel_num: int = 32 - ds_heavy_token_num: int = 256 - ds_heavy_channel_type: str = "qk" - ds_sparse_decode_threshold: int = 4096 - # LoRA lora_paths: Optional[List[str]] = None max_loras_per_batch: int = 8 @@ -122,6 +119,21 @@ class ServerArgs: sampling_backend: Optional[str] = None grammar_backend: Optional[str] = "outlines" + # Speculative decoding + speculative_draft_model_path: Optional[str] = None + speculative_algorithm: Optional[str] = None + speculative_num_steps: int = 5 + speculative_num_draft_tokens: int = 64 + speculative_eagle_topk: int = 8 + + # Double Sparsity + enable_double_sparsity: bool = False + ds_channel_config_path: str = None + ds_heavy_channel_num: int = 32 + ds_heavy_token_num: int = 256 + ds_heavy_channel_type: str = "qk" + ds_sparse_decode_threshold: int = 4096 + # Optimization/debug options disable_radix_cache: bool = False disable_jump_forward: bool = False @@ -137,12 +149,21 @@ class ServerArgs: enable_torch_compile: bool = False torch_compile_max_bs: int = 32 cuda_graph_max_bs: Optional[int] = None + cuda_graph_bs: Optional[List[int]] = None torchao_config: str = "" enable_nan_detection: bool = False enable_p2p_check: bool = False triton_attention_reduce_in_fp32: bool = False + triton_attention_num_kv_splits: int = 8 num_continuous_decode_steps: int = 1 delete_ckpt_after_loading: bool = False + enable_memory_saver: bool = False + allow_auto_truncate: bool = False + + # Custom logit processor + enable_custom_logit_processor: bool = False + tool_call_parser: str = None + enable_hierarchical_cache: bool = False def __post_init__(self): # Set missing default values @@ -216,25 +237,34 @@ def __post_init__(self): ) self.disable_cuda_graph = True + # Expert parallelism + if self.enable_ep_moe: + self.ep_size = self.tp_size + logger.info( + f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." + ) + # Others if self.enable_dp_attention: self.dp_size = self.tp_size + assert self.tp_size % self.dp_size == 0 self.chunked_prefill_size = self.chunked_prefill_size // 2 - self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96) self.schedule_conservativeness = self.schedule_conservativeness * 0.3 - self.disable_overlap_schedule = True logger.warning( f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. " - f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. " f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. " "Data parallel size is adjusted to be the same as tensor parallel size. " - "Overlap scheduler is disabled." ) - # Expert parallelism - if self.enable_ep_moe: - self.ep_size = self.tp_size + + # Speculative Decoding + if self.speculative_algorithm == "EAGLE": + self.prefill_only_one_req = True + self.disable_cuda_graph_padding = True + self.disable_radix_cache = True + self.disable_overlap_schedule = True + self.chunked_prefill_size = -1 logger.info( - f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." + "The radix cache, chunked prefill, and overlap scheduler are disabled because of using eagle speculative decoding." ) # GGUF @@ -282,7 +312,16 @@ def add_cli_args(parser: argparse.ArgumentParser): "--load-format", type=str, default=ServerArgs.load_format, - choices=["auto", "pt", "safetensors", "npcache", "dummy", "gguf"], + choices=[ + "auto", + "pt", + "safetensors", + "npcache", + "dummy", + "gguf", + "bitsandbytes", + "layered", + ], help="The format of the model weights to load. " '"auto" will try to load the weights in the safetensors format ' "and fall back to the pytorch bin format if safetensors format " @@ -293,7 +332,12 @@ def add_cli_args(parser: argparse.ArgumentParser): "a numpy cache to speed up the loading. " '"dummy" will initialize the weights with random values, ' "which is mainly for profiling." - '"gguf" will load the weights in the gguf format. ', + '"gguf" will load the weights in the gguf format. ' + '"bitsandbytes" will load the weights using bitsandbytes ' + "quantization." + '"layered" loads weights layer by layer so that one can quantize a ' + "layer before loading another to make the peak memory envelope " + "smaller.", ) parser.add_argument( "--trust-remote-code", @@ -318,8 +362,17 @@ def add_cli_args(parser: argparse.ArgumentParser): "--kv-cache-dtype", type=str, default=ServerArgs.kv_cache_dtype, - choices=["auto", "fp8_e5m2"], - help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.', + choices=["auto", "fp8_e5m2", "fp8_e4m3"], + help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.', + ) + parser.add_argument( + "--quantization-param-path", + type=nullable_str, + default=None, + help="Path to the JSON file containing the KV cache " + "scaling factors. This should generally be supplied, when " + "KV cache dtype is FP8. Otherwise, KV cache scaling factors " + "default to 1.0, which may cause accuracy issues. ", ) parser.add_argument( "--quantization", @@ -334,6 +387,8 @@ def add_cli_args(parser: argparse.ArgumentParser): "awq_marlin", "bitsandbytes", "gguf", + "modelopt", + "w8a8_int8", ], help="The quantization method.", ) @@ -347,7 +402,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--device", type=str, default="cuda", - choices=["cuda", "xpu", "hpu"], + choices=["cuda", "xpu", "hpu", "cpu"], help="The device type.", ) parser.add_argument( @@ -375,7 +430,6 @@ def add_cli_args(parser: argparse.ArgumentParser): "name, a tag name, or a commit id. If unspecified, will use " "the default version.", ) - # Memory and scheduling parser.add_argument( "--mem-fraction-static", @@ -421,13 +475,18 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.schedule_conservativeness, help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.", ) - parser.add_argument( "--cpu-offload-gb", type=int, default=ServerArgs.cpu_offload_gb, help="How many GBs of RAM to reserve for CPU offloading", ) + parser.add_argument( + "--prefill-only-one-req", + type=bool, + help="If true, we only prefill one request at one prefill batch", + default=ServerArgs.prefill_only_one_req, + ) # Other runtime options parser.add_argument( @@ -443,6 +502,11 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.stream_interval, help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher", ) + parser.add_argument( + "--stream-output", + action="store_true", + help="Whether to output as a sequence of disjoint segments.", + ) parser.add_argument( "--random-seed", type=int, @@ -506,7 +570,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--decode-log-interval", type=int, default=ServerArgs.decode_log_interval, - help="The log interval of decode batch", + help="The log interval of decode batch.", ) # API related @@ -546,6 +610,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "shortest_queue", ], ) + # Expert parallelism parser.add_argument( "--expert-parallel-size", @@ -577,43 +642,6 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.json_model_override_args, ) - # Double Sparsity - parser.add_argument( - "--enable-double-sparsity", - action="store_true", - help="Enable double sparsity attention", - ) - parser.add_argument( - "--ds-channel-config-path", - type=str, - default=ServerArgs.ds_channel_config_path, - help="The path of the double sparsity channel config", - ) - parser.add_argument( - "--ds-heavy-channel-num", - type=int, - default=ServerArgs.ds_heavy_channel_num, - help="The number of heavy channels in double sparsity attention", - ) - parser.add_argument( - "--ds-heavy-token-num", - type=int, - default=ServerArgs.ds_heavy_token_num, - help="The number of heavy tokens in double sparsity attention", - ) - parser.add_argument( - "--ds-heavy-channel-type", - type=str, - default=ServerArgs.ds_heavy_channel_type, - help="The type of heavy channels in double sparsity attention", - ) - parser.add_argument( - "--ds-sparse-decode-threshold", - type=int, - default=ServerArgs.ds_sparse_decode_threshold, - help="The type of heavy channels in double sparsity attention", - ) - # LoRA parser.add_argument( "--lora-paths", @@ -653,6 +681,75 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Choose the backend for grammar-guided decoding.", ) + # Speculative decoding + parser.add_argument( + "--speculative-algorithm", + type=str, + choices=["EAGLE"], + help="Speculative algorithm.", + ) + parser.add_argument( + "--speculative-draft-model-path", + type=str, + help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.", + ) + parser.add_argument( + "--speculative-num-steps", + type=int, + help="The number of steps sampled from draft model in Speculative Decoding.", + default=ServerArgs.speculative_num_steps, + ) + parser.add_argument( + "--speculative-num-draft-tokens", + type=int, + help="The number of token sampled from draft model in Speculative Decoding.", + default=ServerArgs.speculative_num_draft_tokens, + ) + parser.add_argument( + "--speculative-eagle-topk", + type=int, + help="The number of token sampled from draft model in eagle2 each step.", + choices=[1, 2, 4, 8], + default=ServerArgs.speculative_eagle_topk, + ) + + # Double Sparsity + parser.add_argument( + "--enable-double-sparsity", + action="store_true", + help="Enable double sparsity attention", + ) + parser.add_argument( + "--ds-channel-config-path", + type=str, + default=ServerArgs.ds_channel_config_path, + help="The path of the double sparsity channel config", + ) + parser.add_argument( + "--ds-heavy-channel-num", + type=int, + default=ServerArgs.ds_heavy_channel_num, + help="The number of heavy channels in double sparsity attention", + ) + parser.add_argument( + "--ds-heavy-token-num", + type=int, + default=ServerArgs.ds_heavy_token_num, + help="The number of heavy tokens in double sparsity attention", + ) + parser.add_argument( + "--ds-heavy-channel-type", + type=str, + default=ServerArgs.ds_heavy_channel_type, + help="The type of heavy channels in double sparsity attention", + ) + parser.add_argument( + "--ds-sparse-decode-threshold", + type=int, + default=ServerArgs.ds_sparse_decode_threshold, + help="The type of heavy channels in double sparsity attention", + ) + # Optimization/debug options parser.add_argument( "--disable-radix-cache", @@ -689,11 +786,6 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.", ) - parser.add_argument( - "--disable-nan-detection", - action="store_true", - help="Disable the NaN detection for better performance.", - ) parser.add_argument( "--disable-overlap-schedule", action="store_true", @@ -731,6 +823,12 @@ def add_cli_args(parser: argparse.ArgumentParser): default=ServerArgs.cuda_graph_max_bs, help="Set the maximum batch size for cuda graph.", ) + parser.add_argument( + "--cuda-graph-bs", + type=int, + nargs="+", + help="Set the list of batch sizes for cuda graph.", + ) parser.add_argument( "--torchao-config", type=str, @@ -753,6 +851,12 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." "This only affects Triton attention kernels.", ) + parser.add_argument( + "--triton-attention-num-kv-splits", + type=int, + default=ServerArgs.triton_attention_num_kv_splits, + help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.", + ) parser.add_argument( "--num-continuous-decode-steps", type=int, @@ -766,27 +870,33 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Delete the model checkpoint after loading the model.", ) - - # Deprecated arguments parser.add_argument( - "--enable-overlap-schedule", - action=DeprecatedAction, - help="'--enable-overlap-schedule' is deprecated. It is enabled by default now. Please drop this argument.", + "--enable-memory-saver", + action="store_true", + help="Allow saving memory using release_memory_occupation and resume_memory_occupation", ) parser.add_argument( - "--disable-flashinfer", - action=DeprecatedAction, - help="'--disable-flashinfer' is deprecated. Please use '--attention-backend triton' instead.", + "--allow-auto-truncate", + action="store_true", + help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.", ) parser.add_argument( - "--disable-flashinfer-sampling", - action=DeprecatedAction, - help="'--disable-flashinfer-sampling' is deprecated. Please use '--sampling-backend pytroch' instead.", + "--enable-custom-logit-processor", + action="store_true", + help="Enable users to pass custom logit processors to the server (disabled by default for security)", ) + # Function Calling parser.add_argument( - "--disable-disk-cache", - action=DeprecatedAction, - help="'--disable-disk-cache' is deprecated. Please use '--disable-outlines-disk-cache' instead.", + "--tool-call-parser", + type=str, + choices=["qwen25", "mistral", "llama3"], + default=ServerArgs.tool_call_parser, + help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'.", + ) + parser.add_argument( + "--enable-hierarchical-cache", + action="store_true", + help="Enable hierarchical cache", ) @classmethod @@ -798,7 +908,7 @@ def from_cli_args(cls, args: argparse.Namespace): return cls(**{attr: getattr(args, attr) for attr in attrs}) def url(self): - if is_ipv6(self.host): + if is_valid_ipv6_address(self.host): return f"http://[{self.host}]:{self.port}" else: return f"http://{self.host}:{self.port}" @@ -808,8 +918,8 @@ def check_server_args(self): self.tp_size % self.nnodes == 0 ), "tp_size must be divisible by number of nodes" assert not ( - self.dp_size > 1 and self.nnodes != 1 - ), "multi-node data parallel is not supported" + self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention + ), "multi-node data parallel is not supported unless dp attention!" assert ( self.max_loras_per_batch > 0 # FIXME @@ -847,6 +957,9 @@ def prepare_server_args(argv: List[str]) -> ServerArgs: return server_args +ZMQ_TCP_PORT_DELTA = 233 + + @dataclasses.dataclass class PortArgs: # The ipc filename for tokenizer to receive inputs from detokenizer (zmq) @@ -860,19 +973,49 @@ class PortArgs: nccl_port: int @staticmethod - def init_new(server_args) -> "PortArgs": + def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs": port = server_args.port + random.randint(100, 1000) while True: if is_port_available(port): break - port += 42 + if port < 60000: + port += 42 + else: + port -= 43 - return PortArgs( - tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, - scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, - detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, - nccl_port=port, - ) + if not server_args.enable_dp_attention: + # Normal case, use IPC within a single node + return PortArgs( + tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", + scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", + detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", + nccl_port=port, + ) + else: + # DP attention. Use TCP + port to handle both single-node and multi-node. + if server_args.nnodes == 1 and server_args.dist_init_addr is None: + dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA) + else: + dist_init_addr = server_args.dist_init_addr.split(":") + assert ( + len(dist_init_addr) == 2 + ), "please provide --dist-init-addr as host:port of head node" + + dist_init_host, dist_init_port = dist_init_addr + port_base = int(dist_init_port) + 1 + if dp_rank is None: + scheduler_input_port = ( + port_base + 2 + ) # TokenizerManager to DataParallelController + else: + scheduler_input_port = port_base + 2 + 1 + dp_rank + + return PortArgs( + tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}", + scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}", + detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}", + nccl_port=port, + ) class LoRAPathAction(argparse.Action): diff --git a/python/sglang/srt/speculative/build_eagle_tree.py b/python/sglang/srt/speculative/build_eagle_tree.py new file mode 100644 index 00000000000..6412825ed8c --- /dev/null +++ b/python/sglang/srt/speculative/build_eagle_tree.py @@ -0,0 +1,347 @@ +import cutex +import torch + +# parent_table [bs,topk*depth+)] +# selected_index [bs,draft_token_num-1)] +# verified_seq_len [bs] +# tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] +# positions [bs*draft_token] +# retrive_index [b, draft_token, depth+2] +kernels = cutex.SourceModule( + """ +//cuda +__global__ void build_tree(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, + Tensor tree_mask, Tensor positions, Tensor retrive_index, int topk, int depth, int draft_token_num) { + int bid = blockIdx.x; + int tid = threadIdx.x; + if (tid >= draft_token_num){ + return; + } + int seq_tree_idx = draft_token_num * draft_token_num * bid; + for(int i=0; i 1: + index = tl.arange(0, draft_token_num) + mask_left = index != max_index + remained_index = tl.where(mask_max and mask_left, index, 0) + max_index = tl.max(remained_index) + + tl.store(accept_length + pid, accept_len) + retrive_index_ptr = retrive_index + (retrive_start + max_index) * max_len + retrive_offset = tl.arange(0, max_len_upper) + retrive_load_mask = retrive_offset < accept_len + 1 + data = tl.load(retrive_index_ptr + retrive_offset, mask=retrive_load_mask) + + tl.store( + accept_index + pid * max_len + retrive_offset, data, mask=retrive_load_mask + ) + + extract_load_ptr = accept_index + pid * max_len + accept_len + if accept_len == max_len - 1: + extract_data = tl.load(extract_load_ptr - 1) + tl.store(extract_index + pid * 2, extract_data) + extract_data = tl.load(extract_load_ptr) + tl.store(extract_index + pid * 2 + 1, extract_data) + + else: + extract_data = tl.load(extract_load_ptr) + tl.store(extract_index + pid * 2, extract_data) + + +@triton.jit +def create_extend_spec_info( + verified_id, + seq_len, + accept_len, + accept_len_cum, + positions, + new_verified_id, + accept_len_upper: tl.constexpr, +): + pid = tl.program_id(axis=0) + offset = 0 if pid == 0 else tl.load(accept_len_cum + pid - 1) + seq_length = tl.load(seq_len + pid) + accept_length = tl.load(accept_len + pid) + positions_ptr = positions + offset + data = tl.arange(0, accept_len_upper) + mask = data < accept_length + tl.store(positions_ptr + data, seq_length - accept_length + data, mask) + + offset = tl.load(accept_len_cum + pid) - 1 + verified_id_data = tl.load(verified_id + offset) + tl.store(new_verified_id + pid, verified_id_data) + + +@triton.jit +def assign_req_to_token_pool( + req_pool_indices, + req_to_token, + start_offset, + end_offset, + out_cache_loc, + pool_len: tl.constexpr, + bs_upper: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 32 + pid = tl.program_id(axis=0) + kv_start = tl.load(start_offset + pid) + kv_end = tl.load(end_offset + pid) + token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len + + length_offset = tl.arange(0, bs_upper) + start = tl.load(start_offset + length_offset, mask=length_offset < pid) + end = tl.load(end_offset + length_offset, mask=length_offset < pid) + out_offset = tl.sum(end - start, axis=0) + + out_cache_ptr = out_cache_loc + out_offset + + save_offset = tl.arange(0, BLOCK_SIZE) + kv_start + load_offset = tl.arange(0, BLOCK_SIZE) + + num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) + for _ in range(num_loop): + mask = save_offset < kv_end + data = tl.load(out_cache_ptr + load_offset, mask=mask) + tl.store(token_pool + save_offset, data, mask=mask) + save_offset += BLOCK_SIZE + load_offset += BLOCK_SIZE + + +@triton.jit +def generate_draft_decode_kv_indices( + req_pool_indices, + req_to_token, + paged_kernel_lens, + kv_indices, + iters: tl.constexpr, + topk: tl.constexpr, + pool_len: tl.constexpr, + bs_upper: tl.constexpr, + iter_upper: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 128 + bid = tl.program_id(axis=0) + topk_id = tl.program_id(axis=1) + + load_offset = tl.arange(0, bs_upper) + seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid) + seq_len = tl.load(paged_kernel_lens + bid) + cum_seq_len = tl.sum(seq_lens) + + kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters) + kv_ptr = kv_indices + kv_offset + token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len + + kv_offset = tl.arange(0, BLOCK_SIZE) + num_loop = tl.cdiv(seq_len, BLOCK_SIZE) + for _ in range(num_loop): + mask = kv_offset < seq_len + data = tl.load(token_pool_ptr + kv_offset, mask=mask) + tl.store(kv_ptr + kv_offset, data, mask=mask) + kv_offset += BLOCK_SIZE + + extend_offset = tl.arange(0, iter_upper) + extend_data = tl.load( + token_pool_ptr + seq_len + tl.arange(0, iter_upper) * topk + topk_id, + mask=extend_offset < iters, + ) + tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters) + + +class EAGLEDraftInput(SpecInfo): + def __init__(self): + self.prev_mode = ForwardMode.DECODE + + self.scores: torch.Tensor = None + self.score_list: List[torch.Tensor] = [] + self.token_list: List[torch.Tensor] = [] + self.origin_score_list: List[torch.Tensor] = [] # used for sampling + self.parents_list: List[torch.Tensor] = [] + self.cache_list: List[torch.Tenor] = [] + self.iter = 0 + + # shape: (b, hidden_size) + self.hidden_states: torch.Tensor = None + # shape: (b,) + self.verified_id: torch.Tensor = None + # shape: (b, vocab_size) + self.sample_output: torch.Tensor = None + + self.positions: torch.Tensor = None + self.accept_length: torch.Tensor = None + self.accept_length_cpu: List[int] = None + + def load_server_args(self, server_args: ServerArgs): + self.topk: int = server_args.speculative_eagle_topk + self.num_verify_token: int = server_args.speculative_num_draft_tokens + self.spec_steps = server_args.speculative_num_steps + + def prepare_for_extend(self, batch: ScheduleBatch): + req_pool_indices = batch.alloc_req_slots(len(batch.reqs)) + out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) + batch.out_cache_loc = out_cache_loc + + pt = 0 + for i, req in enumerate(batch.reqs): + req.req_pool_idx = req_pool_indices[i] + pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids) + assert seq_len - pre_len == req.extend_input_len + + if pre_len > 0: + batch.req_to_token_pool.req_to_token[req.req_pool_idx][ + :pre_len + ] = req.prefix_indices + + batch.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = ( + out_cache_loc[pt : pt + req.extend_input_len] + ) + + pt += req.extend_input_len + + # TODO: support batching inputs + assert len(batch.extend_lens) == 1 + batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id)) + + def filter_batch( + self, + new_indices: torch.Tensor, + ): + self.sample_output = self.sample_output[: len(new_indices)] + self.hidden_states = self.hidden_states[: len(new_indices)] + self.verified_id = self.verified_id[: len(new_indices)] + + def prepare_for_decode(self, batch: ScheduleBatch): + prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab) + top = torch.topk(prob, self.topk, dim=-1) + topk_index, topk_p = ( + top.indices, + top.values, + ) # shape: (b * top_k, top_k) or (b, top_k) + + if self.prev_mode.is_decode(): + scores = torch.mul( + self.scores.unsqueeze(2), topk_p.reshape(-1, self.topk, self.topk) + ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk) + topk_cs = torch.topk( + scores.flatten(start_dim=1), self.topk, dim=-1 + ) # (b, topk) + topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values + + selected_input_index = topk_cs_index.flatten() // self.topk + torch.arange( + 0, batch.batch_size() * self.topk, step=self.topk, device="cuda" + ).repeat_interleave(self.topk) + + batch.spec_info.hidden_states = batch.spec_info.hidden_states[ + selected_input_index, : + ] + + topk_index = topk_index.reshape(-1, self.topk**2) + batch.input_ids = torch.gather( + topk_index, index=topk_cs_index, dim=1 + ).flatten() + batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids)) + + self.scores = topk_cs_p + self.score_list.append(scores) # (b, topk, topk) + self.token_list.append(topk_index) # (b, topk * topk) + self.origin_score_list.append(topk_p.reshape(topk_index.shape)) + self.parents_list.append( + topk_cs_index + (self.topk**2 * (self.iter - 1) + self.topk) + ) # shape: (b, topk) + else: + # ForwardMode.EXTEND or ForwardMode.DRAFT_EXTEND + batch.spec_info.hidden_states = ( + batch.spec_info.hidden_states.repeat_interleave(self.topk, dim=0) + ) + + batch.input_ids = topk_index.flatten() + batch.out_cache_loc = batch.alloc_token_slots(topk_index.numel()) + + self.scores = topk_p # shape: (b, topk) + self.score_list.append(topk_p.unsqueeze(1)) # shape: (b, 1, topk) + self.token_list.append(topk_index) # shape: (b, topk) + self.origin_score_list.append(topk_p) + self.parents_list.append( + torch.arange(-1, self.topk, dtype=torch.long, device="cuda") + .unsqueeze(0) + .repeat(self.scores.shape[0], 1) + ) # shape: (b, topk + 1) + self.cache_list.append(batch.out_cache_loc) + self.positions = ( + batch.seq_lens[:, None] + + torch.full( + [1, self.topk], fill_value=self.iter, device="cuda", dtype=torch.long + ) + ).flatten() + + bs = len(batch.seq_lens) + assign_req_to_token_pool[(bs,)]( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens + self.topk * self.iter, + batch.seq_lens + self.topk * (self.iter + 1), + batch.out_cache_loc, + batch.req_to_token_pool.req_to_token.shape[1], + triton.next_power_of_2(bs), + ) + self.iter += 1 + + def prepare_extend_after_decode(self, batch: ScheduleBatch): + batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel()) + accept_length_cpu = batch.spec_info.accept_length_cpu + batch.extend_lens = [x + 1 for x in accept_length_cpu] + batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend + seq_lens_cpu = batch.seq_lens.tolist() + + pt = 0 + i = 0 + for req in batch.reqs: + if req.finished(): + continue + # assert seq_len - pre_len == req.extend_input_len + input_len = batch.extend_lens[i] + seq_len = seq_lens_cpu[i] + batch.req_to_token_pool.req_to_token[req.req_pool_idx][ + seq_len - input_len : seq_len + ] = batch.out_cache_loc[pt : pt + input_len] + pt += input_len + i += 1 + assert pt == batch.out_cache_loc.shape[0] + + self.positions = torch.empty_like(self.verified_id) + new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long) + self.accept_length.add_(1) + + create_extend_spec_info[(self.accept_length.numel(),)]( + self.verified_id, + batch.seq_lens, + self.accept_length, + torch.cumsum(self.accept_length, axis=0, dtype=torch.int), + self.positions, + new_verified_id, + triton.next_power_of_2(self.spec_steps + 1), + ) + + batch.seq_lens_sum = sum(seq_lens_cpu) + batch.input_ids = self.verified_id + self.verified_id = new_verified_id + + def prepare_for_verify(self, batch: ScheduleBatch): + score_list = torch.cat(self.score_list, dim=1).flatten( + 1 + ) # b, n, topk; n= 1+(self.iter-1)*self.topk + ss_token_list = torch.cat( + self.token_list, dim=1 + ) # b, (self.topk+(self.iter-1)*self.topk) + origin_token_list = torch.cat(self.origin_score_list, dim=1) + top_scores = torch.topk(score_list, self.num_verify_token - 1, dim=-1) + top_scores_index = top_scores.indices + top_scores_index = torch.sort(top_scores_index).values + + draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1) + scores = torch.gather(origin_token_list, index=top_scores_index, dim=1) + draft_tokens = torch.cat((self.verified_id.unsqueeze(1), draft_tokens), dim=1) + parent_list = torch.cat(self.parents_list[:-1], dim=1) + + tree_mask, position, retrive_index, retrive_cum_len = build_tree_kernel( + parent_list, + top_scores_index, + batch.seq_lens, + self.topk, + self.iter - 1, + self.num_verify_token, + ) + + return EagleVerifyInput( + draft_tokens.flatten(), + scores.flatten(), + tree_mask, + position, + retrive_index, + retrive_cum_len, + self.num_verify_token, + ) + + def generate_attn_arg_decode( + self, + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + req_to_token: torch.Tensor, + ): + seq_num = req_pool_indices.numel() + bs = self.topk * req_pool_indices.numel() + seq_len = self.positions.reshape(-1).contiguous() + + cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") + cum_kv_seq_len[1:] = torch.cumsum(seq_len + 1, dim=0) + total_len = torch.sum(paged_kernel_lens).item() + + kv_indices = torch.empty( + (total_len * self.topk + seq_num * self.iter * self.topk,), + dtype=torch.int32, + device="cuda", + ) + + generate_draft_decode_kv_indices[(req_pool_indices.numel(), self.topk)]( + req_pool_indices, + req_to_token, + paged_kernel_lens, + kv_indices, + self.iter, + self.topk, + req_to_token.shape[1], + triton.next_power_of_2(seq_num), + triton.next_power_of_2(self.spec_steps), + ) + return bs, kv_indices, cum_kv_seq_len + + def clear_draft_cache(self, batch): + draft_cache = torch.cat(self.cache_list, dim=0) + batch.token_to_kv_pool.free(draft_cache) + + def generate_attn_arg_prefill( + self, + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + req_to_token: torch.Tensor, + ): + bs = self.accept_length.numel() + qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") + qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0) + + cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") + cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) + kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda") + + create_flashinfer_kv_indices_triton[(bs,)]( + req_to_token, + req_pool_indices, + paged_kernel_lens, + cum_kv_seq_len, + None, + kv_indices, + req_to_token.size(1), + ) + + return kv_indices, cum_kv_seq_len, qo_indptr, None + + def merge_batch(self, spec_info: EAGLEDraftInput): + if self.hidden_states is None: + self.hidden_states = spec_info.hidden_states + self.verified_id = spec_info.verified_id + self.sample_output = spec_info.sample_output + self.prev_mode = spec_info.prev_mode + return + if spec_info.hidden_states is None: + return + self.hidden_states = torch.cat( + [self.hidden_states, spec_info.hidden_states], axis=0 + ) + self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0) + self.sample_output = torch.cat([self.sample_output, spec_info.sample_output]) + + +class EagleVerifyInput(SpecInfo): + def __init__( + self, + draft_token: torch.Tensor, + draft_score: torch.Tensor, + tree_mask: torch.Tensor, + positions: torch.Tensor, + retrive_index: torch.Tensor, + retrive_cum_len: torch.Tensor, + draft_token_num: int, + ): + self.draft_token = draft_token + self.draft_score = draft_score + self.custom_mask = tree_mask + self.positions = positions + self.retrive_index = retrive_index + self.retrive_cum_len = retrive_cum_len + self.draft_token_num = draft_token_num + + def prepare_for_verify(self, batch: ScheduleBatch): + batch.input_ids = self.draft_token + batch.out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel()) + bs = batch.seq_lens.numel() + assign_req_to_token_pool[(bs,)]( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + batch.seq_lens + self.draft_token_num, + batch.out_cache_loc, + batch.req_to_token_pool.req_to_token.shape[1], + triton.next_power_of_2(bs), + ) + + def generate_attn_arg_prefill( + self, + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + req_to_token: torch.Tensor, + ): + batch_size = len(req_pool_indices) + qo_indptr = torch.arange( + 0, + (1 + batch_size) * self.draft_token_num, + step=self.draft_token_num, + dtype=torch.int32, + device="cuda", + ) + + cum_kv_seq_len = torch.zeros( + (batch_size + 1,), dtype=torch.int32, device="cuda" + ) + + paged_kernel_lens = paged_kernel_lens + self.draft_token_num + cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) + + kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda") + + create_flashinfer_kv_indices_triton[(batch_size,)]( + req_to_token, + req_pool_indices, + paged_kernel_lens, + cum_kv_seq_len, + None, + kv_indices, + req_to_token.size(1), + ) + return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask + + def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor: + predict = torch.argmax(logits_output.next_token_logits, dim=-1) + predict = torch.cat( + [predict, torch.full([1], -1, dtype=torch.long, device="cuda")], dim=-1 + ) + draft_token = torch.cat( + [self.draft_token, torch.full([1], -1, dtype=torch.long, device="cuda")], + dim=-1, + ) + target_predict = predict[self.retrive_index] + candidates = draft_token[self.retrive_index] + # logits = logits_output.next_token_logits[self.retrive_index] + # target_predict = torch.argmax(logits[:, :-1], dim=-1) + accept_mask = candidates[:, 1:] == target_predict[:, :-1] + accept_mask = (torch.cumprod(accept_mask, dim=1)).sum(dim=1) + bs = self.retrive_cum_len.numel() - 1 + + max_draft_len = self.retrive_index.shape[-1] + accept_index = torch.full( + (bs, max_draft_len), -1, dtype=torch.long, device="cuda" + ) + accept_length = torch.empty((bs,), dtype=torch.int, device="cuda") + extract_index = torch.full((bs * 2,), 0, dtype=torch.int, device="cuda") + eagle_verify_retrive[(bs,)]( + self.retrive_index.contiguous(), + accept_mask.contiguous(), + self.retrive_cum_len, + accept_index, + accept_length, + extract_index, + max_draft_len, + self.draft_token_num, + triton.next_power_of_2(max_draft_len), + ) + + draft_input = EAGLEDraftInput() + new_accept_index = [] + unfinished_index = [] + finished_extend_len = {} # {rid:accept_length + 1} + accept_index_cpu = accept_index.tolist() + predict_cpu = predict.tolist() + has_finished = False + + # iterate every accepted token and check if req has finished after append the token + # should be checked BEFORE free kv cache slots + for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)): + new_accept_index_ = [] + for j, idx in enumerate(accept_index_row): + if idx == -1: + break + id = predict_cpu[idx] + # if not found_finished: + req.output_ids.append(id) + finished_extend_len[req.rid] = j + 1 + req.check_finished() + if req.finished(): + has_finished = True + # set all tokens after finished token to -1 and break + accept_index[i, j + 1 :] = -1 + break + else: + new_accept_index_.append(idx) + if not req.finished(): + new_accept_index.extend(new_accept_index_) + unfinished_index.append(i) + req.spec_verify_ct += 1 + accept_length = (accept_index != -1).sum(dim=1) - 1 + + accept_index = accept_index[accept_index != -1] + accept_length_cpu = accept_length.tolist() + verified_id = predict[accept_index] + + evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) + evict_mask[accept_index] = False + mem_need_free_idx = batch.out_cache_loc[evict_mask] + batch.token_to_kv_pool.free(mem_need_free_idx) + assign_req_to_token_pool[(bs,)]( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + batch.seq_lens + accept_length + 1, + batch.out_cache_loc[accept_index], + batch.req_to_token_pool.req_to_token.shape[1], + triton.next_power_of_2(bs), + ) + batch.seq_lens.add_(accept_length + 1) + + if len(new_accept_index) > 0: + new_accept_index = torch.tensor(new_accept_index, device="cuda") + draft_input.verified_id = predict[new_accept_index] + draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index] + draft_input.accept_length = accept_length[unfinished_index] + draft_input.accept_length_cpu = [ + accept_length_cpu[i] for i in unfinished_index + ] + if has_finished: + draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index] + else: + draft_input.seq_lens_for_draft_extend = batch.seq_lens + + logits_output.next_token_logits = logits_output.next_token_logits[accept_index] + return ( + draft_input, + logits_output, + verified_id, + finished_extend_len, + accept_length_cpu, + ) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py new file mode 100644 index 00000000000..06a4372fce2 --- /dev/null +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -0,0 +1,183 @@ +from typing import List, Optional, Union + +import torch + +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.managers.schedule_batch import Req, ScheduleBatch +from sglang.srt.managers.tp_worker import TpModelWorker +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, +) +from sglang.srt.model_executor.model_runner import ModelRunner +from sglang.srt.server_args import ServerArgs +from sglang.srt.speculative.eagle_utils import EAGLEDraftInput +from sglang.srt.utils import rank0_print + + +class EAGLEWorker(TpModelWorker): + + def __init__( + self, + server_args: ServerArgs, + gpu_id: int, + tp_rank: int, + dp_rank: Optional[int], + nccl_port: int, + target_worker: TpModelWorker, + ): + # Do not capture cuda graph in `super().__init__()` + # We will capture it later + backup_disable_cuda_graph = server_args.disable_cuda_graph + server_args.disable_cuda_graph = True + super().__init__( + gpu_id=gpu_id, + tp_rank=tp_rank, + server_args=server_args, + nccl_port=nccl_port, + dp_rank=dp_rank, + is_draft_worker=True, + ) + self.target_worker = target_worker + self.server_args = server_args + self.finish_extend_len = [] + + # Share the embedding and lm_head + embed, head = self.target_worker.model_runner.model.get_embed_and_head() + self.model_runner.model.set_embed_and_head(embed, head) + self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph + self.model_runner.init_cuda_graphs() + + def forward_draft_decode(self, batch: ScheduleBatch): + batch.spec_info.prepare_for_decode(batch) + batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + logits_output = self.model_runner.forward(forward_batch) + self.capture_for_decode(logits_output, forward_batch) + + def forward_draft_extend(self, batch: ScheduleBatch): + self._set_mem_pool(batch, self.model_runner) + batch.spec_info.prepare_for_extend(batch) + batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + logits_output = self.model_runner.forward(forward_batch) + self.capture_for_decode(logits_output, forward_batch) + self._set_mem_pool(batch, self.target_worker.model_runner) + + def forward_batch_speculative_generation(self, batch: ScheduleBatch): + if batch.forward_mode.is_decode(): + # Draft + self._set_mem_pool(batch, self.model_runner) + for i in range(self.server_args.speculative_num_steps): + self.forward_draft_decode(batch) + batch.spec_info.clear_draft_cache(batch) + self._set_mem_pool(batch, self.target_worker.model_runner) + + # Verify + ( + next_draft_input, + logits_output, + verified_id, + self.finish_extend_len, + accept_length_cpu, + model_worker_batch, + ) = self.verify(batch) + next_draft_input.load_server_args(self.server_args) + batch.spec_info = next_draft_input + # if it is None, means all requsets are finished + if batch.spec_info.verified_id is not None: + self.forward_draft_extend_after_decode(batch) + return ( + logits_output, + verified_id, + model_worker_batch, + sum(accept_length_cpu), + ) + + else: + # Forward with the target model and get hidden states. + # We need the full hidden states to prefill the KV cache of the draft model. + model_worker_batch = batch.get_model_worker_batch() + model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL + logits_output, next_token_ids = self.target_worker.forward_batch_generation( + model_worker_batch + ) + + # Forward with the draft model. + spec_info = EAGLEDraftInput() + spec_info.load_server_args(self.server_args) + spec_info.hidden_states = logits_output.hidden_states + spec_info.verified_id = next_token_ids + batch.spec_info = spec_info + self.forward_draft_extend(batch) + return logits_output, next_token_ids, model_worker_batch, 0 + + def verify(self, batch: ScheduleBatch): + verify_input = batch.spec_info.prepare_for_verify(batch) + verify_input.prepare_for_verify(batch) + batch.forward_mode = ForwardMode.TARGET_VERIFY + batch.spec_info = verify_input + batch.spec_info.capture_hidden_mode = CaptureHiddenMode.FULL + model_worker_batch = batch.get_model_worker_batch() + logits_output, _ = self.target_worker.forward_batch_generation( + model_worker_batch, skip_sample=True + ) + verify_input.hidden_states = logits_output.hidden_states + res = verify_input.verify(batch, logits_output) + batch.forward_mode = ForwardMode.DECODE + return res + (model_worker_batch,) + + def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner): + batch.token_to_kv_pool = runner.token_to_kv_pool + batch.req_to_token_pool = runner.req_to_token_pool + + def forward_draft_extend_after_decode(self, batch: ScheduleBatch): + seq_lens_backup = batch.seq_lens + + self._set_mem_pool(batch, self.model_runner) + batch.forward_mode = ForwardMode.DRAFT_EXTEND + batch.spec_info.prepare_extend_after_decode(batch) + batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST + model_worker_batch = batch.get_model_worker_batch() + forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) + logits_output = self.model_runner.forward(forward_batch) + self.capture_for_decode(logits_output, forward_batch) + self._set_mem_pool(batch, self.target_worker.model_runner) + + # Restore backup. + # This is because `seq_lens` can be modified in `prepare_extend_after_decode` + batch.forward_mode = ForwardMode.DECODE + batch.seq_lens = seq_lens_backup + + def capture_for_decode( + self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch + ): + sample_output = torch.softmax( + logits_output.next_token_logits, dim=-1 + ) # TODO(kavioyu): Support more sampling methods + spec_info = forward_batch.spec_info + spec_info.sample_output = sample_output + spec_info.hidden_states = logits_output.hidden_states + spec_info.prev_mode = forward_batch.forward_mode + + # Don't support prefix share now. + def finish_request(self, reqs: Union[Req, List[Req]]): + if not isinstance(reqs, List): + reqs = [reqs] + for req in reqs: + if req.rid not in self.finish_extend_len: + continue + req_len = ( + len(req.origin_input_ids) + + len(req.output_ids) + - self.finish_extend_len[req.rid] + - 1 + ) + kv_indices = self.model_runner.req_to_token_pool.req_to_token[ + req.req_pool_idx + ][:req_len] + self.model_runner.token_to_kv_pool.free(kv_indices) + self.model_runner.req_to_token_pool.free(req.req_pool_idx) diff --git a/python/sglang/srt/speculative/spec_info.py b/python/sglang/srt/speculative/spec_info.py new file mode 100644 index 00000000000..5f156b837f9 --- /dev/null +++ b/python/sglang/srt/speculative/spec_info.py @@ -0,0 +1,24 @@ +from enum import IntEnum, auto + + +class SpeculativeAlgorithm(IntEnum): + NONE = auto() + EAGLE = auto() + + def is_none(self): + return self == SpeculativeAlgorithm.NONE + + def is_eagle(self): + return self == SpeculativeAlgorithm.EAGLE + + @staticmethod + def from_string(name: str): + name_map = { + "EAGLE": SpeculativeAlgorithm.EAGLE, + None: SpeculativeAlgorithm.NONE, + } + return name_map[name] + + +class SpecInfo: + pass diff --git a/python/sglang/srt/torch_memory_saver_adapter.py b/python/sglang/srt/torch_memory_saver_adapter.py new file mode 100644 index 00000000000..31f8ebf2f07 --- /dev/null +++ b/python/sglang/srt/torch_memory_saver_adapter.py @@ -0,0 +1,59 @@ +from abc import ABC +from contextlib import contextmanager + +try: + import torch_memory_saver + + _primary_memory_saver = torch_memory_saver.TorchMemorySaver() +except ImportError: + pass + + +class TorchMemorySaverAdapter(ABC): + @staticmethod + def create(enable: bool): + return ( + _TorchMemorySaverAdapterReal() if enable else _TorchMemorySaverAdapterNoop() + ) + + def configure_subprocess(self): + raise NotImplementedError + + def region(self): + raise NotImplementedError + + def pause(self): + raise NotImplementedError + + def resume(self): + raise NotImplementedError + + +class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter): + def configure_subprocess(self): + return torch_memory_saver.configure_subprocess() + + def region(self): + return _primary_memory_saver.region() + + def pause(self): + return _primary_memory_saver.pause() + + def resume(self): + return _primary_memory_saver.resume() + + +class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter): + @contextmanager + def configure_subprocess(self): + yield + + @contextmanager + def region(self): + yield + + def pause(self): + pass + + def resume(self): + pass diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 5c310136a21..ebb346bbc63 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -14,6 +14,9 @@ """Common utilities.""" import base64 +import ctypes +import dataclasses +import io import ipaddress import itertools import json @@ -27,12 +30,14 @@ import signal import socket import subprocess +import sys import tempfile import time import warnings from functools import lru_cache from importlib.metadata import PackageNotFoundError, version from io import BytesIO +from multiprocessing.reduction import ForkingPickler from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union import numpy as np @@ -59,7 +64,6 @@ logger = logging.getLogger(__name__) - show_time_cost = False time_infos = {} @@ -70,7 +74,7 @@ def is_hip() -> bool: def is_cuda(): - return hasattr(torch, "cuda") and torch.cuda.is_available() + return hasattr(torch, "cuda") and torch.version.cuda is not None def is_cuda_alike(): @@ -92,15 +96,11 @@ def is_flashinfer_available(): """ if not get_bool_env_var("SGLANG_IS_FLASHINFER_AVAILABLE", default="true"): return False - return torch.cuda.is_available() and not is_hip() + return torch.cuda.is_available() and torch.version.cuda -def is_ipv6(address): - try: - ipaddress.IPv6Address(address) - return True - except ipaddress.AddressValueError: - return False +def is_cuda_available(): + return torch.cuda.is_available() and torch.version.cuda def enable_show_time_cost(): @@ -169,7 +169,7 @@ def inner_func(*args, **kwargs): return wrapper -def get_available_gpu_memory(device, gpu_id, distributed=False): +def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True): """ Get available memory for cuda:gpu_id device. When distributed is True, the available memory is the minimum available memory of all GPUs. @@ -184,7 +184,8 @@ def get_available_gpu_memory(device, gpu_id, distributed=False): "which may cause useless memory allocation for torch CUDA context.", ) - torch.cuda.empty_cache() + if empty_cache: + torch.cuda.empty_cache() free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id) elif device == "xpu": @@ -196,7 +197,9 @@ def get_available_gpu_memory(device, gpu_id, distributed=False): f"WARNING: current device is not {gpu_id}, but {torch.xpu.current_device()}, ", "which may cause useless memory allocation for torch XPU context.", ) - torch.xpu.empty_cache() + + if empty_cache: + torch.xpu.empty_cache() used_memory = torch.xpu.memory_allocated() total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory free_gpu_memory = total_gpu_memory - used_memory @@ -213,6 +216,10 @@ def get_available_gpu_memory(device, gpu_id, distributed=False): free_gpu_memory, total_gpu_memory = torch.hpu.mem_get_info() + elif device == "cpu": + # TODO: rename the variables in the current function to be not GPU specific + free_gpu_memory = psutil.virtual_memory().available + if distributed: tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to( torch.device(device, gpu_id) @@ -330,6 +337,8 @@ def is_port_available(port): return True except socket.error: return False + except OverflowError: + return False def decode_video_base64(video_base64): @@ -500,76 +509,32 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N pass -def monkey_patch_vllm_p2p_access_check(gpu_id: int): +def monkey_patch_p2p_access_check(): """ - Monkey patch the slow p2p access check in vllm. + Monkey patch the slow p2p access check. NOTE: We assume the p2p access is always allowed, which can be wrong for some setups. """ - import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt + import sglang.srt.distributed.device_communicators.custom_all_reduce_utils as tgt setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True) # Suppress the warnings from this delete function when using sglang.bench_one_batch - from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce + from sglang.srt.distributed.device_communicators.custom_all_reduce import ( + CustomAllreduce, + ) setattr(CustomAllreduce, "__del__", lambda *args, **kwargs: None) -vllm_all_gather_backup = None - - -def monkey_patch_vllm_all_gather(reverse: bool = False): - """Monkey patch all-gather to remove in-place operations.""" - from torch.distributed import _functional_collectives as funcol - from vllm.distributed.parallel_state import GroupCoordinator - - global vllm_all_gather_backup - if vllm_all_gather_backup is None: - vllm_all_gather_backup = GroupCoordinator.all_gather - - def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: - world_size = self.world_size - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - assert ( - -input_.dim() <= dim < input_.dim() - ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" - if dim < 0: - # Convert negative dim to positive. - dim += input_.dim() - input_size = input_.size() - # Allocate output tensor. - output_tensor = torch.empty( - (world_size,) + input_size, dtype=input_.dtype, device=input_.device - ) - - output_tensor = funcol.all_gather_tensor( - input_, gather_dim=0, group=self.device_group - ).view((world_size,) + input_size) - - # Reshape - output_tensor = output_tensor.movedim(0, dim) - output_tensor = output_tensor.reshape( - input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :] - ) - return output_tensor - - if reverse: - setattr(GroupCoordinator, "all_gather", vllm_all_gather_backup) - else: - setattr(GroupCoordinator, "all_gather", all_gather) - - def monkey_patch_vllm_gguf_config(): - from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization.gguf import ( GGUFConfig, GGUFEmbeddingMethod, GGUFLinearMethod, ) + from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding def get_quant_method_with_embedding_replaced( @@ -704,13 +669,14 @@ def broadcast_pyobj( data: List[Any], rank: int, dist_group: Optional[torch.distributed.ProcessGroup] = None, + src: int = 0, ): """Broadcast inputs from rank=0 to all other ranks with torch.dist backend.""" if rank == 0: if len(data) == 0: tensor_size = torch.tensor([0], dtype=torch.long) - dist.broadcast(tensor_size, src=0, group=dist_group) + dist.broadcast(tensor_size, src=src, group=dist_group) else: serialized_data = pickle.dumps(data) size = len(serialized_data) @@ -719,19 +685,19 @@ def broadcast_pyobj( ) tensor_size = torch.tensor([size], dtype=torch.long) - dist.broadcast(tensor_size, src=0, group=dist_group) - dist.broadcast(tensor_data, src=0, group=dist_group) + dist.broadcast(tensor_size, src=src, group=dist_group) + dist.broadcast(tensor_data, src=src, group=dist_group) return data else: tensor_size = torch.tensor([0], dtype=torch.long) - dist.broadcast(tensor_size, src=0, group=dist_group) + dist.broadcast(tensor_size, src=src, group=dist_group) size = tensor_size.item() if size == 0: return [] tensor_data = torch.empty(size, dtype=torch.uint8) - dist.broadcast(tensor_data, src=0, group=dist_group) + dist.broadcast(tensor_data, src=src, group=dist_group) serialized_data = bytes(tensor_data.cpu().numpy()) data = pickle.loads(serialized_data) @@ -776,7 +742,9 @@ def first_rank_print(*args, **kwargs): pass -def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: str): +def get_zmq_socket( + context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool +): mem = psutil.virtual_memory() total_mem = mem.total / 1024**3 available_mem = mem.available / 1024**3 @@ -789,19 +757,22 @@ def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: if socket_type == zmq.PUSH: socket.setsockopt(zmq.SNDHWM, 0) socket.setsockopt(zmq.SNDBUF, buf_size) - socket.connect(f"ipc://{endpoint}") elif socket_type == zmq.PULL: socket.setsockopt(zmq.RCVHWM, 0) socket.setsockopt(zmq.RCVBUF, buf_size) - socket.bind(f"ipc://{endpoint}") else: raise ValueError(f"Unsupported socket type: {socket_type}") + if bind: + socket.bind(endpoint) + else: + socket.connect(endpoint) + return socket def dump_to_file(dirpath, name, value): - from vllm.distributed import get_tensor_model_parallel_rank + from sglang.srt.distributed import get_tensor_model_parallel_rank if get_tensor_model_parallel_rank() != 0: return @@ -1068,9 +1039,6 @@ def get_device_name(device_id: int = 0) -> str: if hasattr(torch, "cuda") and torch.cuda.is_available(): return torch.cuda.get_device_name(device_id) - if hasattr(torch, "hip") and torch.hip.is_available(): - return torch.hip.get_device_name(device_id) - if hasattr(torch, "xpu") and torch.xpu.is_available(): return torch.xpu.get_device_name(device_id) @@ -1083,9 +1051,6 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]: if hasattr(torch, "cuda") and torch.cuda.is_available(): major, minor = torch.cuda.get_device_capability(device_id) - if hasattr(torch, "hip") and torch.hip.is_available(): - major, minor = torch.cuda.get_device_capability(device_id) - if hasattr(torch, "xpu") and torch.xpu.is_available(): major, minor, *_ = torch.xpu.get_device_capability(device_id)["version"].split( "." @@ -1208,7 +1173,6 @@ def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) -> # https://github.com/pytorch/pytorch/blob/ # c1cd946818442aca8c7f812b16d187ce1586c3bc/ # torch/cuda/__init__.py#L831C1-L831C17 - import torch.cuda import torch.version if not torch.cuda._is_compiled(): @@ -1241,49 +1205,235 @@ def cuda_device_count_stateless() -> int: return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None)) -def should_use_tensor_core( - kv_cache_dtype: torch.dtype, - num_attention_heads: int, - num_kv_heads: int, -) -> bool: - """ - Determine whether to use tensor cores for attention computation. +def dataclass_to_string_truncated(data, max_length=2048): + if isinstance(data, str): + if len(data) > max_length: + half_length = max_length // 2 + return f"{repr(data[:half_length])} ... {repr(data[-half_length:])}" + else: + return f"{repr(data)}" + elif isinstance(data, (list, tuple)): + if len(data) > max_length: + half_length = max_length // 2 + return str(data[:half_length]) + " ... " + str(data[-half_length:]) + else: + return str(data) + elif isinstance(data, dict): + return ( + "{" + + ", ".join( + f"'{k}': {dataclass_to_string_truncated(v, max_length)}" + for k, v in data.items() + ) + + "}" + ) + elif dataclasses.is_dataclass(data): + fields = dataclasses.fields(data) + return ( + f"{data.__class__.__name__}(" + + ", ".join( + f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}" + for f in fields + ) + + ")" + ) + else: + return str(data) - Args: - kv_cache_dtype: Data type of the KV cache - num_attention_heads: Number of attention heads - num_kv_heads: Number of key/value heads - Returns: - bool: Whether to use tensor cores - """ - # Try to use environment variable first - env_override = os.environ.get("SGLANG_FLASHINFER_USE_TENSOR_CORE") - if env_override is not None: - return env_override.lower() == "true" +def permute_weight(x: torch.Tensor) -> torch.Tensor: + b_ = x.shape[0] + n_ = x.shape[1] + k_ = x.shape[2] + + x_ = x + if x.dtype == torch.bfloat16 or x.dtype == torch.float16: + x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 32), 4, 8) + elif x.dtype == torch.float8_e4m3fnuz or x.dtype == torch.int8: + x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 64), 4, 16) + else: + return x_ + + x_ = x_.permute(0, 1, 3, 4, 2, 5) + x_ = x_.contiguous() + x_ = x_.view(*x.shape) + return x_ + + +class MultiprocessingSerializer: + @staticmethod + def serialize(obj): + buf = io.BytesIO() + ForkingPickler(buf).dump(obj) + buf.seek(0) + return buf.read() + + @staticmethod + def deserialize(data): + return ForkingPickler.loads(data) + + +def debug_timing(func): + # todo: replace with a more organized instrumentation + def wrapper(*args, **kwargs): + if logger.isEnabledFor(logging.DEBUG): + tic = torch.cuda.Event(enable_timing=True) + toc = torch.cuda.Event(enable_timing=True) + tic.record() + result = func(*args, **kwargs) + toc.record() + torch.cuda.synchronize() # Ensure all CUDA operations are complete + elapsed = tic.elapsed_time(toc) + indices = kwargs.get("indices", args[1] if len(args) > 1 else None) + num_tokens = len(indices) if indices is not None else 0 + throughput = num_tokens / elapsed * 1000 if elapsed > 0 else 0 + logger.debug( + f"Transfer time: {elapsed} ms, throughput: {throughput} tokens/s" + ) + return result + else: + return func(*args, **kwargs) - # Try to use _grouped_size_compiled_for_decode_kernels if available - # This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug + return wrapper + + +def nullable_str(val: str): + if not val or val == "None": + return None + return val + + +def pyspy_dump_schedulers(): + """py-spy dump on all scheduler in a local node.""" try: - from flashinfer.decode import _grouped_size_compiled_for_decode_kernels + pid = psutil.Process().pid + # Command to run py-spy with the PID + cmd = f"py-spy dump --pid {pid}" + result = subprocess.run( + cmd, shell=True, capture_output=True, text=True, check=True + ) + logger.info(f"Profile for PID {pid}:\n{result.stdout}") + except subprocess.CalledProcessError as e: + logger.info(f"Failed to profile PID {pid}. Error: {e.stderr}") - if not _grouped_size_compiled_for_decode_kernels( - num_attention_heads, - num_kv_heads, - ): - return True - else: - return False - except (ImportError, AttributeError): + +def kill_itself_when_parent_died(): + if sys.platform == "linux": + # sigkill this process when parent worker manager dies + PR_SET_PDEATHSIG = 1 + libc = ctypes.CDLL("libc.so.6") + libc.prctl(PR_SET_PDEATHSIG, signal.SIGKILL) + else: + logger.warninig("kill_itself_when_parent_died is only supported in linux.") + + +def set_uvicorn_logging_configs(): + from uvicorn.config import LOGGING_CONFIG + + LOGGING_CONFIG["formatters"]["default"][ + "fmt" + ] = "[%(asctime)s] %(levelprefix)s %(message)s" + LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S" + LOGGING_CONFIG["formatters"]["access"][ + "fmt" + ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s' + LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S" + + +def get_ip() -> str: + # SGLANG_HOST_IP env can be ignore + host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "") + if host_ip: + return host_ip + + # IP is not set, try to get it from the network interface + + # try ipv4 + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable + return s.getsockname()[0] + except Exception: pass - # Calculate GQA group size - gqa_group_size = num_attention_heads // num_kv_heads + # try ipv6 + try: + s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) + # Google's public DNS server, see + # https://developers.google.com/speed/public-dns/docs/using#addresses + s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable + return s.getsockname()[0] + except Exception: + pass + + warnings.warn( + "Failed to get the IP address, using 0.0.0.0 by default." + "The value can be set by the environment variable" + " SGLANG_HOST_IP or HOST_IP.", + stacklevel=2, + ) + return "0.0.0.0" + + +def get_open_port() -> int: + + port = os.getenv("SGLANG_PORT") + if port is not None: + while True: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", port)) + return port + except OSError: + port += 1 # Increment port number if already in use + logger.info("Port %d is already in use, trying port %d", port - 1, port) + # try ipv4 + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + except OSError: + # try ipv6 + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] - # Determine based on dtype and GQA group size - if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + +def is_valid_ipv6_address(address: str) -> bool: + try: + ipaddress.IPv6Address(address) return True - elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16): - return gqa_group_size > 4 - else: + except ValueError: return False + + +def rank0_print(msg: str): + from sglang.srt.distributed import get_tensor_model_parallel_rank + + if get_tensor_model_parallel_rank() == 0: + print(msg, flush=True) + + +def launch_dummy_health_check_server(host, port): + import uvicorn + from fastapi import FastAPI, Response + + app = FastAPI() + + @app.get("/health") + async def health(): + """Check the health of the http server.""" + return Response(status_code=200) + + @app.get("/health_generate") + async def health_generate(): + """Check the health of the http server.""" + return Response(status_code=200) + + uvicorn.run( + app, + host=host, + port=port, + timeout_keep_alive=5, + loop="uvloop", + ) diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index f22f9cafaf3..bae0fcf2a49 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -12,7 +12,6 @@ # limitations under the License. # ============================================================================== -import json import multiprocessing as mp import os from dataclasses import dataclass @@ -22,8 +21,8 @@ import torch.nn.functional as F from transformers import AutoModelForCausalLM +from sglang.srt.entrypoints.engine import Engine from sglang.srt.hf_transformers_utils import get_tokenizer -from sglang.srt.server import Runtime from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER DEFAULT_PROMPTS = [ @@ -278,7 +277,7 @@ def __init__( ): self.model_type = model_type self.is_generation = model_type == "generation" - self.runtime = Runtime( + self.engine = Engine( model_path=model_path, tp_size=tp_size, dtype=get_dtype_str(torch_dtype), @@ -306,7 +305,7 @@ def forward( top_output_logprobs = [] sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0} for i, prompt in enumerate(prompts): - response = self.runtime.generate( + response = self.engine.generate( prompt, lora_path=lora_paths[i] if lora_paths else None, sampling_params=sampling_params, @@ -314,7 +313,6 @@ def forward( logprob_start_len=0, top_logprobs_num=NUM_TOP_LOGPROBS, ) - response = json.loads(response) output_strs.append(response["text"]) top_input_logprobs.append( [ @@ -343,8 +341,7 @@ def forward( top_output_logprobs=top_output_logprobs, ) else: - response = self.runtime.encode(prompts) - response = json.loads(response) + response = self.engine.encode(prompts) if self.model_type == "embedding": logits = [x["embedding"] for x in response] return ModelOutput(embed_logits=logits) @@ -366,20 +363,18 @@ def batch_forward( # the return value contains logprobs from prefill output_strs = [] sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0} - response = self.runtime.generate( + response = self.engine.generate( prompts, lora_path=lora_paths if lora_paths else None, sampling_params=sampling_params, ) - response = json.loads(response) output_strs = [r["text"] for r in response] return ModelOutput( output_strs=output_strs, ) else: - response = self.runtime.encode(prompts) - response = json.loads(response) + response = self.engine.encode(prompts) if self.model_type == "embedding": logits = [x["embedding"] for x in response] return ModelOutput(embed_logits=logits) @@ -391,8 +386,8 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): - self.runtime.shutdown() - del self.runtime + self.engine.shutdown() + del self.engine def monkey_patch_gemma2_sdpa(): diff --git a/python/sglang/test/test_block_fp8.py b/python/sglang/test/test_block_fp8.py new file mode 100644 index 00000000000..3a02531e695 --- /dev/null +++ b/python/sglang/test/test_block_fp8.py @@ -0,0 +1,341 @@ +import itertools +import unittest + +import torch + +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe +from sglang.srt.layers.quantization.fp8_kernel import ( + per_token_group_quant_fp8, + w8a8_block_fp8_matmul, +) + + +# For test +def native_per_token_group_quant_fp8( + x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn +): + """Function to perform per-token-group quantization on an input tensor `x` using native torch. + + It converts the tensor values into float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + Note that only `torch.float8_e4m3fn` is supported for now. + """ + assert ( + x.shape[-1] % group_size == 0 + ), "the last dimension of `x` cannot be divisible by `group_size`" + assert x.is_contiguous(), "`x` is not contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_ = x.reshape(x.numel() // group_size, group_size) + amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32) + x_s = amax / fp8_max + x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype) + x_q = x_q.reshape(x.shape) + x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,)) + + return x_q, x_s + + +class TestPerTokenGroupQuantFP8(unittest.TestCase): + DTYPES = [torch.half, torch.bfloat16, torch.float32] + NUM_TOKENS = [7, 83, 2048] + D = [512, 4096, 5120, 13824] + GROUP_SIZE = [64, 128, 256, 512] + SEEDS = [0] + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _per_token_group_quant_fp8(self, num_tokens, d, dtype, group_size, seed): + torch.manual_seed(seed) + + x = torch.rand(num_tokens, d, dtype=dtype) + + with torch.inference_mode(): + ref_out, ref_scale = native_per_token_group_quant_fp8(x, group_size) + out, scale = per_token_group_quant_fp8(x, group_size) + + self.assertTrue( + torch.allclose(out.to(torch.float32), ref_out.to(torch.float32), rtol=0.15) + ) + self.assertTrue(torch.allclose(scale, ref_scale)) + + def test_per_token_group_quant_fp8(self): + for params in itertools.product( + self.NUM_TOKENS, + self.D, + self.DTYPES, + self.GROUP_SIZE, + self.SEEDS, + ): + with self.subTest( + num_tokens=params[0], + d=params[1], + dtype=params[2], + group_size=params[3], + seed=params[4], + ): + self._per_token_group_quant_fp8(*params) + + +# For test +def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16): + """This function performs matrix multiplication with block-wise quantization using native torch. + + It takes two input tensors `A` and `B` with scales `As` and `Bs`. + The output is returned in the specified `output_dtype`. + """ + + A = A.to(torch.float32) + B = B.to(torch.float32) + assert A.shape[-1] == B.shape[-1] + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] + assert A.shape[:-1] == As.shape[:-1] + + M = A.numel() // A.shape[-1] + N, K = B.shape + origin_C_shape = A.shape[:-1] + (N,) + A = A.reshape(M, A.shape[-1]) + As = As.reshape(M, As.shape[-1]) + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + assert n_tiles == Bs.shape[0] + assert k_tiles == Bs.shape[1] + + C_shape = (M, N) + C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) + + A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)] + B_tiles = [ + [ + B[ + j * block_n : min((j + 1) * block_n, N), + i * block_k : min((i + 1) * block_k, K), + ] + for i in range(k_tiles) + ] + for j in range(n_tiles) + ] + C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)] + As_tiles = [As[:, i : i + 1] for i in range(k_tiles)] + + for i in range(k_tiles): + for j in range(n_tiles): + a = A_tiles[i] + b = B_tiles[j][i] + c = C_tiles[j] + s = As_tiles[i] * Bs[j][i] + c[:, :] += torch.matmul(a, b.t()) * s + + C = C.reshape(origin_C_shape).to(output_dtype) + return C + + +class TestW8A8BlockFP8Matmul(unittest.TestCase): + OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16] + M = [1, 7, 83, 512, 2048] + N = [128, 512, 1024, 4096, 7748, 13824] + K = [256, 4096, 5120, 3884, 13824] + # BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]] + BLOCK_SIZE = [[128, 128]] + SEEDS = [0] + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _w8a8_block_fp8_matmul(self, M, N, K, block_size, out_dtype, seed): + torch.manual_seed(seed) + # NOTE(HandH1998): to avoid overflow when out_dtype = torch.half + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + A_fp32 = (torch.rand(M, K, dtype=torch.float32) - 0.5) * 2 * fp8_max + A_fp8 = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + B_fp32 = (torch.rand(N, K, dtype=torch.float32) - 0.5) * 2 * fp8_max + B_fp8 = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + block_n, block_k = block_size[0], block_size[1] + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + + As = torch.rand(M, k_tiles, dtype=torch.float32) * factor_for_scale + Bs = torch.rand(n_tiles, k_tiles, dtype=torch.float32) * factor_for_scale + + with torch.inference_mode(): + ref_out = native_w8a8_block_fp8_matmul( + A_fp8, B_fp8, As, Bs, block_size, out_dtype + ) + out = w8a8_block_fp8_matmul(A_fp8, B_fp8, As, Bs, block_size, out_dtype) + + self.assertTrue( + torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) + / torch.mean(torch.abs(ref_out.to(torch.float32))) + < 0.001 + ) + + def test_w8a8_block_fp8_matmul(self): + for params in itertools.product( + self.M, + self.N, + self.K, + self.BLOCK_SIZE, + self.OUT_DTYPES, + self.SEEDS, + ): + with self.subTest( + M=params[0], + N=params[1], + K=params[2], + block_size=params[3], + out_dtype=params[4], + seed=params[5], + ): + self._w8a8_block_fp8_matmul(*params) + + +# For test +def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): + """This function performs fused moe with block-wise quantization using native torch.""" + + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + + _, block_k = block_shape[0], block_shape[1] + a_q, a_s = native_per_token_group_quant_fp8(a, block_k) + # NOTE(HandH1998): Since "index_cuda" not implemented for 'Float8_e4m3fn', we need to cast `float8`` to `float32``. + a_q = a_q.to(torch.float32) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + inter_out = native_w8a8_block_fp8_matmul( + a_q[mask], w1[i], a_s[mask], w1_s[i], block_shape, output_dtype=a.dtype + ) + act_out = SiluAndMul().forward_native(inter_out) + act_out_q, act_out_s = native_per_token_group_quant_fp8(act_out, block_k) + act_out = act_out.to(torch.float32) + out[mask] = native_w8a8_block_fp8_matmul( + act_out_q, w2[i], act_out_s, w2_s[i], block_shape, output_dtype=a.dtype + ) + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) + + +class TestW8A8BlockFP8FusedMoE(unittest.TestCase): + DTYPES = [torch.float32, torch.half, torch.bfloat16] + M = [1, 33, 64, 222, 1024 * 128] + N = [128, 1024, 2048] + K = [256, 4096, 5120] + E = [8, 24] + TOP_KS = [2, 6] + BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]] + # BLOCK_SIZE = [[128, 128]] + SEEDS = [0] + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _w8a8_block_fp8_fused_moe(self, M, N, K, E, topk, block_size, dtype, seed): + torch.manual_seed(seed) + # NOTE(HandH1998): to avoid overflow when out_dtype = torch.half + factor_for_scale = 1e-2 + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + a = torch.randn((M, K), dtype=dtype) / 10 + + w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 * fp8_max + w1 = w1_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 * fp8_max + w2 = w2_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + block_n, block_k = block_size[0], block_size[1] + n_tiles_w1 = (2 * N + block_n - 1) // block_n + n_tiles_w2 = (K + block_n - 1) // block_n + k_tiles_w1 = (K + block_k - 1) // block_k + k_tiles_w2 = (N + block_k - 1) // block_k + + w1_s = ( + torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) + * factor_for_scale + ) + w2_s = ( + torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) + * factor_for_scale + ) + + score = torch.randn((M, E), dtype=dtype) + + with torch.inference_mode(): + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + ) + ref_out = torch_w8a8_block_fp8_moe( + a, w1, w2, w1_s, w2_s, score, topk, block_size + ) + + self.assertTrue( + torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) + / torch.mean(torch.abs(ref_out.to(torch.float32))) + < 0.02 + ) + + def test_w8a8_block_fp8_fused_moe(self): + for params in itertools.product( + self.M, + self.N, + self.K, + self.E, + self.TOP_KS, + self.BLOCK_SIZE, + self.DTYPES, + self.SEEDS, + ): + with self.subTest( + M=params[0], + N=params[1], + K=params[2], + E=params[3], + topk=params[4], + block_size=params[5], + dtype=params[6], + seed=params[7], + ): + self._w8a8_block_fp8_fused_moe(*params) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py index a251e0acaaa..088cb0d0af9 100644 --- a/python/sglang/test/test_programs.py +++ b/python/sglang/test/test_programs.py @@ -509,13 +509,36 @@ def few_shot_hellaswag(s, question, choices): temperature=0, num_threads=64, progress_bar=True, + generator_style=False, ) - preds = [choices[i].index(rets[i]["answer"]) for i in range(len(rets))] + preds = [] + for i, ret in enumerate(rets): + preds.append(choices[i].index(ret["answer"])) latency = time.time() - tic # Compute accuracy accuracy = np.mean(np.array(preds) == np.array(labels)) + # Test generator style of run_batch + tic = time.time() + rets = few_shot_hellaswag.run_batch( + arguments, + temperature=0, + num_threads=64, + progress_bar=True, + generator_style=True, + ) + preds_gen = [] + for i, ret in enumerate(rets): + preds_gen.append(choices[i].index(ret["answer"])) + latency_gen = time.time() - tic + + # Compute accuracy + accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels)) + print(f"{accuracy=}, {accuracy_gen=}") + assert np.abs(accuracy_gen - accuracy) < 0.05 + assert np.abs(latency_gen - latency) < 1 + return accuracy, latency diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index f97fc12355a..b303f19121d 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -34,12 +34,16 @@ DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct" DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct" DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8" -DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600 +DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 1000 DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1 = "meta-llama/Llama-3.1-8B-Instruct,mistralai/Mistral-7B-Instruct-v0.3,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,google/gemma-2-27b-it" -DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Llama-3.1-70B-Instruct,mistralai/Mixtral-8x7B-Instruct-v0.1,Qwen/Qwen2-57B-A14B-Instruct,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct" +DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Llama-3.1-70B-Instruct,mistralai/Mixtral-8x7B-Instruct-v0.1,Qwen/Qwen2-57B-A14B-Instruct" DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8" DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8" DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4" +DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct" + +DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf" +DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmzheng/sglang-EAGLE-llama2-chat-7B" def is_in_ci(): @@ -131,10 +135,6 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None): return pred -def call_generate_gserver(prompt, temperature, max_tokens, stop=None, url=None): - raise NotImplementedError() - - def call_generate_guidance( prompt, temperature, max_tokens, stop=None, n=1, regex=None, model=None ): @@ -405,7 +405,7 @@ def popen_launch_server( base_url: str, timeout: float, api_key: Optional[str] = None, - other_args: tuple = (), + other_args: list[str] = (), env: Optional[dict] = None, return_stdout_stderr: Optional[tuple] = None, ): @@ -526,15 +526,60 @@ def get_similarities(vec1, vec2): return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0) +def get_benchmark_args( + base_url="", + dataset_name="", + dataset_path="", + tokenizer="", + num_prompts=500, + random_input_len=4096, + random_output_len=2048, + request_rate=float("inf"), + disable_stream=False, + disable_ignore_eos=False, +): + return SimpleNamespace( + backend="sglang", + base_url=base_url, + host=None, + port=None, + dataset_name=dataset_name, + dataset_path=dataset_path, + model=None, + tokenizer=tokenizer, + num_prompts=num_prompts, + sharegpt_output_len=None, + sharegpt_context_len=None, + random_input_len=random_input_len, + random_output_len=random_output_len, + random_range_ratio=0.0, + request_rate=request_rate, + multi=None, + output_file=None, + disable_tqdm=False, + disable_stream=disable_stream, + return_logprob=False, + seed=0, + disable_ignore_eos=disable_ignore_eos, + extra_request_body=None, + apply_chat_template=False, + profile=None, + lora_name=None, + ) + + def run_bench_serving( model, num_prompts, request_rate, other_server_args, dataset_name="random", + dataset_path="", + tokenizer=None, random_input_len=4096, random_output_len=2048, disable_stream=False, + disable_ignore_eos=False, need_warmup=False, ): # Launch the server @@ -547,30 +592,17 @@ def run_bench_serving( ) # Run benchmark - args = SimpleNamespace( - backend="sglang", + args = get_benchmark_args( base_url=base_url, - host=None, - port=None, dataset_name=dataset_name, - dataset_path="", - model=None, - tokenizer=None, + dataset_path=dataset_path, + tokenizer=tokenizer, num_prompts=num_prompts, - sharegpt_output_len=None, random_input_len=random_input_len, random_output_len=random_output_len, - random_range_ratio=0.0, request_rate=request_rate, - multi=None, - seed=0, - output_file=None, - disable_tqdm=False, disable_stream=disable_stream, - disable_ignore_eos=False, - lora_name=None, - extra_request_body=None, - profile=None, + disable_ignore_eos=disable_ignore_eos, ) try: @@ -586,6 +618,38 @@ def run_bench_serving( return res +def run_bench_serving_multi( + model, + base_url, + other_server_args, + benchmark_args, + need_warmup=False, +): + # Launch the server + process = popen_launch_server( + model, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_server_args, + ) + + # run benchmark for all + res_l = [] + try: + for args in benchmark_args: + if need_warmup: + warmup_args = copy.deepcopy(args) + warmup_args.num_prompts = 16 + run_benchmark(warmup_args) + + res = run_benchmark(args) + res_l.append((args, res)) + finally: + kill_process_tree(process.pid) + + return res_l + + def run_bench_one_batch(model, other_args): command = [ "python3", @@ -656,16 +720,16 @@ def calculate_rouge_l(output_strs_list1, output_strs_list2): STDOUT_FILENAME = "stdout.txt" -def read_output(output_lines): +def read_output(output_lines: List[str], filename: str = STDERR_FILENAME): """Print the output in real time with another thread.""" - while not os.path.exists(STDERR_FILENAME): + while not os.path.exists(filename): time.sleep(1) pt = 0 while pt >= 0: - if pt > 0 and not os.path.exists(STDERR_FILENAME): + if pt > 0 and not os.path.exists(filename): break - lines = open(STDERR_FILENAME).readlines() + lines = open(filename).readlines() for line in lines[pt:]: print(line, end="", flush=True) output_lines.append(line) @@ -719,13 +783,13 @@ def run_and_check_memory_leak( # Clean up everything kill_process_tree(process.pid) - kill_process_tree(process.pid) stdout.close() stderr.close() if os.path.exists(STDOUT_FILENAME): os.remove(STDOUT_FILENAME) if os.path.exists(STDERR_FILENAME): os.remove(STDERR_FILENAME) + kill_process_tree(process.pid) t.join() # Assert success @@ -733,7 +797,7 @@ def run_and_check_memory_leak( has_leak = False has_abort = False for line in output_lines: - if "The server is fired" in line: + if "Uvicorn running" in line: has_new_server = True if "leak" in line: has_leak = True @@ -746,6 +810,33 @@ def run_and_check_memory_leak( assert has_abort +def run_command_and_capture_output(command, env: Optional[dict] = None): + stdout = open(STDOUT_FILENAME, "w") + stderr = open(STDERR_FILENAME, "w") + process = subprocess.Popen( + command, stdout=stdout, stderr=stderr, env=env, text=True + ) + + # Launch a thread to stream the output + output_lines = [] + t = threading.Thread(target=read_output, args=(output_lines, STDOUT_FILENAME)) + t.start() + + # Join the process + process.wait() + + stdout.close() + stderr.close() + if os.path.exists(STDOUT_FILENAME): + os.remove(STDOUT_FILENAME) + if os.path.exists(STDERR_FILENAME): + os.remove(STDERR_FILENAME) + kill_process_tree(process.pid) + t.join() + + return output_lines + + def run_mmlu_test( disable_radix_cache=False, enable_mixed_chunk=False, diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 6465db7b81a..59982f037c8 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -1,7 +1,6 @@ -"""Common utilities.""" +"""Common utilities""" import base64 -import gc import importlib import json import logging @@ -15,7 +14,7 @@ from concurrent.futures import ThreadPoolExecutor from io import BytesIO from json import dumps -from typing import Optional, Union +from typing import Any, Callable, List, Optional, Tuple, Type, Union import numpy as np import requests @@ -79,7 +78,15 @@ def status_code(self): return self.resp.status -def http_request(url, json=None, stream=False, api_key=None, verify=None, timeout=None): +def http_request( + url, + json=None, + stream=False, + api_key=None, + verify=None, + timeout=None, + method: Optional[str] = None, +): """A faster version of requests.post with low-level urllib API.""" headers = {"Content-Type": "application/json; charset=utf-8"} @@ -92,7 +99,7 @@ def http_request(url, json=None, stream=False, api_key=None, verify=None, timeou url, json=json, stream=True, headers=headers, timeout=timeout ) else: - req = urllib.request.Request(url, headers=headers) + req = urllib.request.Request(url, headers=headers, method=method) if json is None: data = None else: @@ -360,3 +367,56 @@ def terminate_process(process): def print_highlight(html_content: str): html_content = str(html_content).replace("\n", "
") display(HTML(f"{html_content}")) + + +class TypeBasedDispatcher: + def __init__(self, mapping: List[Tuple[Type, Callable]]): + self._mapping = mapping + + def __call__(self, obj: Any): + for ty, fn in self._mapping: + if isinstance(obj, ty): + return fn(obj) + raise ValueError(f"Invalid object: {obj}") + + +def trim_overlap(existing_text, new_chunk): + """ + Finds the largest suffix of 'existing_text' that is a prefix of 'new_chunk' + and removes that overlap from the start of 'new_chunk'. + """ + max_overlap = 0 + max_possible = min(len(existing_text), len(new_chunk)) + for i in range(max_possible, 0, -1): + if existing_text.endswith(new_chunk[:i]): + max_overlap = i + break + return new_chunk[max_overlap:] + + +def stream_and_merge(llm, prompt, sampling_params): + """ + 1) Streams the text, + 2) Removes chunk overlaps, + 3) Returns the merged text. + """ + final_text = "" + for chunk in llm.generate(prompt, sampling_params, stream=True): + chunk_text = chunk["text"] + cleaned_chunk = trim_overlap(final_text, chunk_text) + final_text += cleaned_chunk + return final_text + + +async def async_stream_and_merge(llm, prompt, sampling_params): + """ + Streams tokens asynchronously, removes chunk overlaps, + and yields the cleaned chunk in real time for printing. + """ + final_text = "" + generator = await llm.async_generate(prompt, sampling_params, stream=True) + async for chunk in generator: + chunk_text = chunk["text"] + cleaned_chunk = trim_overlap(final_text, chunk_text) + final_text += cleaned_chunk + yield cleaned_chunk # yield the non-overlapping portion diff --git a/python/sglang/version.py b/python/sglang/version.py index a21caf9d324..d1b3e6d0ae9 100644 --- a/python/sglang/version.py +++ b/python/sglang/version.py @@ -1 +1 @@ -__version__ = "0.4.0.post1" +__version__ = "0.4.2.post1" diff --git a/rust/README.md b/rust/README.md deleted file mode 100644 index 84a8e8fb1d0..00000000000 --- a/rust/README.md +++ /dev/null @@ -1,183 +0,0 @@ -# SGLang Router - -SGLang router is a standalone module implemented in Rust to achieve data parallelism across SGLang instances. - -## Installation - -```bash -pip install sglang-router -``` - -## Usage -The router offers two modes: - -### 1. Co-launch workers and router -This will be a drop-in replacement for the existing `--dp-size`. This part of code will be moved into sglang core. -Under the hood, it uses multi-processes to launch multiple sglang workers, wait for them to be healthy, then launch the router. - -```bash -$ python -m sglang_router.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --dp-size 8 -``` - -### 2. Launch only router -This is useful for multi-node DP. You can launch workers on different nodes, then connect the router to them. - -```bash -$ python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 - -$ python -m sglang_router.launch_router --help -usage: launch_router.py [-h] [--host HOST] [--port PORT] [--worker-urls WORKER_URLS [WORKER_URLS ...]] - [--policy {random,round_robin,cache_aware}] [--cache-threshold CACHE_THRESHOLD] - [--balance-abs-threshold BALANCE_ABS_THRESHOLD] [--balance-rel-threshold BALANCE_REL_THRESHOLD] - [--eviction-interval EVICTION_INTERVAL] [--max-tree-size MAX_TREE_SIZE] - -options: - -h, --help show this help message and exit - --host HOST Host address to bind the router server (default: 127.0.0.1) - --port PORT Port number to bind the router server (default: 30000) - --worker-urls WORKER_URLS [WORKER_URLS ...] - List of worker URLs (e.g., http://worker1:8000 http://worker2:8000) (default: None) - --policy {random,round_robin,cache_aware} - Load balancing policy to use (default: cache_aware) - --cache-threshold CACHE_THRESHOLD - Cache threshold (0.0-1.0) for cache-aware routing (default: 0.5) - --balance-abs-threshold BALANCE_ABS_THRESHOLD - Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold (default: 32) - --balance-rel-threshold BALANCE_REL_THRESHOLD - Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold (default: 1.0001) - --eviction-interval EVICTION_INTERVAL - Interval in seconds between cache eviction operations (default: 60) - --max-tree-size MAX_TREE_SIZE - Maximum size of the approximation tree for cache-aware routing (default: 16777216) -``` - -## Strategy - -### Cache-Aware Load-Balancing Router - -This router combines two strategies to optimize both cache utilization and request distribution: - -1. Cache-Aware Routing (Approximate Tree) -2. Load-Balancing Routing (Shortest Queue with Balance Thresholds) - -The router dynamically switches between these strategies based on load conditions: -- Uses load balancing when the system is imbalanced -- Uses cache-aware routing when the system is balanced - -A system is considered imbalanced if both conditions are met: -1. (max_load - min_load) > balance_abs_threshold -2. max_load > balance_rel_threshold * min_load - -#### 1. Cache-Aware Routing (Approximate Tree) -This strategy maintains an approximate radix tree for each worker based on request history, -eliminating the need for direct cache state queries. The tree stores raw text characters -instead of token IDs to avoid tokenization overhead. - -Process: -- For each request, find the worker with the highest prefix match -- If match rate > cache_threshold: - - Route to the worker with highest match (likely has relevant data cached) -- If match rate ≤ cache_threshold: - - Route to the worker with smallest tree size (most available cache capacity) -- Background maintenance: - - Periodically evict least recently used leaf nodes to prevent memory overflow - -#### 2. Load-Balancing (Shortest Queue) -This strategy tracks pending request counts per worker and routes new requests -to the least busy worker when the system is detected to be imbalanced. This helps -maintain optimal load distribution across workers. - -### Configuration Parameters - -1. `cache_threshold`: (float, 0.0 to 1.0, default: 0.5) - - Minimum prefix match ratio to use highest-match routing - - Below this threshold, routes to worker with most available cache space - -2. `balance_abs_threshold`: (integer, default: 32) - - Absolute difference threshold for load imbalance detection - - System is potentially imbalanced if (max_load - min_load) > abs_threshold - -3. `balance_rel_threshold`: (float, default: 1.0001) - - Relative ratio threshold for load imbalance detection - - System is potentially imbalanced if max_load > min_load * rel_threshold - - Used in conjunction with abs_threshold to determine final imbalance state - -4. `eviction_interval`: (integer, default: 60) - - Interval in seconds between LRU eviction cycles for the approximate trees - - Background thread periodically evicts least recently used nodes to maintain tree size - -5. `max_tree_size`: (integer, default: 16777216) - - Maximum nodes per tree - - When exceeded, LRU leaf nodes are evicted during the next eviction cycle - -## Development - -- Rust and Cargo installed - -```bash -# Install rustup (Rust installer and version manager) -curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh - -# Follow the installation prompts, then reload your shell -source $HOME/.cargo/env - -# Verify installation -rustc --version -cargo --version -``` - -- Python with pip installed - - -### Build Process - -#### 1. Build Rust Project - -```bash -cargo build -``` - -#### 2. Build Python Binding - -##### Option A: Build and Install Wheel -1. Build the wheel package: -```bash -pip install setuptools-rust wheel build -python -m build -``` - -2. Install the generated wheel: -```bash -pip install -``` - -##### Option B: Development Mode - -For development purposes, you can install the package in editable mode: - -Warning: Using editable python binding can suffer from performance degradation!! Please build a fresh wheel for every update if you want to test performance. - -```bash -pip install -e . -``` - -**Note:** When modifying Rust code, you must rebuild the wheel for changes to take effect. - -### CI/CD Setup - -The continuous integration pipeline consists of three main steps: - -#### 1. Build Wheels -- Uses `cibuildwheel` to create manylinux x86_64 packages -- Compatible with major Linux distributions (Ubuntu, CentOS, etc.) -- Additional configurations can be added to support other OS/architectures -- Reference: [cibuildwheel documentation](https://cibuildwheel.pypa.io/en/stable/) - -#### 2. Build Source Distribution -- Creates a source distribution containing the raw, unbuilt code -- Enables `pip` to build the package from source when prebuilt wheels are unavailable - -#### 3. Publish to PyPI -- Uploads both wheels and source distribution to PyPI - -The CI configuration is based on the [tiktoken workflow](https://github.com/openai/tiktoken/blob/63527649963def8c759b0f91f2eb69a40934e468/.github/workflows/build_wheels.yml#L1). diff --git a/rust/py_test/test_launch_server.py b/rust/py_test/test_launch_server.py deleted file mode 100644 index 0dacc2c9f7d..00000000000 --- a/rust/py_test/test_launch_server.py +++ /dev/null @@ -1,184 +0,0 @@ -import socket -import subprocess -import time -import unittest -from types import SimpleNamespace - -import requests - -from sglang.srt.utils import kill_process_tree -from sglang.test.run_eval import run_eval -from sglang.test.test_utils import ( - DEFAULT_MODEL_NAME_FOR_TEST, - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, -) - - -def popen_launch_router( - model: str, - base_url: str, - dp_size: int, - timeout: float, -): - """ - Launch the router server process. - - Args: - model: Model path/name - base_url: Server base URL - dp_size: Data parallel size - timeout: Server launch timeout - """ - _, host, port = base_url.split(":") - host = host[2:] - - command = [ - "python3", - "-m", - "sglang_router.launch_server", - "--model-path", - model, - "--host", - host, - "--port", - port, - "--dp", - str(dp_size), # Convert dp_size to string - "--router-eviction-interval", - "5", # frequent eviction for testing - ] - - # Use current environment - env = None - - process = subprocess.Popen(command, stdout=None, stderr=None) - - start_time = time.time() - with requests.Session() as session: - while time.time() - start_time < timeout: - try: - response = session.get(f"{base_url}/health") - if response.status_code == 200: - print(f"Router {base_url} is healthy") - return process - except requests.RequestException: - pass - time.sleep(10) - - raise TimeoutError("Router failed to start within the timeout period.") - - -def find_available_port(): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -def popen_launch_server( - model: str, - base_url: str, - timeout: float, -): - _, host, port = base_url.split(":") - host = host[2:] - - command = [ - "python3", - "-m", - "sglang.launch_server", - "--model-path", - model, - "--host", - host, - "--port", - port, - "--base-gpu-id", - "1", - ] - - process = subprocess.Popen(command, stdout=None, stderr=None) - - start_time = time.time() - with requests.Session() as session: - while time.time() - start_time < timeout: - try: - response = session.get(f"{base_url}/health") - if response.status_code == 200: - print(f"Server {base_url} is healthy") - return process - except requests.RequestException: - pass - time.sleep(10) - - raise TimeoutError("Server failed to start within the timeout period.") - - -class TestEvalAccuracyMini(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_router( - cls.model, - cls.base_url, - dp_size=1, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - ) - cls.other_process = [] - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) - for process in cls.other_process: - kill_process_tree(process.pid) - - def test_mmlu(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mmlu", - num_examples=64, - num_threads=32, - temperature=0.1, - ) - - metrics = run_eval(args) - score = metrics["score"] - THRESHOLD = 0.65 - passed = score >= THRESHOLD - msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" - self.assertGreaterEqual(score, THRESHOLD, msg) - - def test_add_worker(self): - # 1. start a worker, and wait until it is healthy - port = find_available_port() - worker_url = f"http://127.0.0.1:{port}" - worker_process = popen_launch_server( - self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH - ) - self.other_process.append(worker_process) - # 2. use /add_worker api to add it the the router - with requests.Session() as session: - response = session.post(f"{self.base_url}/add_worker?url={worker_url}") - print(f"status code: {response.status_code}, response: {response.text}") - self.assertEqual(response.status_code, 200) - # 3. run mmlu - args = SimpleNamespace( - base_url=self.base_url, - model=self.model, - eval_name="mmlu", - num_examples=64, - num_threads=32, - temperature=0.1, - ) - metrics = run_eval(args) - score = metrics["score"] - THRESHOLD = 0.65 - passed = score >= THRESHOLD - msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" - self.assertGreaterEqual(score, THRESHOLD, msg) - - -if __name__ == "__main__": - unittest.main() diff --git a/rust/src/main.rs b/rust/src/main.rs deleted file mode 100644 index e450f2c54c3..00000000000 --- a/rust/src/main.rs +++ /dev/null @@ -1,125 +0,0 @@ -use clap::Parser; -use clap::ValueEnum; - -use sglang_router_rs::{router::PolicyConfig, server, server::ServerConfig}; - -#[derive(Debug, Clone, ValueEnum)] -pub enum PolicyType { - Random, - RoundRobin, - CacheAware, -} - -#[derive(Parser, Debug)] -#[command(author, version, about, long_about = None)] -struct Args { - #[arg( - long, - default_value = "127.0.0.1", - help = "Host address to bind the router server to. Default: 127.0.0.1" - )] - host: String, - - #[arg( - long, - default_value_t = 3001, - help = "Port number to bind the router server to. Default: 3001" - )] - port: u16, - - #[arg( - long, - value_delimiter = ',', - help = "Comma-separated list of worker URLs that will handle the requests. Each URL should include the protocol, host, and port (e.g., http://worker1:8000,http://worker2:8000)" - )] - worker_urls: Vec, - - #[arg( - long, - default_value_t = PolicyType::CacheAware, - value_enum, - help = "Load balancing policy to use for request distribution:\n\ - - random: Randomly select workers\n\ - - round_robin: Distribute requests in round-robin fashion\n\ - - cache_aware: Distribute requests based on cache state and load balance\n" - )] - policy: PolicyType, - - #[arg( - long, - default_value_t = 0.5, - requires = "policy", - required_if_eq("policy", "cache_aware"), - help = "Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker if the match rate exceeds threshold, otherwise routes to the worker with the smallest tree. Default: 0.5" - )] - cache_threshold: f32, - - #[arg( - long, - default_value_t = 32, - requires = "policy", - required_if_eq("policy", "cache_aware"), - help = "Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 32" - )] - balance_abs_threshold: usize, - - #[arg( - long, - default_value_t = 1.0001, - requires = "policy", - required_if_eq("policy", "cache_aware"), - help = "Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 1.0001" - )] - balance_rel_threshold: f32, - - #[arg( - long, - default_value_t = 60, - requires = "policy", - required_if_eq("policy", "cache_aware"), - help = "Interval in seconds between cache eviction operations in cache-aware routing. Default: 60" - )] - eviction_interval_secs: u64, - - #[arg( - long, - default_value_t = 2usize.pow(24), - requires = "policy", - required_if_eq("policy", "cache_aware"), - help = "Maximum size of the approximation tree for cache-aware routing. Default: 2^24" - )] - max_tree_size: usize, - - #[arg(long, default_value_t = false, help = "Enable verbose logging")] - verbose: bool, -} - -impl Args { - fn get_policy_config(&self) -> PolicyConfig { - match self.policy { - PolicyType::Random => PolicyConfig::RandomConfig, - PolicyType::RoundRobin => PolicyConfig::RoundRobinConfig, - PolicyType::CacheAware => PolicyConfig::CacheAwareConfig { - cache_threshold: self.cache_threshold, - balance_abs_threshold: self.balance_abs_threshold, - balance_rel_threshold: self.balance_rel_threshold, - eviction_interval_secs: self.eviction_interval_secs, - max_tree_size: self.max_tree_size, - }, - } - } -} - -#[actix_web::main] -async fn main() -> std::io::Result<()> { - let args = Args::parse(); - let policy_config = args.get_policy_config(); - server::startup(ServerConfig { - host: args.host, - port: args.port, - worker_urls: args.worker_urls, - policy_config, - verbose: args.verbose, - }) - .await -} diff --git a/rust/src/router.rs b/rust/src/router.rs deleted file mode 100644 index 2b6b8d52cff..00000000000 --- a/rust/src/router.rs +++ /dev/null @@ -1,399 +0,0 @@ -use crate::tree::Tree; -use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; -use actix_web::{HttpRequest, HttpResponse}; -use bytes::Bytes; -use futures_util::{StreamExt, TryStreamExt}; -use log::{debug, info}; -use std::collections::HashMap; -use std::fmt::Debug; -use std::sync::atomic::AtomicUsize; -use std::sync::{Arc, Mutex, RwLock}; -use std::thread; -use std::time::Duration; - -#[derive(Debug)] -pub enum Router { - RoundRobin { - worker_urls: Arc>>, - current_index: AtomicUsize, - }, - Random { - worker_urls: Arc>>, - }, - CacheAware { - /* - Cache-Aware Load Balancing Router - - This router combines two strategies to optimize both cache utilization and request distribution: - - 1. Cache-Aware Routing (Approximate Tree) - 2. Load Balancing (Shortest Queue with Balance Thresholds) - - The router dynamically switches between these strategies based on load conditions: - - Uses load balancing when the system is imbalanced - - Uses cache-aware routing when the system is balanced - - A system is considered imbalanced if both conditions are met: - 1. (max - min) > abs_threshold - 2. max > rel_threshold * min - - Strategy Details: - - 1. Cache-Aware Routing (Approximate Tree) - ------------------------------------------- - This strategy maintains an approximate radix tree for each worker based on request history, - eliminating the need for direct cache state queries. The tree stores raw text characters - instead of token IDs to avoid tokenization overhead. - - Process: - a. For each request, find the worker with the highest prefix match - b. If match rate > cache_threshold: - Route to the worker with highest match (likely has relevant data cached) - c. If match rate ≤ cache_threshold: - Route to the worker with smallest tree size (most available cache capacity) - d. Background maintenance: - Periodically evict least recently used leaf nodes to prevent memory overflow - - 2. Load Balancing (Shortest Queue) - ------------------------------------------- - This strategy tracks pending request counts per worker and routes new requests - to the least busy worker when the system is detected to be imbalanced. - - Configuration Parameters: - ------------------------ - 1. cache_threshold: (float, 0.0 to 1.0) - Minimum prefix match ratio to use highest-match routing. - Below this threshold, routes to worker with most available cache space. - - 2. balance_abs_threshold: (integer) - Absolute difference threshold for load imbalance detection. - System is potentially imbalanced if (max_load - min_load) > abs_threshold - - 3. balance_rel_threshold: (float) - Relative ratio threshold for load imbalance detection. - System is potentially imbalanced if max_load > min_load * rel_threshold - Used in conjunction with abs_threshold to determine final imbalance state. - - 4. eviction_interval_secs: (integer) - Interval between LRU eviction cycles for the approximate trees. - - 5. max_tree_size: (integer) - Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted - during the next eviction cycle. - */ - worker_urls: Arc>>, - tree: Arc>, - running_queue: Arc>>, - processed_queue: Arc>>, - cache_threshold: f32, - balance_abs_threshold: usize, - balance_rel_threshold: f32, - _eviction_thread: Option>, - }, -} - -#[derive(Debug)] -pub enum PolicyConfig { - RandomConfig, - RoundRobinConfig, - CacheAwareConfig { - cache_threshold: f32, - balance_abs_threshold: usize, - balance_rel_threshold: f32, - eviction_interval_secs: u64, - max_tree_size: usize, - }, -} - -fn get_text_from_request(body: &Bytes, route: &str) -> String { - // convert body to json - let json = serde_json::from_slice::(body).unwrap(); - - if route == "generate" { - // get the "text" field - let text = json.get("text").and_then(|t| t.as_str()).unwrap_or(""); - return text.to_string(); - } else if route == "v1/chat/completions" { - // get the messages field as raw text - if let Some(messages) = json.get("messages") { - // Convert messages back to a string, preserving all JSON formatting - return serde_json::to_string(messages).unwrap_or_default(); - } - } else if route == "v1/completions" { - let prompt = json.get("prompt").and_then(|t| t.as_str()).unwrap_or(""); - return prompt.to_string(); - } - - return "".to_string(); -} -impl Router { - pub fn new(worker_urls: Vec, policy_config: PolicyConfig) -> Self { - match policy_config { - PolicyConfig::RandomConfig => Router::Random { - worker_urls: Arc::new(RwLock::new(worker_urls)), - }, - PolicyConfig::RoundRobinConfig => Router::RoundRobin { - worker_urls: Arc::new(RwLock::new(worker_urls)), - current_index: std::sync::atomic::AtomicUsize::new(0), - }, - PolicyConfig::CacheAwareConfig { - cache_threshold, - balance_abs_threshold, - balance_rel_threshold, - eviction_interval_secs, - max_tree_size, - } => { - let mut running_queue = HashMap::new(); - for url in &worker_urls { - running_queue.insert(url.clone(), 0); - } - - let mut processed_queue = HashMap::new(); - for url in &worker_urls { - processed_queue.insert(url.clone(), 0); - } - - let tree = Arc::new(Mutex::new(Tree::new())); - let running_queue = Arc::new(Mutex::new(running_queue)); - let processed_queue = Arc::new(Mutex::new(processed_queue)); - - // Create background eviction thread - let tree_clone = Arc::clone(&tree); - let processed_queue_clone = Arc::clone(&processed_queue); - let running_queue_clone = Arc::clone(&running_queue); - let eviction_thread = thread::spawn(move || { - loop { - // Sleep for the specified interval - thread::sleep(Duration::from_secs(eviction_interval_secs)); - - let locked_tree_clone = tree_clone.lock().unwrap(); - // Run eviction - locked_tree_clone.evict_tenant_by_size(max_tree_size); - - // Print the process queue - let locked_processed_queue = processed_queue_clone.lock().unwrap(); - info!("Processed Queue: {:?}", locked_processed_queue); - - // Print the running queue - let locked_running_queue = running_queue_clone.lock().unwrap(); - info!("Running Queue: {:?}", locked_running_queue); - } - }); - - for url in &worker_urls { - tree.lock().unwrap().insert(&"".to_string(), url); - } - - Router::CacheAware { - worker_urls: Arc::new(RwLock::new(worker_urls)), - tree, - running_queue, - processed_queue, - cache_threshold, - balance_abs_threshold, - balance_rel_threshold, - _eviction_thread: Some(eviction_thread), - } - } - } - } - - pub fn get_first(&self) -> Option { - match self { - Router::RoundRobin { worker_urls, .. } - | Router::Random { worker_urls } - | Router::CacheAware { worker_urls, .. } => { - if worker_urls.read().unwrap().is_empty() { - None - } else { - Some(worker_urls.read().unwrap()[0].clone()) - } - } - } - } - - pub async fn dispatch( - &self, - client: &reqwest::Client, - req: HttpRequest, - body: Bytes, - route: &str, - ) -> HttpResponse { - let text = get_text_from_request(&body, route); - - let worker_url = match self { - Router::RoundRobin { - worker_urls, - current_index, - } => { - let idx = current_index - .fetch_update( - std::sync::atomic::Ordering::SeqCst, - std::sync::atomic::Ordering::SeqCst, - |x| Some((x + 1) % worker_urls.read().unwrap().len()), - ) - .unwrap(); - worker_urls.read().unwrap()[idx].clone() - } - - Router::Random { worker_urls } => worker_urls.read().unwrap() - [rand::random::() % worker_urls.read().unwrap().len()] - .clone(), - - Router::CacheAware { - worker_urls, - tree, - running_queue, - processed_queue, - cache_threshold, - balance_abs_threshold, - balance_rel_threshold, - .. - } => { - // TODO: delay scheduling if cache hit rate is high because it may cause imbalance. prioritize low hit rate ones - - let tree = tree.lock().unwrap(); - let mut running_queue = running_queue.lock().unwrap(); - - // Get current load statistics - let max_load = *running_queue.values().max().unwrap_or(&0); - let min_load = *running_queue.values().min().unwrap_or(&0); - - // Load is considered imbalanced if: - // 1. (max - min) > abs_threshold AND - // 2. max > rel_threshold * min - let is_imbalanced = max_load.saturating_sub(min_load) > *balance_abs_threshold - && (max_load as f32) > (min_load as f32 * balance_rel_threshold); - - let selected_url = if is_imbalanced { - // Log load balancing trigger and current queue state - info!( - "Load balancing triggered due to workload imbalance:\n\ - Max load: {}, Min load: {}\n\ - Current running queue: {:?}", - max_load, min_load, running_queue - ); - - // Use shortest queue routing when load is imbalanced - running_queue - .iter() - .min_by_key(|(_url, &count)| count) - .map(|(url, _)| url.clone()) - .unwrap_or_else(|| worker_urls.read().unwrap()[0].clone()) - } else { - // Use cache-aware routing when load is balanced - let (matched_text, matched_worker) = tree.prefix_match(&text); - let matched_rate = - matched_text.chars().count() as f32 / text.chars().count() as f32; - - if matched_rate > *cache_threshold { - matched_worker.to_string() - } else { - tree.get_smallest_tenant() - } - }; - - // Update queues and tree - *running_queue.get_mut(&selected_url).unwrap() += 1; - - *processed_queue - .lock() - .unwrap() - .get_mut(&selected_url) - .unwrap() += 1; - tree.insert(&text, &selected_url); - - selected_url - } - }; - - let is_stream = serde_json::from_slice::(&body) - .map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false)) - .unwrap_or(false); - - let res = match client - .post(format!("{}/{}", worker_url.clone(), route)) - .header( - "Content-Type", - req.headers() - .get("Content-Type") - .and_then(|h| h.to_str().ok()) - .unwrap_or("application/json"), - ) - .body(body.to_vec()) - .send() - .await - { - Ok(res) => res, - Err(_) => return HttpResponse::InternalServerError().finish(), - }; - - let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) - .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); - - if !is_stream { - // For non-streaming requests, get response first - let response = match res.bytes().await { - Ok(body) => HttpResponse::build(status).body(body.to_vec()), - Err(e) => { - let error_msg = format!("Failed to get response body: {}", e); - HttpResponse::InternalServerError().body(error_msg) - } - }; - - // Then decrement running queue counter if using CacheAware - if let Router::CacheAware { running_queue, .. } = self { - if let Ok(mut queue) = running_queue.lock() { - if let Some(count) = queue.get_mut(&worker_url) { - *count = count.saturating_sub(1); - } - } - } - - response - } else if let Router::CacheAware { running_queue, .. } = self { - let running_queue = Arc::clone(running_queue); - let worker_url = worker_url.clone(); - - HttpResponse::build(status) - .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))) - .streaming( - res.bytes_stream() - .map_err(|_| { - actix_web::error::ErrorInternalServerError("Failed to read stream") - }) - .inspect(move |bytes| { - let bytes = bytes.as_ref().unwrap(); - if bytes - .as_ref() - .windows(12) - .any(|window| window == b"data: [DONE]") - { - let mut locked_queue = running_queue.lock().unwrap(); - let count = locked_queue.get_mut(&worker_url).unwrap(); - *count = count.saturating_sub(1); - debug!("streaming is done!!") - } - }), - ) - } else { - HttpResponse::build(status) - .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))) - .streaming(res.bytes_stream().map_err(|_| { - actix_web::error::ErrorInternalServerError("Failed to read stream") - })) - } - } - - pub fn add_worker(&self, worker_url: String) { - match self { - Router::RoundRobin { worker_urls, .. } - | Router::Random { worker_urls } - | Router::CacheAware { worker_urls, .. } => { - let mut urls = worker_urls.write().unwrap(); - info!("Added worker: {}", worker_url); - urls.push(worker_url); - } - } - } -} diff --git a/rust/src/server.rs b/rust/src/server.rs deleted file mode 100644 index 7197b9a2709..00000000000 --- a/rust/src/server.rs +++ /dev/null @@ -1,208 +0,0 @@ -use crate::router::PolicyConfig; -use crate::router::Router; -use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder}; -use bytes::Bytes; -use env_logger::Builder; -use log::{info, LevelFilter}; -use std::collections::HashMap; -use std::io::Write; - -#[derive(Debug)] -pub struct AppState { - router: Router, - client: reqwest::Client, -} - -impl AppState { - pub fn new( - worker_urls: Vec, - client: reqwest::Client, - policy_config: PolicyConfig, - ) -> Self { - // Create router based on policy - let router = Router::new(worker_urls, policy_config); - - Self { router, client } - } -} - -async fn forward_request( - client: &reqwest::Client, - worker_url: String, - route: String, -) -> HttpResponse { - match client.get(format!("{}{}", worker_url, route)).send().await { - Ok(res) => { - let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) - .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); - - // print the status - println!( - "Forwarding Request Worker URL: {}, Route: {}, Status: {}", - worker_url, route, status - ); - match res.bytes().await { - Ok(body) => HttpResponse::build(status).body(body.to_vec()), - Err(_) => HttpResponse::InternalServerError().finish(), - } - } - Err(_) => HttpResponse::InternalServerError().finish(), - } -} - -#[get("/health")] -async fn health(data: web::Data) -> impl Responder { - let worker_url = match data.router.get_first() { - Some(url) => url, - None => return HttpResponse::InternalServerError().finish(), - }; - - forward_request(&data.client, worker_url, "/health".to_string()).await -} - -#[get("/health_generate")] -async fn health_generate(data: web::Data) -> impl Responder { - let worker_url = match data.router.get_first() { - Some(url) => url, - None => return HttpResponse::InternalServerError().finish(), - }; - - forward_request(&data.client, worker_url, "/health_generate".to_string()).await -} - -#[get("/get_server_info")] -async fn get_server_info(data: web::Data) -> impl Responder { - let worker_url = match data.router.get_first() { - Some(url) => url, - None => return HttpResponse::InternalServerError().finish(), - }; - - forward_request(&data.client, worker_url, "/get_server_info".to_string()).await -} - -#[get("/v1/models")] -async fn v1_models(data: web::Data) -> impl Responder { - let worker_url = match data.router.get_first() { - Some(url) => url, - None => return HttpResponse::InternalServerError().finish(), - }; - - forward_request(&data.client, worker_url, "/v1/models".to_string()).await -} - -#[get("/get_model_info")] -async fn get_model_info(data: web::Data) -> impl Responder { - let worker_url = match data.router.get_first() { - Some(url) => url, - None => return HttpResponse::InternalServerError().finish(), - }; - - forward_request(&data.client, worker_url, "/get_model_info".to_string()).await -} - -#[post("/generate")] -async fn generate(req: HttpRequest, body: Bytes, data: web::Data) -> impl Responder { - data.router - .dispatch(&data.client, req, body, "generate") - .await -} - -#[post("/v1/chat/completions")] -async fn v1_chat_completions( - req: HttpRequest, - body: Bytes, - data: web::Data, -) -> impl Responder { - data.router - .dispatch(&data.client, req, body, "v1/chat/completions") - .await -} - -#[post("/v1/completions")] -async fn v1_completions( - req: HttpRequest, - body: Bytes, - data: web::Data, -) -> impl Responder { - data.router - .dispatch(&data.client, req, body, "v1/completions") - .await -} - -#[post("/add_worker")] -async fn add_worker( - query: web::Query>, - data: web::Data, -) -> impl Responder { - let worker_url = match query.get("url") { - Some(url) => url.to_string(), - None => { - return HttpResponse::BadRequest() - .body("Worker URL required. Provide 'url' query parameter") - } - }; - data.router.add_worker(worker_url); - HttpResponse::Ok().finish() -} - -pub struct ServerConfig { - pub host: String, - pub port: u16, - pub worker_urls: Vec, - pub policy_config: PolicyConfig, - pub verbose: bool, -} - -pub async fn startup(config: ServerConfig) -> std::io::Result<()> { - Builder::new() - .format(|buf, record| { - use chrono::Local; - writeln!( - buf, - "[Router (Rust)] {} - {} - {}", - Local::now().format("%Y-%m-%d %H:%M:%S"), - record.level(), - record.args() - ) - }) - .filter( - None, - if config.verbose { - LevelFilter::Debug - } else { - LevelFilter::Info - }, - ) - .init(); - - info!("Starting server on {}:{}", config.host, config.port); - info!("Worker URLs: {:?}", config.worker_urls); - info!("Policy Config: {:?}", config.policy_config); - - let client = reqwest::Client::builder() - .build() - .expect("Failed to create HTTP client"); - - let app_state = web::Data::new(AppState::new( - config.worker_urls, - client, - config.policy_config, - )); - - HttpServer::new(move || { - App::new() - .app_data(app_state.clone()) - .service(generate) - .service(v1_chat_completions) - .service(v1_completions) - .service(v1_models) - .service(get_model_info) - .service(health) - .service(health_generate) - .service(get_server_info) - .service(add_worker) - }) - .bind((config.host, config.port))? - .run() - .await -} diff --git a/scripts/ci_install_dependency.sh b/scripts/ci_install_dependency.sh index 787cc8b952c..1a059d5ff68 100755 --- a/scripts/ci_install_dependency.sh +++ b/scripts/ci_install_dependency.sh @@ -1,16 +1,20 @@ +#!/bin/bash +set -euxo pipefail + # Install the dependency in CI. # Use repo from environment variable, passed from GitHub Actions -FLASHINFER_REPO="${FLASHINFER_REPO:-https://flashinfer.ai/whl/cu121/torch2.4}" +FLASHINFER_REPO="${FLASHINFER_REPO:-https://flashinfer.ai/whl/cu124/torch2.4/flashinfer}" SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" bash "${SCRIPT_DIR}/killall_sglang.sh" pip install --upgrade pip -pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu121/torch2.4/flashinfer/ +pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/ -# Force reinstall flashinfer -pip install flashinfer -i ${FLASHINFER_REPO} --force-reinstall +# Force reinstall flashinfer and torch_memory_saver +pip install flashinfer==0.1.6 --find-links ${FLASHINFER_REPO} --force-reinstall --no-deps +pip install torch_memory_saver --force-reinstall pip install transformers==4.45.2 sentence_transformers accelerate peft @@ -19,3 +23,6 @@ pip install cutex # For compling xgrammar kernels pip install cuda-python nvidia-cuda-nvrtc-cu12 + +# reinstall sgl-kernel +pip install sgl-kernel --force-reinstall --no-deps diff --git a/scripts/ci_install_rust.sh b/scripts/ci_install_rust.sh index 85b3e95697a..519155dfbe8 100755 --- a/scripts/ci_install_rust.sh +++ b/scripts/ci_install_rust.sh @@ -1,6 +1,14 @@ -# these are required for actix -apt-get update -apt-get install -y libssl-dev pkg-config +#!/bin/bash +set -euxo pipefail + +# Check if sudo is available +if command -v sudo >/dev/null 2>&1; then + sudo apt-get update + sudo apt-get install -y libssl-dev pkg-config +else + apt-get update + apt-get install -y libssl-dev pkg-config +fi # Install rustup (Rust installer and version manager) curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y diff --git a/scripts/deprecated/test_httpserver_classify.py b/scripts/deprecated/test_httpserver_classify.py deleted file mode 100644 index dbcafb88d7d..00000000000 --- a/scripts/deprecated/test_httpserver_classify.py +++ /dev/null @@ -1,69 +0,0 @@ -""" -Usage: -python3 -m sglang.launch_server --disable-cuda-graph --model-path /model/llama-classification - -python3 test_httpserver_classify.py -""" - -import argparse - -import numpy as np -import requests - - -def get_logits(url, prompt): - response = requests.post( - url + "/generate", - json={ - "text": prompt, - "sampling_params": { - "max_new_tokens": 0, - }, - "return_logprob": True, - }, - ) - return response.json()["meta_info"]["normalized_prompt_logprob"] - - -def get_logits_batch(url, prompts): - response = requests.post( - url + "/generate", - json={ - "text": prompts, - "sampling_params": { - "max_new_tokens": 0, - }, - "return_logprob": True, - }, - ) - ret = response.json() - logits = np.array( - list( - ret[i]["meta_info"]["normalized_prompt_logprob"] - for i in range(len(prompts)) - ) - ) - return logits - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--host", type=str, default="http://127.0.0.1") - parser.add_argument("--port", type=int, default=30000) - args = parser.parse_args() - - url = f"{args.host}:{args.port}" - - # A single request - prompt = "This is a test prompt.<|eot_id|>" - logits = get_logits(url, prompt) - print(f"{logits=}") - - # A batch of requests - prompts = [ - "This is a test prompt.<|eot_id|>", - "This is another test prompt.<|eot_id|>", - "This is a long long long long test prompt.<|eot_id|>", - ] - logits = get_logits_batch(url, prompts) - print(f"{logits=}") diff --git a/scripts/deprecated/test_httpserver_decode_stream.py b/scripts/deprecated/test_httpserver_decode_stream.py index 955c368d154..616eaf6c4b1 100644 --- a/scripts/deprecated/test_httpserver_decode_stream.py +++ b/scripts/deprecated/test_httpserver_decode_stream.py @@ -42,7 +42,6 @@ def test_decode_stream(url, return_logprob, top_logprobs_num): if return_logprob: assert data["meta_info"]["input_token_logprobs"] is not None assert data["meta_info"]["output_token_logprobs"] is not None - assert data["meta_info"]["normalized_prompt_logprob"] is not None for logprob, token_id, token_text in data["meta_info"][ "output_token_logprobs" ][prev:]: diff --git a/scripts/deprecated/test_jump_forward.py b/scripts/deprecated/test_jump_forward.py index 60074a04005..315a50b5ba7 100644 --- a/scripts/deprecated/test_jump_forward.py +++ b/scripts/deprecated/test_jump_forward.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, constr import sglang as sgl -from sglang.srt.constrained import build_regex_from_object +from sglang.srt.constrained.outlines_backend import build_regex_from_object from sglang.test.test_utils import ( add_common_sglang_args_and_parse, select_sglang_backend, diff --git a/scripts/killall_sglang.sh b/scripts/killall_sglang.sh index fcad493c59c..163a60f184b 100755 --- a/scripts/killall_sglang.sh +++ b/scripts/killall_sglang.sh @@ -1,5 +1,29 @@ -# Kill all SGLang processes and free the GPU memory. +#!/bin/bash -kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}') -kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}') -kill -9 $(ps aux | grep 'sglang.bench' | grep -v 'grep' | awk '{print $2}') +# Check if sudo is available +if command -v sudo >/dev/null 2>&1; then + sudo apt-get update + sudo apt-get install -y lsof +else + apt-get update + apt-get install -y lsof +fi + +# Show current GPU status +nvidia-smi + +# Clean SGLang processes +kill -9 $(ps aux | grep 'sglang::' | grep -v 'grep' | awk '{print $2}') 2>/dev/null +kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}') 2>/dev/null +kill -9 $(ps aux | grep 'sglang.bench' | grep -v 'grep' | awk '{print $2}') 2>/dev/null +kill -9 $(ps aux | grep 'sglang.data_parallel' | grep -v 'grep' | awk '{print $2}') 2>/dev/null + +# Clean all GPU processes if any argument is provided +if [ $# -gt 0 ]; then + kill -9 $(nvidia-smi | sed -n '/Processes:/,$p' | grep " [0-9]" | awk '{print $5}') 2>/dev/null + lsof /dev/nvidia* | awk '{print $2}' | xargs kill -9 2>/dev/null +fi + + +# Show GPU status after clean up +nvidia-smi diff --git a/scripts/playground/reference_hf.py b/scripts/playground/reference_hf.py index 7901145c6d7..3ece3d648a9 100644 --- a/scripts/playground/reference_hf.py +++ b/scripts/playground/reference_hf.py @@ -25,12 +25,88 @@ import argparse +import requests import torch -from transformers import AutoModelForCausalLM +from PIL import Image +from transformers import ( + AutoModelForCausalLM, + AutoModelForImageTextToText, + AutoProcessor, +) from sglang.srt.hf_transformers_utils import get_tokenizer +@torch.no_grad() +def vlm_text_with_image(args): + # Load the processor and model for ImageTextToText tasks + processor = AutoProcessor.from_pretrained(args.model_path, trust_remote_code=True) + model = AutoModelForImageTextToText.from_pretrained( + args.model_path, + torch_dtype=args.dtype, + low_cpu_mem_usage=True, + device_map="auto", + trust_remote_code=True, + ) + + torch.cuda.set_device(0) + + # List of image URLs to process + image_urls = [ + "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true" + ] + + # Conversation template for the processor + conversation = [ + { + "role": "user", + "content": [ + { + "type": "image", + }, + {"type": "text", "text": "Describe this image."}, + ], + } + ] + + max_new_tokens = args.max_new_tokens + + for i, url in enumerate(image_urls): + # Load the image from the URL + image = Image.open(requests.get(url, stream=True).raw) + + # Apply the chat template to the text prompt + # Notice that not all processors support chat templates. + # LLaVA and QWen are two processors that support chat templates. + if not hasattr(processor, "apply_chat_template"): + raise ValueError("The processor does not support chat templates.") + text_prompt = processor.apply_chat_template( + conversation, add_generation_prompt=True + ) + + # Prepare inputs for the model + inputs = processor(text=[text_prompt], images=[image], return_tensors="pt").to( + "cuda:0" + ) + + # Generate output from the model + output_ids = model.generate( + **inputs, do_sample=False, max_new_tokens=max_new_tokens + ) + output_str = processor.decode(output_ids[0]) + + # Get the logits from the model's forward pass + outputs = model.forward(**inputs) + logits = outputs.logits[0, -1, :] + + print(f"\n========== Image {i} ==========") + print("prefill logits (final)", logits) + # TODO(gaocegege): The output contains numerous <|image_pad|> tokens, + # making it cluttered and difficult to read. + # These tokens should be removed or cleaned up for better readability. + print(output_str) + + @torch.no_grad() def normal_text(args): t = get_tokenizer(args.model_path, trust_remote_code=True) @@ -108,7 +184,11 @@ def synthetic_tokens(args): parser.add_argument("--dtype", type=str, default="float16") + parser.add_argument("--model-type", type=str, default="text") + args = parser.parse_args() - normal_text(args) - # synthetic_tokens(args) + if args.model_type == "vlm": + vlm_text_with_image(args) + else: + normal_text(args) diff --git a/scripts/update_kernel_whl_index.py b/scripts/update_kernel_whl_index.py new file mode 100644 index 00000000000..a42969641f5 --- /dev/null +++ b/scripts/update_kernel_whl_index.py @@ -0,0 +1,16 @@ +# Reference: https://github.com/flashinfer-ai/flashinfer/blob/v0.2.0/scripts/update_whl_index.py + +import hashlib +import pathlib +import re + +for path in sorted(pathlib.Path("sgl-kernel/dist").glob("*.whl")): + with open(path, "rb") as f: + sha256 = hashlib.sha256(f.read()).hexdigest() + ver = re.findall(r"sgl_kernel-([0-9.]+(?:\.post[0-9]+)?)-", path.name)[0] + index_dir = pathlib.Path(f"sgl-whl/cu118/sgl-kernel") + index_dir.mkdir(exist_ok=True) + base_url = "https://github.com/sgl-project/whl/releases/download" + full_url = f"{base_url}/v{ver}/{path.name}#sha256={sha256}" + with (index_dir / "index.html").open("a") as f: + f.write(f'{path.name}
\n') diff --git a/scripts/version_branch_to_tag.sh b/scripts/version_branch_to_tag.sh index 53272c1efe5..9f587fb0b54 100755 --- a/scripts/version_branch_to_tag.sh +++ b/scripts/version_branch_to_tag.sh @@ -1,4 +1,5 @@ #!/bin/bash +set -euxo pipefail # This script is used for release. # It tags all remote branches starting with 'v' with the same name as the branch, diff --git a/sgl-kernel/.clang-format b/sgl-kernel/.clang-format new file mode 100644 index 00000000000..5e690c02885 --- /dev/null +++ b/sgl-kernel/.clang-format @@ -0,0 +1,8 @@ +BasedOnStyle: Google +IndentWidth: 2 +ColumnLimit: 120 +AllowShortFunctionsOnASingleLine: Empty +DerivePointerAlignment: false +PointerAlignment: Left +NamespaceIndentation: None +SortIncludes: true diff --git a/sgl-kernel/3rdparty/cccl b/sgl-kernel/3rdparty/cccl new file mode 160000 index 00000000000..b5fe509fd11 --- /dev/null +++ b/sgl-kernel/3rdparty/cccl @@ -0,0 +1 @@ +Subproject commit b5fe509fd11a925f90d6495176707cc1184eed9d diff --git a/sgl-kernel/3rdparty/cutlass b/sgl-kernel/3rdparty/cutlass new file mode 160000 index 00000000000..bdd641790ad --- /dev/null +++ b/sgl-kernel/3rdparty/cutlass @@ -0,0 +1 @@ +Subproject commit bdd641790ad49353b40ada41330552a78d2f8b5a diff --git a/sgl-kernel/3rdparty/flashinfer b/sgl-kernel/3rdparty/flashinfer new file mode 160000 index 00000000000..e5a3befbe3e --- /dev/null +++ b/sgl-kernel/3rdparty/flashinfer @@ -0,0 +1 @@ +Subproject commit e5a3befbe3e63025f0158bc96b218a9c5f402ac7 diff --git a/sgl-kernel/3rdparty/turbomind b/sgl-kernel/3rdparty/turbomind new file mode 160000 index 00000000000..0c9d0c724a9 --- /dev/null +++ b/sgl-kernel/3rdparty/turbomind @@ -0,0 +1 @@ +Subproject commit 0c9d0c724a99974ca3af0c12b24ef8a0444c4fd9 diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt deleted file mode 100644 index c635b75c348..00000000000 --- a/sgl-kernel/CMakeLists.txt +++ /dev/null @@ -1,47 +0,0 @@ -cmake_minimum_required(VERSION 3.18) -project(sgl-kernel LANGUAGES CXX CUDA) - -set(CMAKE_EXPORT_COMPILE_COMMANDS ON) - -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_CXX_STANDARD_REQUIRED ON) - -set(CMAKE_CUDA_STANDARD 17) -set(CMAKE_CUDA_STANDARD_REQUIRED ON) - -find_package(PythonInterp 3 REQUIRED) -find_package(PythonLibs 3 REQUIRED) - -execute_process( - COMMAND ${PYTHON_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)" - OUTPUT_VARIABLE TORCH_CMAKE_PATH - OUTPUT_STRIP_TRAILING_WHITESPACE -) - -message(STATUS "PYTHON_EXECUTABLE: ${PYTHON_EXECUTABLE}") -message(STATUS "TORCH_CMAKE_PATH: ${TORCH_CMAKE_PATH}") - -list(APPEND CMAKE_PREFIX_PATH "${TORCH_CMAKE_PATH}") - -find_package(Torch REQUIRED) - -include_directories(${PYTHON_INCLUDE_DIRS}) - -add_library(warp_reduce SHARED - src/sgl-kernel/csrc/warp_reduce.cc - src/sgl-kernel/csrc/warp_reduce_kernel.cu -) - -target_include_directories(warp_reduce PRIVATE - ${CUDA_INCLUDE_DIRS} - ${TORCH_INCLUDE_DIRS} -) - -target_link_libraries(warp_reduce PRIVATE - ${TORCH_LIBRARIES} - ${PYTHON_LIBRARIES} -) - -set_target_properties(warp_reduce PROPERTIES - CUDA_SEPARABLE_COMPILATION ON -) diff --git a/sgl-kernel/Makefile b/sgl-kernel/Makefile index 3186031acc7..1384f1bcd81 100644 --- a/sgl-kernel/Makefile +++ b/sgl-kernel/Makefile @@ -1,19 +1,28 @@ -.PHONY: tree ln install build clean test +.PHONY: tree ln submodule install build clean rebuild test format tree: - @tree --prune -I "__pycache__|*.egg-info|*.so|build" + @tree --prune -I "__pycache__|*.egg-info|*.so|build|3rdparty|dist" -ln: - @rm -rf build && cmake . -DCMAKE_EXPORT_COMPILE_COMMANDS=1 -DCMAKE_CUDA_COMPILER=nvcc -B build && rm -rf compile_commands.json && ln -s build/compile_commands.json compile_commands.json +submodule: + @git submodule update --init --recursive -install: +ln: submodule + @rm -rf build && bear python3 setup.py build + +install: submodule @pip install -e . -build: - @python3 setup.py bdist_wheel +build: submodule + @rm -rf dist/* || true && export MAX_JOBS=$(nproc) && python3 setup.py bdist_wheel && pip3 install dist/*whl --force-reinstall --no-deps clean: @rm -rf build dist *.egg-info +rebuild: clean submodule build + @echo "Succeed to rebuild" + test: - @pytest tests/ + @find tests -name "test_*.py" | xargs -n 1 python3 + +format: + @find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' -o -name '*.hpp' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black diff --git a/sgl-kernel/README.md b/sgl-kernel/README.md index 857cae366d8..0572f9758ab 100644 --- a/sgl-kernel/README.md +++ b/sgl-kernel/README.md @@ -1,5 +1,19 @@ # SGL Kernel -Kernel Library for SGLang +[Kernel Library](https://github.com/sgl-project/sglang/tree/main/sgl-kernel) for SGLang [![PyPI](https://img.shields.io/pypi/v/sgl-kernel)](https://pypi.org/project/sgl-kernel) + +## Installation + +For CUDA 11.8: + +```bash +pip3 install sgl-kernel -i https://docs.sglang.ai/whl/cu118 +``` + +For CUDA 12.1 or CUDA 12.4: + +```bash +pip3 install sgl-kernel +``` diff --git a/sgl-kernel/THIRDPARTYNOTICES.txt b/sgl-kernel/THIRDPARTYNOTICES.txt new file mode 100644 index 00000000000..fcae14df3aa --- /dev/null +++ b/sgl-kernel/THIRDPARTYNOTICES.txt @@ -0,0 +1,430 @@ +Notice for flashinfer-ai/flashinfer +------------------------------- + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +------------------------------------------------------------------------------------------------- +Some of the code in this project are adapted from other open-source projects with different +licenses. This product also bundles some third-party components under other open source licenses. +This section summarizes those components and their licenses. +See licenses/ for text of these licenses. + +BSD 3-Clause License +-------------------- + +include/flashinfer/attention/hopper/epilogue.cuh +include/flashinfer/attention/hopper/mainloop.cuh +include/flashinfer/attention/hopper/kernel_traits.cuh +include/flashinfer/attention/hopper/named_barrier.cuh +include/flashinfer/attention/hopper/tile_scheduler.cuh +include/flashinfer/attention/hopper/utils.cuh + +BSD 3-Clause "New" License +-------------------------- + +3rdparty/cutlass +include/flashinfer/attention/hopper/block_sparse_gather.cuh + +Notice for NVIDIA/TensorRT-LLM +------------------------------- + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/sgl-kernel/benchmark/bench_fp8_gemm.py b/sgl-kernel/benchmark/bench_fp8_gemm.py new file mode 100644 index 00000000000..c3f80475356 --- /dev/null +++ b/sgl-kernel/benchmark/bench_fp8_gemm.py @@ -0,0 +1,164 @@ +import argparse +import copy +import itertools + +import torch +import triton +from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm +from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm +from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant + +# Weight Shapes are in the format +# ([K, N], TP_SPLIT_DIM) +# Example: +# A shape of ([14336, 4096], 0) indicates the following GEMM shape, +# - TP1 : K = 14336, N = 4096 +# - TP2 : K = 7168, N = 4096 +# A shape of ([4096, 6144], 1) indicates the following GEMM shape, +# - TP1 : K = 4096, N = 6144 +# - TP4 : K = 4096, N = 1536 + +# TP1 shapes +WEIGHT_SHAPES = { + "meta-llama/Llama-3.1-8B-Instruct": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-3.3-70B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 57344], 1), + ([28672, 8192], 0), + ], + "mistralai/Mistral-Large-Instruct-2407": [ + ([12288, 14336], 1), + ([12288, 12288], 0), + ([12288, 57344], 1), + ([28672, 12288], 0), + ], + "Qwen/Qwen2.5-7B-Instruct": [ + ([3584, 4608], 1), + ([3584, 3584], 0), + ([3584, 37888], 1), + ([18944, 3584], 0), + ], + "Qwen/Qwen2.5-32B-Instruct": [ + ([5120, 7168], 1), + ([5120, 5120], 0), + ([5120, 55296], 1), + ([27648, 5120], 0), + ], + "Qwen/Qwen2.5-72B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 59136], 1), + ([29568, 8192], 0), + ], + "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [ + ([2048, 3072], 1), + ([2048, 4096], 1), + ([2048, 2048], 0), + ([2048, 576], 0), + ([2048, 21888], 1), + ([10944, 2048], 0), + ([2048, 2816], 1), + ([1408, 2048], 0), + ], +} + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048], + x_log=False, + line_arg="provider", + line_vals=[ + "vllm-fp8-fp16", + "vllm-fp8-bf16", + "sglang-fp8-fp16", + "sglang-fp8-bf16", + ], + line_names=[ + "vllm-fp8-fp16", + "vllm-fp8-bf16", + "sglang-fp8-fp16", + "sglang-fp8-bf16", + ], + styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")], + ylabel="GB/s", + plot_name="fp8 scaled matmul", + args={}, + ) +) +def benchmark(batch_size, provider, N, K): + # M, N, K = batch_size, 4096, 8192 + M = batch_size + a = torch.ones((M, K), device="cuda") * 5.0 + b = torch.ones((N, K), device="cuda") * 5.0 + scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) + scale_b = torch.randn((N,), device="cuda", dtype=torch.float32) + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) + b_fp8 = b_fp8.t() + quantiles = [0.5, 0.2, 0.8] + + dtype = torch.float16 if "fp16" in provider else torch.bfloat16 + + if "vllm-fp8" in provider: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype), + quantiles=quantiles, + ) + elif "sglang-fp8" in provider: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: sgl_scaled_mm( + a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None + ), + quantiles=quantiles, + ) + + gbps = lambda ms: (2 * M * N * K + M * N) * a.element_size() * 1e-9 / (ms * 1e-3) + return gbps(ms), gbps(max_ms), gbps(min_ms) + + +def prepare_shapes(args): + KN_model_names = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + assert model in WEIGHT_SHAPES + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KN.append(model) + KN_model_names.append(KN) + return KN_model_names + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + type=str, + default=["meta-llama/Llama-3.1-8B-Instruct"], + help="List of models to benchmark", + ) + parser.add_argument( + "--tp-sizes", + nargs="+", + type=int, + default=[1], + help="List of tensor parallel sizes", + ) + args = parser.parse_args() + + KN_model_names = prepare_shapes(args) + for K, N, model_name in KN_model_names: + print(f"{model_name} N={N} K={K}: ") + benchmark.run( + print_data=True, show_plots=True, save_path="bench_fp8_res", N=N, K=K + ) + + print("Benchmark finished!") diff --git a/sgl-kernel/benchmark/bench_int8_gemm.py b/sgl-kernel/benchmark/bench_int8_gemm.py new file mode 100644 index 00000000000..c5a709393c1 --- /dev/null +++ b/sgl-kernel/benchmark/bench_int8_gemm.py @@ -0,0 +1,146 @@ +import argparse +import copy +import itertools + +import torch +import triton +from sgl_kernel import int8_scaled_mm +from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm + + +def to_int8(tensor: torch.Tensor) -> torch.Tensor: + return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) + + +WEIGHT_SHAPES = { + "meta-llama/Llama-3.1-8B-Instruct": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-3.3-70B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 57344], 1), + ([28672, 8192], 0), + ], + "mistralai/Mistral-Large-Instruct-2407": [ + ([12288, 14336], 1), + ([12288, 12288], 0), + ([12288, 57344], 1), + ([28672, 12288], 0), + ], + "Qwen/Qwen2.5-7B-Instruct": [ + ([3584, 4608], 1), + ([3584, 3584], 0), + ([3584, 37888], 1), + ([18944, 3584], 0), + ], + "Qwen/Qwen2.5-32B-Instruct": [ + ([5120, 7168], 1), + ([5120, 5120], 0), + ([5120, 55296], 1), + ([27648, 5120], 0), + ], + "Qwen/Qwen2.5-72B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 59136], 1), + ([29568, 8192], 0), + ], + "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [ + ([2048, 3072], 1), + ([2048, 4096], 1), + ([2048, 2048], 0), + ([2048, 576], 0), + ([2048, 21888], 1), + ([10944, 2048], 0), + ([2048, 2816], 1), + ([1408, 2048], 0), + ], +} + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048], + x_log=False, + line_arg="provider", + line_vals=["vllm", "sgl-kernel"], + line_names=["vllm int8 gemm", "sgl-kernel int8 gemm"], + styles=[("blue", "-"), ("orange", "-")], + ylabel="GB/s", + plot_name="int8 scaled matmul", + args={}, + ) +) +def benchmark(batch_size, provider, N, K): + M = batch_size + a = to_int8(torch.randn((M, K), device="cuda") * 5) + b = to_int8(torch.randn((N, K), device="cuda").t() * 5) + scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) + scale_b = torch.randn((N,), device="cuda", dtype=torch.float32) + bias = torch.randn((N,), device="cuda", dtype=torch.float16) + + quantiles = [0.5, 0.2, 0.8] + if provider == "sgl-kernel": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias), + quantiles=quantiles, + ) + if provider == "vllm": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: vllm_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias), + quantiles=quantiles, + ) + gbps = ( + lambda ms: ( + (2 * M * N * K - M * N) * a.element_size() + + (3 * M * N) * scale_a.element_size() + ) + * 1e-9 + / (ms * 1e-3) + ) + return gbps(ms), gbps(max_ms), gbps(min_ms) + + +def prepare_shapes(args): + KN_model_names = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + assert model in WEIGHT_SHAPES + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KN.append(model) + KN_model_names.append(KN) + return KN_model_names + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + type=str, + default=["meta-llama/Llama-3.1-8B-Instruct"], + help="List of models to benchmark", + ) + parser.add_argument( + "--tp-sizes", + nargs="+", + type=int, + default=[1], + help="List of tensor parallel sizes", + ) + args = parser.parse_args() + + KN_model_names = prepare_shapes(args) + for K, N, model_name in KN_model_names: + print(f"{model_name} N={N} K={K}: ") + benchmark.run( + print_data=True, show_plots=True, save_path="bench_int8_res", N=N, K=K + ) + + print("Benchmark finished!") diff --git a/sgl-kernel/benchmark/bench_lightning_attention_decode.py b/sgl-kernel/benchmark/bench_lightning_attention_decode.py new file mode 100644 index 00000000000..24872e61a4d --- /dev/null +++ b/sgl-kernel/benchmark/bench_lightning_attention_decode.py @@ -0,0 +1,299 @@ +import itertools +import math + +import torch +import triton +import triton.language as tl +from sgl_kernel import lightning_attention_decode + + +def next_power_of_2(n): + return 2 ** (int(math.ceil(math.log(n, 2)))) + + +@triton.jit +def _decode_kernel( + Q, + K, + V, + KV, + Out, + S, + b: tl.constexpr, + h: tl.constexpr, + n: tl.constexpr, + d: tl.constexpr, + d_original: tl.constexpr, + e: tl.constexpr, + e_original: tl.constexpr, +): + off_bh = tl.program_id(0) + off_h = off_bh % h + + qk_offset = off_bh * n * d + v_offset = off_bh * n * e + o_offset = off_bh * n * e + kv_offset = off_bh * d * e + + s = tl.load(S + off_h) + ratio = tl.exp(-s) + + d_idx = tl.arange(0, d) + e_idx = tl.arange(0, e) + + # Create masks for original dimensions + d_mask = d_idx < d_original + e_mask = e_idx < e_original + + # Load with masking + q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0) + k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0) + v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0) + + # Load KV with 2D masking + kv = tl.load( + KV + kv_offset + d_idx[:, None] * e + e_idx[None, :], + mask=(d_mask[:, None] & e_mask[None, :]), + other=0.0, + ) + + # Compute outer product using element-wise operations + k_v_prod = k[:, None] * v[None, :] + kv = ratio * kv + k_v_prod + + # Store KV with 2D masking + tl.store( + KV + kv_offset + d_idx[:, None] * e + e_idx[None, :], + kv.to(KV.dtype.element_ty), + mask=(d_mask[:, None] & e_mask[None, :]), + ) + + # Compute matrix-vector multiplication using element-wise operations and reduction + o = tl.sum(q[:, None] * kv, axis=0) + + # Store output with masking + tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask) + + +def triton_lightning_attn_decode(q, k, v, kv, s): + """Triton implementation of Lightning Attention decode operation""" + b, h, n, d = q.shape + e = v.shape[-1] + assert n == 1, "Sequence length must be 1 in decode mode" + + # Get padded dimensions (power of 2) + d_padded = next_power_of_2(d) + e_padded = next_power_of_2(e) + + # Create output tensor (padded) + o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device) + + # Create padded tensors without actually padding the data + q_padded = torch.empty(b, h, n, d_padded, dtype=q.dtype, device=q.device) + k_padded = torch.empty(b, h, n, d_padded, dtype=k.dtype, device=k.device) + v_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device) + kv_padded = torch.empty( + b, h, d_padded, e_padded, dtype=torch.float32, device=kv.device + ) + + # Copy data to padded tensors + q_padded[..., :d] = q + k_padded[..., :d] = k + v_padded[..., :e] = v + kv_padded[..., :d, :e] = kv + + # Launch kernel + grid = (b * h, 1) + _decode_kernel[grid]( + q_padded, + k_padded, + v_padded, + kv_padded, + o_padded, + s, + b=b, + h=h, + n=n, + d=d_padded, + d_original=d, + e=e_padded, + e_original=e, + ) + + # Get unpadded outputs + o = o_padded[..., :e] + kv_out = kv_padded[..., :d, :e] + + return o, kv_out + + +def lightning_attention_decode_naive(q, k, v, past_kv, slope): + """Naive implementation of lightning attention decode""" + original_dtype = q.dtype + ratio = torch.exp(-slope) # [h, 1, 1] + + kv = past_kv + b, h, n, d = q.shape + + output = [] + for i in range(n): + kv = ratio * kv.to(torch.float32) + torch.einsum( + "... n d, ... n e -> ... d e", + k[:, :, i : i + 1], + v[:, :, i : i + 1], + ) + qkv = torch.einsum( + "... n e, ... e d -> ... n d", + q[:, :, i : i + 1].to(torch.float32), + kv.to(torch.float32), + ) + output.append(qkv) + output = torch.concat(output, dim=-2) + + return output.to(original_dtype), kv + + +def lightning_attention_decode_kernel(q, k, v, past_kv, slope, output, new_kv): + return lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv) + + +def calculate_diff(batch_size): + dtype = torch.bfloat16 + device = torch.device("cuda") + num_heads = 64 + head_dim = 96 + seq_len = 1 + + q = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + k = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + v = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device) + slope = torch.randn(num_heads, 1, 1, device=device) + + output_naive, new_kv_naive = lightning_attention_decode_naive( + q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone() + ) + + output_kernel = torch.empty_like(output_naive) + new_kv_kernel = torch.empty_like(new_kv_naive) + lightning_attention_decode_kernel( + q.clone(), + k.clone(), + v.clone(), + past_kv.clone(), + slope.clone(), + output_kernel, + new_kv_kernel, + ) + + output_triton, new_kv_triton = triton_lightning_attn_decode( + q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone() + ) + + if ( + torch.allclose(output_naive, output_kernel, atol=1e-2, rtol=1e-2) + and torch.allclose(output_naive, output_triton, atol=1e-2, rtol=1e-2) + and torch.allclose(new_kv_naive, new_kv_kernel, atol=1e-2, rtol=1e-2) + and torch.allclose(new_kv_naive, new_kv_triton, atol=1e-2, rtol=1e-2) + ): + print("✅ All implementations match") + else: + print("❌ Implementations differ") + + +batch_size_range = [i for i in range(1, 65)] # 1 to 128 +configs = [(bs,) for bs in batch_size_range] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["naive", "kernel", "triton"], + line_names=["PyTorch Naive", "SGL Kernel", "Triton"], + styles=[("blue", "-"), ("red", "-"), ("green", "-")], + ylabel="us", + plot_name="lightning-attention-decode-performance", + args={}, + ) +) +def benchmark(batch_size, provider): + dtype = torch.bfloat16 + device = torch.device("cuda") + num_heads = 64 + head_dim = 96 + seq_len = 1 + + q = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + k = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + v = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device) + slope = torch.randn(num_heads, 1, 1, device=device) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "naive": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: lightning_attention_decode_naive( + q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone() + ), + quantiles=quantiles, + ) + elif provider == "kernel": + output = torch.empty( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + new_kv = torch.empty(batch_size, num_heads, head_dim, head_dim, device=device) + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: lightning_attention_decode_kernel( + q.clone(), + k.clone(), + v.clone(), + past_kv.clone(), + slope.clone(), + output, + new_kv, + ), + quantiles=quantiles, + ) + elif provider == "triton": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: triton_lightning_attn_decode( + q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone() + ), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_path", + type=str, + default="./configs/benchmark_ops/lightning_attention_decode_sgl/", + help="Path to save lightning attention decode benchmark results", + ) + args = parser.parse_args() + + # Run correctness test + calculate_diff(batch_size=4) + + # Run performance benchmark + benchmark.run(print_data=True) diff --git a/sgl-kernel/build.sh b/sgl-kernel/build.sh index b276f0141c2..ffa798d145a 100755 --- a/sgl-kernel/build.sh +++ b/sgl-kernel/build.sh @@ -1,13 +1,28 @@ #!/bin/bash - set -ex +PYTHON_VERSION=$1 +CUDA_VERSION=$2 +PYTHON_ROOT_PATH=/opt/python/cp${PYTHON_VERSION//.}-cp${PYTHON_VERSION//.} + +if (( ${CUDA_VERSION%.*} < 12 )); then + ENABLE_SM90A=0 +else + ENABLE_SM90A=1 +fi -docker run --rm -it \ +docker run --rm \ -v "$(pwd)":/sgl-kernel \ - pytorch/manylinux-builder:cuda12.1 \ + pytorch/manylinux-builder:cuda${CUDA_VERSION} \ bash -c " - pip install --no-cache-dir torch==2.4.0 --index-url https://download.pytorch.org/whl/cu121 && \ + ${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir torch==2.5.1 --index-url https://download.pytorch.org/whl/cu${CUDA_VERSION//.} && \ + ${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir ninja && \ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX' && \ + export CUDA_VERSION=${CUDA_VERSION} && \ + export SGL_KERNEL_ENABLE_BF16=1 && \ + export SGL_KERNEL_ENABLE_FP8=1 && \ + export SGL_KERNEL_ENABLE_SM90A=${ENABLE_SM90A} && \ + mkdir -p /usr/lib/x86_64-linux-gnu/ && \ + ln -s /usr/local/cuda-${CUDA_VERSION}/targets/x86_64-linux/lib/stubs/libcuda.so /usr/lib/x86_64-linux-gnu/libcuda.so && \ cd /sgl-kernel && \ - python setup.py bdist_wheel + ${PYTHON_ROOT_PATH}/bin/python setup.py bdist_wheel " diff --git a/sgl-kernel/developer_guide.md b/sgl-kernel/developer_guide.md new file mode 100644 index 00000000000..2b9859d948f --- /dev/null +++ b/sgl-kernel/developer_guide.md @@ -0,0 +1,55 @@ +# Developer Guide for sgl-kernel + +## Development Environment Setup + +Use Docker to set up the development environment. See [Docker setup guide](https://github.com/sgl-project/sglang/blob/main/docs/developer/development_guide_using_docker.md#setup-docker-container). + +Create and enter development container: +```bash +docker run -itd --shm-size 32g --gpus all -v $HOME/.cache:/root/.cache --ipc=host --name sglang_zhyncs lmsysorg/sglang:dev /bin/zsh +docker exec -it sglang_zhyncs /bin/zsh +``` + +## Project Structure + +### Dependencies + +Third-party libraries: + +- [CCCL](https://github.com/NVIDIA/cccl) +- [CUTLASS](https://github.com/NVIDIA/cutlass) +- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) +- [TurboMind](https://github.com/InternLM/turbomind) + +### Kernel Development + +Steps to add a new kernel: + +1. Implement in [src/sgl-kernel/csrc/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/src/sgl-kernel/csrc) +2. Expose interface in [src/sgl-kernel/include/sgl_kernels_ops.h](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h) +3. Create torch extension in [src/sgl-kernel/torch_extension.cc](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/torch_extension.cc) +4. Create Python wrapper in [src/sgl-kernel/ops/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py) +5. Expose Python interface in [src/sgl-kernel/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/__init__.py) +6. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source + +### Build & Install + +Development build: + +```bash +make build +``` + +Note: + +The `sgl-kernel` is rapidly evolving. If you experience a compilation failure, try using `make rebuild`. + +### Testing & Benchmarking + +1. Add pytest tests in [tests/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/tests) +2. Add benchmarks using [triton benchmark](https://triton-lang.org/main/python-api/generated/triton.testing.Benchmark.html) in [benchmark/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/benchmark) +3. Run test suite + +### Release new version + +Update version in [pyproject.toml](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/pyproject.toml) and [version.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/version.py) diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index 4330d2e19be..bb7d6943348 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -4,23 +4,20 @@ build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.0.2" +version = "0.0.3.post1" description = "Kernel Library for SGLang" readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.9" license = { file = "LICENSE" } classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", - "Programming Language :: C++", - "Programming Language :: CUDA", -] -dependencies = [ - "torch", + "Environment :: GPU :: NVIDIA CUDA" ] +dependencies = [] [project.urls] -"Homepage" = "https://github.com/sgl-project/sglang" +"Homepage" = "https://github.com/sgl-project/sglang/tree/main/sgl-kernel" "Bug Tracker" = "https://github.com/sgl-project/sglang/issues" [tool.setuptools] diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index f2af83643e5..90c3cbc1d3c 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -1,32 +1,159 @@ +import multiprocessing +import os +import sys +from pathlib import Path + +import torch from setuptools import find_packages, setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension +root = Path(__file__).parent.resolve() + + +if "bdist_wheel" in sys.argv and "--plat-name" not in sys.argv: + sys.argv.extend(["--plat-name", "manylinux2014_x86_64"]) + + +def _get_cuda_version(): + if torch.version.cuda: + return tuple(map(int, torch.version.cuda.split("."))) + return (0, 0) + + +def _get_device_sm(): + if torch.cuda.is_available(): + major, minor = torch.cuda.get_device_capability() + return major * 10 + minor + return 0 + + +def _get_version(): + with open(root / "pyproject.toml") as f: + for line in f: + if line.startswith("version"): + return line.split("=")[1].strip().strip('"') + + +operator_namespace = "sgl_kernels" +cutlass_default = root / "3rdparty" / "cutlass" +cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default)) +flashinfer = root / "3rdparty" / "flashinfer" +turbomind = root / "3rdparty" / "turbomind" +include_dirs = [ + cutlass.resolve() / "include", + cutlass.resolve() / "tools" / "util" / "include", + root / "src" / "sgl-kernel" / "include", + root / "src" / "sgl-kernel" / "csrc", + flashinfer.resolve() / "include", + flashinfer.resolve() / "include" / "gemm", + flashinfer.resolve() / "csrc", + "cublas", + "cublasLt", + turbomind.resolve(), + turbomind.resolve() / "src", +] + +nvcc_flags = [ + "-DNDEBUG", + f"-DOPERATOR_NAMESPACE={operator_namespace}", + "-O3", + "-Xcompiler", + "-fPIC", + "-gencode=arch=compute_75,code=sm_75", + "-gencode=arch=compute_80,code=sm_80", + "-gencode=arch=compute_89,code=sm_89", + "-gencode=arch=compute_90,code=sm_90", + "-std=c++17", + "-use_fast_math", + "-DFLASHINFER_ENABLE_F16", + "-Xcompiler=-Wconversion", + "-Xcompiler=-fno-strict-aliasing", +] +nvcc_flags_fp8 = [ + "-DFLASHINFER_ENABLE_FP8", + "-DFLASHINFER_ENABLE_FP8_E4M3", + "-DFLASHINFER_ENABLE_FP8_E5M2", +] + +sources = [ + "src/sgl-kernel/torch_extension.cc", + "src/sgl-kernel/csrc/trt_reduce_internal.cu", + "src/sgl-kernel/csrc/trt_reduce_kernel.cu", + "src/sgl-kernel/csrc/moe_align_kernel.cu", + "src/sgl-kernel/csrc/int8_gemm_kernel.cu", + "src/sgl-kernel/csrc/fp8_gemm_kernel.cu", + "src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu", + "src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu", + "3rdparty/flashinfer/csrc/activation.cu", + "3rdparty/flashinfer/csrc/bmm_fp8.cu", + "3rdparty/flashinfer/csrc/norm.cu", + "3rdparty/flashinfer/csrc/sampling.cu", + "3rdparty/flashinfer/csrc/renorm.cu", + "3rdparty/flashinfer/csrc/rope.cu", +] + +enable_bf16 = os.getenv("SGL_KERNEL_ENABLE_BF16", "0") == "1" +enable_fp8 = os.getenv("SGL_KERNEL_ENABLE_FP8", "0") == "1" +enable_sm90a = os.getenv("SGL_KERNEL_ENABLE_SM90A", "0") == "1" +cuda_version = _get_cuda_version() +sm_version = _get_device_sm() + +if torch.cuda.is_available(): + if cuda_version >= (12, 0) and sm_version >= 90: + nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") + if sm_version >= 90: + nvcc_flags.extend(nvcc_flags_fp8) + if sm_version >= 80: + nvcc_flags.append("-DFLASHINFER_ENABLE_BF16") +else: + # compilation environment without GPU + if enable_sm90a: + nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") + if enable_fp8: + nvcc_flags.extend(nvcc_flags_fp8) + if enable_bf16: + nvcc_flags.append("-DFLASHINFER_ENABLE_BF16") + +for flag in [ + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", +]: + try: + torch.utils.cpp_extension.COMMON_NVCC_FLAGS.remove(flag) + except ValueError: + pass + +cxx_flags = ["-O3"] +libraries = ["c10", "torch", "torch_python", "cuda", "cublas", "cublasLt"] +extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"] + +ext_modules = [ + CUDAExtension( + name="sgl_kernel.ops._kernels", + sources=sources, + include_dirs=include_dirs, + extra_compile_args={ + "nvcc": nvcc_flags, + "cxx": cxx_flags, + }, + libraries=libraries, + extra_link_args=extra_link_args, + py_limited_api=True, + ), +] + setup( name="sgl-kernel", - version="0.0.2", - packages=find_packages(where="src"), + version=_get_version(), + packages=find_packages(), package_dir={"": "src"}, - ext_modules=[ - CUDAExtension( - "sgl_kernel.ops.warp_reduce_cuda", - [ - "src/sgl-kernel/csrc/warp_reduce.cc", - "src/sgl-kernel/csrc/warp_reduce_kernel.cu", - ], - extra_compile_args={ - "nvcc": [ - "-O3", - "-Xcompiler", - "-fPIC", - "-gencode=arch=compute_75,code=sm_75", - "-gencode=arch=compute_80,code=sm_80", - "-gencode=arch=compute_89,code=sm_89", - "-gencode=arch=compute_90,code=sm_90", - ], - "cxx": ["-O3"], - }, + ext_modules=ext_modules, + cmdclass={ + "build_ext": BuildExtension.with_options( + use_ninja=True, max_jobs=multiprocessing.cpu_count() ) - ], - cmdclass={"build_ext": BuildExtension}, - install_requires=["torch"], + }, + options={"bdist_wheel": {"py_limited_api": "cp39"}}, ) diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index edf3921db79..a3d35072d03 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -1,3 +1,51 @@ -from .ops import warp_reduce +from sgl_kernel.ops import ( + apply_rope_with_cos_sin_cache_inplace, + bmm_fp8, + custom_dispose, + custom_reduce, + fp8_scaled_mm, + fused_add_rmsnorm, + gelu_and_mul, + gelu_tanh_and_mul, + gemma_fused_add_rmsnorm, + gemma_rmsnorm, + get_graph_buffer_ipc_meta, + init_custom_reduce, + int8_scaled_mm, + lightning_attention_decode, + min_p_sampling_from_probs, + moe_align_block_size, + register_graph_buffers, + rmsnorm, + sampling_scaling_penalties, + silu_and_mul, + top_k_renorm_prob, + top_k_top_p_sampling_from_probs, + top_p_renorm_prob, +) -__all__ = ["warp_reduce"] +__all__ = [ + "apply_rope_with_cos_sin_cache_inplace", + "bmm_fp8", + "custom_dispose", + "custom_reduce", + "fp8_scaled_mm", + "fused_add_rmsnorm", + "gelu_and_mul", + "gelu_tanh_and_mul", + "gemma_fused_add_rmsnorm", + "gemma_rmsnorm", + "get_graph_buffer_ipc_meta", + "init_custom_reduce", + "int8_scaled_mm", + "lightning_attention_decode", + "min_p_sampling_from_probs", + "moe_align_block_size", + "register_graph_buffers", + "rmsnorm", + "sampling_scaling_penalties", + "silu_and_mul", + "top_k_renorm_prob", + "top_k_top_p_sampling_from_probs", + "top_p_renorm_prob", +] diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h new file mode 100644 index 00000000000..c83cf49ad83 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h @@ -0,0 +1,275 @@ +// Adapted from +// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h + +#pragma once + +#include +#include + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +template +class EpilogueVisitorPerRowPerCol { + public: + using ThreadblockShape = ThreadblockShape_; + static int const kThreadCount = ThreadCount; + + using ScaleTileIterator = ScaleTileIterator_; + using OutputTileIterator = OutputTileIterator_; + using ElementwiseFunctor = ElementwiseFunctor_; + + static int const kIterations = OutputTileIterator::kIterations; + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + using ElementOutput = typename OutputTileIterator::Element; + using LayoutOutput = cutlass::layout::RowMajor; + using ElementAccumulator = ElementAccumulator_; + + using AlphaScaleElementType = typename ScaleTileIterator::Element; + + using ElementCompute = ElementCompute_; + using AccumulatorFragment = Array; + using ComputeFragment = Array; + using OutputVector = Array; + + static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth; + static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1); + + /// Argument structure + struct Arguments { + typename ElementwiseFunctor::Params elementwise; + int64_t batch_stride_alpha; + int64_t batch_stride_C; + int64_t batch_stride_D; + + // + // Methods + // + Arguments() : batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {} + + Arguments(typename ElementwiseFunctor::Params elementwise_) + : elementwise(elementwise_), batch_stride_alpha(0), batch_stride_C(0), batch_stride_D(0) {} + + Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_, int64_t batch_stride_C_, + int64_t batch_stride_D_) + : elementwise(elementwise_), + batch_stride_alpha(batch_stride_alpha_), + batch_stride_C(batch_stride_C_), + batch_stride_D(batch_stride_D_) {} + }; + + struct Params { + typename ElementwiseFunctor::Params elementwise; + int64_t batch_stride_alpha; + int64_t batch_stride_C; + int64_t batch_stride_D; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Arguments const& args) + : elementwise(args.elementwise), + batch_stride_alpha(args.batch_stride_alpha), + batch_stride_C(args.batch_stride_C), + batch_stride_D(args.batch_stride_D) {} + }; + + /// Shared storage + struct SharedStorage {}; + + private: + Params const& params_; + SharedStorage& shared_storage_; + MatrixCoord extent_; + MatrixCoord extent_real_; + ElementwiseFunctor elementwise_; + + bool const with_bias_; + bool const per_token_quant_; + bool const per_channel_quant_; + + AlphaScaleElementType* ptr_alpha_row_; + AlphaScaleElementType* ptr_alpha_col_; + ScaleTileIterator iterator_alpha_col_; + OutputTileIterator iterator_C_; + OutputTileIterator iterator_D_; + + AlphaScaleElementType element_alpha_row_ = 1.0f; + AlphaScaleElementType element_alpha_col_ = 1.0f; + typename ScaleTileIterator::Fragment fragment_alpha_col_; + typename OutputTileIterator::Fragment fragment_C_; + typename OutputTileIterator::Fragment fragment_D_; + + ElementAccumulator beta_; + + int column_offset_; + + MatrixCoord thread_offset_; + + public: + CUTLASS_DEVICE + EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage, + cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx, + typename ScaleTileIterator::Params params_alpha_col, + typename OutputTileIterator::Params params_C, + typename OutputTileIterator::Params params_D, bool with_bias, bool per_token_quant, + bool per_channel_quant, AlphaScaleElementType* ptr_alpha_row, + AlphaScaleElementType* ptr_alpha_col, typename OutputTileIterator::Element* ptr_C, + typename OutputTileIterator::Element* ptr_D, + cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), + int column_offset = 0, + cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0)) + : params_(params), + shared_storage_(shared_storage), + extent_(problem_size), + elementwise_(params.elementwise), + with_bias_(with_bias), + per_token_quant_(per_token_quant), + per_channel_quant_(per_channel_quant), + ptr_alpha_row_(ptr_alpha_row), + ptr_alpha_col_(ptr_alpha_col), + iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset), + iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset), + iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset), + extent_real_(problem_size_real) { + if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr)) { + element_alpha_col_ = *ptr_alpha_col_; + } + + if (!per_token_quant_ && (ptr_alpha_row_ != nullptr)) { + element_alpha_row_ = *ptr_alpha_row_; + } + } + + /// Helper to indicate split-K behavior + CUTLASS_DEVICE + void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme + int split_k_slices) { ///< Total number of split-K slices + } + + /// Called to set the batch index + CUTLASS_DEVICE + void set_batch_index(int batch_idx) { + iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha); + iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C); + iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D); + } + + /// Called at the start of the epilogue just before iterating over accumulator slices + CUTLASS_DEVICE + void begin_epilogue() { + if (per_channel_quant_) { + iterator_alpha_col_.load(fragment_alpha_col_); + } + + if (with_bias_) { + iterator_C_.load(fragment_C_); + } + } + + /// Called at the start of one step before starting accumulator exchange + CUTLASS_DEVICE + void begin_step(int step_idx) { + fragment_D_.clear(); + } + + /// Called at the start of a row + CUTLASS_DEVICE + void begin_row(int row_idx) { + // load alpha_row in begin_step only when per token(row) scaling is used + if (per_token_quant_) { + int thread_offset_row = + iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row(); + + arch::global_load( + element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row()); + } + } + + /// Called after accumulators have been exchanged for each accumulator vector + CUTLASS_DEVICE + void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum) { + NumericArrayConverter source_converter; + + ComputeFragment result = source_converter(accum); + if (per_channel_quant_) { + ComputeFragment alpha_col = reinterpret_cast(&fragment_alpha_col_)[column_idx]; + result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_); + } else { + result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_); + } + + if (with_bias_) { + NumericArrayConverter bias_converter; + OutputVector bias = reinterpret_cast(&fragment_C_)[column_idx]; + result = bias_accumulator_(result, bias_converter(bias)); + } + + // Convert to the output + NumericArrayConverter output_converter; + OutputVector& output = reinterpret_cast(&fragment_D_)[frag_idx]; + output = output_converter(result); + } + + /// Called at the end of a row + CUTLASS_DEVICE + void end_row(int row_idx) {} + + /// Called after all accumulator elements have been visited + CUTLASS_DEVICE + void end_step(int step_idx) { + iterator_D_.store(fragment_D_); + ++iterator_D_; + } + + /// Called after all steps have been completed + CUTLASS_DEVICE + void end_epilogue() {} + + private: + CUTLASS_DEVICE + ComputeFragment per_token_channel_scale_accumulator_(ComputeFragment const& accum, ComputeFragment const& scale_col, + AlphaScaleElementType const& scale_row) { + ComputeFragment result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ComputeFragment::kElements; ++i) { + result[i] = accum[i] * (scale_col[i] * scale_row); + } + + return result; + } + + CUTLASS_DEVICE + ComputeFragment per_token_scale_accumulator_(ComputeFragment const& accum, AlphaScaleElementType const& scale_col, + AlphaScaleElementType const& scale_row) { + ComputeFragment result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ComputeFragment::kElements; ++i) { + result[i] = accum[i] * (scale_col * scale_row); + } + + return result; + } + + CUTLASS_DEVICE + ComputeFragment bias_accumulator_(ComputeFragment const& accum, ComputeFragment const& bias) { + ComputeFragment result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < OutputVector::kElements; ++i) { + result[i] = accum[i] + bias[i]; + } + return result; + } +}; + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h new file mode 100644 index 00000000000..33e82decc2b --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h @@ -0,0 +1,339 @@ +// Adapted from +// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h +#pragma once + +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/* + This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088) + It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs + and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs. + + Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support + that feature at the moment. + */ + +template +class GemmUniversalBaseCompat { + public: + using GemmKernel = GemmKernel_; + using ThreadblockShape = typename GemmKernel::Mma::Shape; + + using ElementA = typename GemmKernel::ElementA; + using LayoutA = typename GemmKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = GemmKernel::kTransformA; + + using ElementB = typename GemmKernel::ElementB; + using LayoutB = typename GemmKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = GemmKernel::kTransformB; + + using ElementC = typename GemmKernel::ElementC; + using LayoutC = typename GemmKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + + using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC; + + using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; + using Operator = typename GemmKernel::Operator; + + /// Argument structure + using Arguments = typename GemmKernel::Arguments; + + protected: + /// Kernel parameters object + typename GemmKernel::Params params_; + + protected: + /// Private helper to obtain the grid dimensions with fix-up for split-K + static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args) { + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); + + gemm_k_size = args.problem_size.k(); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { + int const kAlignK = + const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + + if (gemm_k_size) { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + } + + public: + /// Constructs the GEMM. + GemmUniversalBaseCompat() {} + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args) { + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + ThreadblockSwizzle threadblock_swizzle; + dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + + uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1); + + if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) { + return Status::kErrorInvalidProblem; + } + + return GemmKernel::can_implement(args); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()"); + + size_t workspace_bytes = 0; + + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + if (args.mode == GemmUniversalMode::kGemmSplitKParallel) { + // Split-K parallel always requires a temporary workspace + workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k()); + } else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) { + // Serial split-K only requires a temporary workspace if the number of partitions along the + // GEMM K dimension is greater than one. + workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); + } + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape); + + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const& args) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()"); + + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + + CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n" + << " result = {" << result << "}"); + + return result; + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()"); + + int max_active_blocks = -1; + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + + if (smem_size <= (48 << 10)) { + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, Kernel, + GemmKernel::kThreadCount, smem_size); + + if (result == cudaSuccess) { + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + } else { + // Query assuming zero shared memory then compute occupancy limit based on SMEM + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, Kernel, + GemmKernel::kThreadCount, 0); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " + << cudaGetErrorString(result)); + + return -1; + } + + if (smem_capacity < 0) { + int device_idx = 0; + result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + return -1; + } + + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + return -1; + } + + smem_capacity = static_cast(properties.sharedMemPerMultiprocessor); + } + + int occupancy = std::min(max_active_blocks, smem_capacity / smem_size); + + CUTLASS_TRACE_HOST(" occupancy: " << occupancy); + + return occupancy; + } + + CUTLASS_TRACE_HOST(" returning internal error"); + + return -1; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + size_t workspace_bytes = get_workspace_size(args); + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + if (workspace_bytes) { + if (!workspace) { + CUTLASS_TRACE_HOST(" error: device workspace must not be null"); + + return Status::kErrorWorkspaceNull; + } + + if (args.mode == GemmUniversalMode::kGemm) { + CUTLASS_TRACE_HOST(" clearing device workspace"); + cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); + + return Status::kErrorInternal; + } + } + } + + // Get CUDA grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + // Initialize the Params structure + params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast(workspace)); + + // Specify shared memory capacity for kernel. + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = + cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) { + return Status::kErrorWorkspaceNull; + } + + params_.update(args, workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()"); + + // + // Configure grid and block dimensions + // + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + // + // Launch kernel + // + + CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes"); + + // Launch + cutlass::Kernel<<>>(params_); + + // + // Query for errors + // + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h new file mode 100644 index 00000000000..674e191a077 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h @@ -0,0 +1,453 @@ +// Adapted from +// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h + +#pragma once + +#include +#include +#include +#include +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct GemmWithEpilogueVisitor { + public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueVisitor = typename Epilogue::Visitor; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using TensorRefA = TensorRef; + + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using TensorRefB = TensorRef; + + using ElementCompute = typename EpilogueVisitor::ElementCompute; + using LayoutAlphaCol = cutlass::layout::RowMajor; + using LayoutAlphaRow = cutlass::layout::ColumnMajor; + using TensorRefAlphaCol = TensorRef; + using TensorRefAlphaRow = TensorRef; + + using ElementC = typename EpilogueVisitor::ElementOutput; + using LayoutC = typename Epilogue::Layout; + using TensorRefC = TensorRef; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + using EpilogueOutputOp = + typename Epilogue::Visitor::ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); + + // + // Structures + // + + /// Argument structure + struct Arguments { + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; + + TensorRefA ref_A; + TensorRefB ref_B; + TensorRefAlphaCol ref_alpha_col; + TensorRefAlphaRow ref_alpha_row; + TensorRefC ref_C; + TensorRefC ref_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_D; + + typename EpilogueVisitor::Arguments epilogue_visitor; + + // + // Methods + // + + Arguments() : mode(GemmUniversalMode::kGemm), batch_count(1) {} + + /// constructs an arguments structure + Arguments(GemmCoord problem_size_, TensorRefA ref_A_, TensorRefB ref_B_, TensorRefAlphaCol ref_alpha_col_, + TensorRefAlphaRow ref_alpha_row_, TensorRefC ref_C_, TensorRefC ref_D_, + typename EpilogueVisitor::Arguments epilogue_visitor_) + : mode(GemmUniversalMode::kGemm), + problem_size(problem_size_), + batch_count(1), + ref_A(ref_A_), + ref_B(ref_B_), + ref_alpha_col(ref_alpha_col_), + ref_alpha_row(ref_alpha_row_), + ref_C(ref_C_), + ref_D(ref_D_), + batch_stride_A(0), + batch_stride_B(0), + batch_stride_D(0), + epilogue_visitor(epilogue_visitor_) {} + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col; + typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row; + typename EpilogueVisitor::OutputTileIterator::Params params_C; + typename EpilogueVisitor::OutputTileIterator::Params params_D; + + GemmUniversalMode mode; + int batch_count; + int gemm_k_size; + + void* ptr_A; + void* ptr_B; + typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_col; + typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_row; + ElementC* ptr_C; + ElementC* ptr_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + + typename EpilogueVisitor::Params epilogue_visitor; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : swizzle_log_tile(0), + params_A(0), + params_B(0), + params_alpha_col(0), + params_C(0), + params_D(0), + batch_count(0), + gemm_k_size(0), + mode(cutlass::gemm::GemmUniversalMode::kGemm), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_alpha_col(nullptr), + ptr_alpha_row(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + batch_stride_A(0), + batch_stride_B(0) {} + + Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape_, int gemm_k_size_, int* workspace_) + : problem_size(args.problem_size), + swizzle_log_tile(0), + params_A(args.ref_A.layout()), + params_B(args.ref_B.layout()), + params_alpha_col(args.ref_alpha_col.layout()), + params_alpha_row(args.ref_alpha_col.layout()), + params_C(args.ref_C.layout()), + params_D(args.ref_D.layout()), + mode(args.mode), + batch_count(args.batch_count), + gemm_k_size(args.problem_size.k()), + ptr_A(args.ref_A.data()), + ptr_B(args.ref_B.data()), + ptr_alpha_col(args.ref_alpha_col.data()), + ptr_alpha_row(args.ref_alpha_row.data()), + ptr_C(args.ref_C.data()), + ptr_D(args.ref_D.data()), + batch_stride_A(args.batch_stride_A), + batch_stride_B(args.batch_stride_B), + epilogue_visitor(args.epilogue_visitor) { + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { + int const kAlignK = + const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + + if (gemm_k_size) { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + + struct { + typename Epilogue::SharedStorage epilogue; + typename EpilogueVisitor::SharedStorage visitor; + } epilogue; + }; + + public: + // + // Methods + // + + CUTLASS_DEVICE + GemmWithEpilogueVisitor() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) { + CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()"); + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::OutputTileIterator::kElementsPerAccess; + + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + + if (platform::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (platform::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (platform::is_same>::value || + platform::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (platform::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (platform::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if (platform::is_same>::value || + platform::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (platform::is_same::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } else if (platform::is_same::value) { + isCMisaligned = problem_size.m() % kAlignmentC; + } else if (platform::is_same>::value || + platform::is_same>::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + static Status can_implement(Arguments const& args) { + return can_implement(args.problem_size); + } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) { + return 0; + } + +#define SPLIT_K_ENABLED 1 + + /// Executes one GEMM + CUTLASS_DEVICE + void run_kernel_(Params const& params, SharedStorage& shared_storage) { + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA* ptr_A = static_cast(params.ptr_A); + ElementB* ptr_B = static_cast(params.ptr_B); + +#if SPLIT_K_ENABLED + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel) { + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; + } else if (params.mode == GemmUniversalMode::kArray) { + ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; + ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; + } +#endif + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B(params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, + tb_offset_B); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // assume identity swizzle + MatrixCoord threadblock_offset(threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // + // Construct the epilogue visitor + // + + bool with_bias = true; + if (params.ptr_C == nullptr) { + with_bias = false; + } + + EpilogueVisitor epilogue_visitor(params.epilogue_visitor, shared_storage.epilogue.visitor, params.problem_size.mn(), + thread_idx, warp_idx, lane_idx, params.params_alpha_col, params.params_C, + params.params_D, with_bias, true, true, params.ptr_alpha_row, params.ptr_alpha_col, + params.ptr_C, params.ptr_D, threadblock_offset, + blockIdx.y * params.problem_size.m()); + + if (params.mode == GemmUniversalMode::kGemm) { + // Indicate which position in a serial reduction the output operator is currently updating + epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) { + epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); + } + + // Construct the epilogue + Epilogue epilogue(shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(epilogue_visitor, accumulators); + } + + template + CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) { + if constexpr (platform::is_same::value) { + run_kernel_(params, shared_storage); + } else { + CUTLASS_NOT_IMPLEMENTED(); + } + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) { + run_kernel(params, shared_storage); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu new file mode 100644 index 00000000000..3e33e143c0c --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu @@ -0,0 +1,624 @@ +// Adapted from +// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_template.h +// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm89.h +// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm90.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "utils.h" + +using namespace cute; + +#if defined CUDA_VERSION && CUDA_VERSION >= 12040 +template typename EpilogueVisitor = cutlass::epilogue::threadblock::Sm80EVT, + typename ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>> +struct DeviceGemmFp8RowwiseSm89 { + static_assert(std::is_same_v, "ElementType must be FP8(e4m3)"); + + using ElementA = ElementType; + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using ElementB = ElementType; + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + using ElementC = OutElementType; + using LayoutC = cutlass::layout::RowMajor; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using ElementOutput = OutElementType; + using LayoutOutput = cutlass::layout::RowMajor; + static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; + + using ElementAccumulator = AccumElementType; + using ElementComputeEpilogue = float; + using ArchTag = cutlass::arch::Sm89; + using OperatorClass = cutlass::arch::OpClassTensorOp; + + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + // Number of epilogue stages in EVT + static constexpr int EVTEpilogueStages = 1; + + using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout; + + // Definition of EVT + using accSrc = cutlass::epilogue::threadblock::VisitorAccFetch; + + using ComputeBScale = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementComputeEpilogue, ElementComputeEpilogue, cutlass::FloatRoundStyle::round_to_nearest>; + using bScaleSrc = cutlass::epilogue::threadblock::VisitorRowBroadcast>; + using EpilogueBScale = cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeAScale = + cutlass::epilogue::threadblock::VisitorCompute; + using aScaleSrc = cutlass::epilogue::threadblock::VisitorColBroadcast>; + using EpilogueAScale = cutlass::epilogue::threadblock::Sm80EVT; + + // With bias + using biasSrc = + cutlass::epilogue::threadblock::VisitorRowBroadcast>; + using ComputeAScaleWithBias = + cutlass::epilogue::threadblock::VisitorCompute; + using EpilogueAScaleWithBias = + cutlass::epilogue::threadblock::Sm80EVT; + + using dTar = cutlass::epilogue::threadblock::VisitorAuxStore< + OutputTileThreadMap, ElementC, cutlass::FloatRoundStyle::round_to_nearest, Stride>; + using EpilogueStore = + typename cutlass::platform::conditional, + cutlass::epilogue::threadblock::Sm80EVT>::type; + + using EpilogueOp = EpilogueStore; + + using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, ElementB, LayoutB, + cutlass::ComplexTransform::kNone, AlignmentB, ElementC, LayoutC, AlignmentC, ElementAccumulator, + ElementComputeEpilogue, OperatorClass, ArchTag, CtaShape, WarpShape, InstructionShape, EpilogueOp, + ThreadblockSwizzle, Stages, FP8MathOperator, EVTEpilogueStages>::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +template +typename Gemm::Arguments prepare_sm89_fp8_args(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + using ElementT = typename Gemm::ElementA; + using ElementOutput = typename Gemm::ElementD; + using ElementComputeEpilogue = float; + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + ElementT const* ptr_a = reinterpret_cast(a.data_ptr()); + ElementT const* ptr_b = reinterpret_cast(b.data_ptr()); + ElementOutput const* ptr_bias = nullptr; + if constexpr (WithBias) { + TORCH_CHECK(bias.has_value()) + ptr_bias = reinterpret_cast(bias.value().data_ptr()); + } + ElementOutput* ptr_d = reinterpret_cast(out.data_ptr()); + ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast(scales_a.data_ptr()); + ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast(scales_b.data_ptr()); + + typename Gemm::Arguments args(cutlass::gemm::GemmUniversalMode::kGemm, // Mode + {m, n, k}, // Problem size + 1, // Split-k factor + {}, // Epilogue args + ptr_a, // a pointer + ptr_b, // b pointer + nullptr, // c pointer (unused) + nullptr, // d pointer (unused) + m * k, // batch stride a (unused) + n * k, // batch stride b (unused) + m * n, // batch stride c (unused) + m * n, // batch stride d (unused) + lda, // stride a + ldb, // stride b + ldc, // stride c (unused) + ldc); // stride d (unused) + if constexpr (WithBias) { + args.epilogue = {{ + { + {}, // Accumulator + {ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}}, + {} // Multiplies + }, + {ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}}, + {ptr_bias, ElementOutput(0), {_0{}, _1{}, _0{}}}, + {} // Multiplies + }, + {ptr_d, {n, _1{}, _0{}}}}; + } else { + args.epilogue = {{ + { + {}, // Accumulator + {ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}}, + {} // Multiplies + }, + {ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}}, + {} // Multiplies + }, + {ptr_d, {n, _1{}, _0{}}}}; + } + + return args; +} + +template +void launch_sm89_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + auto args = prepare_sm89_fp8_args(out, a, b, scales_a, scales_b, bias); + Gemm gemm_op; + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + auto can_implement = gemm_op.can_implement(args); + TORCH_CHECK(can_implement == cutlass::Status::kSuccess) + + auto status = gemm_op(args, workspace.data_ptr(), stream); + TORCH_CHECK(status == cutlass::Status::kSuccess) +} + +template +void sm89_fp8_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + using ElementInput = cutlass::float_e4m3_t; + using ElementOutput = OutType; + using AccumElementType = float; + if (bias) { + using Gemm = typename DeviceGemmFp8RowwiseSm89::Gemm; + return launch_sm89_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } else { + using Gemm = typename DeviceGemmFp8RowwiseSm89::Gemm; + return launch_sm89_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } +} + +template +void sm89_fp8_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + uint32_t const m = a.size(0); + uint32_t const n = out.size(1); + + if (m == 1) { + if (n <= 8192) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } + } else if (m <= 16) { + // M in (1, 16] + if (n <= 8192) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias); + } else if (n <= 16384) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); + } + } else if (m <= 64) { + // M in (16, 64] + if (n <= 16384) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); + } + } else if (m <= 128) { + // M in (64, 128] + if (n <= 8192) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<32, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias); + } else if (n <= 16384) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<32, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } + } else if (m <= 256) { + // M in (128, 256] + if (n <= 8192) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<64, 32, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } else if (n <= 16384) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<64, 32, 64>, 7>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<64, 32, 128>, 4>(out, a, b, scales_a, scales_b, bias); + } + } else if (m <= 512) { + // M in (256, 512) + if (n <= 16384) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<64, 32, 64>, 4>(out, a, b, scales_a, scales_b, bias); + } + } else { + // M in (512, inf) + if (n <= 8192) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<64, 32, 64>, 3>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias); + } + } +} +#endif + +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 +template +struct DeviceGemmFp8RowwiseSm90 { + static_assert(std::is_same_v, "ElementType must be FP8(e4m3)"); + + // A matrix configuration + using ElementA = ElementType; // Element type for A matrix operand + using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand + static constexpr int AlignmentA = + 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A + // matrix in units of elements (up to 16 bytes) + + // B matrix configuration + using ElementB = ElementType; // Element type for B matrix operand + using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand + static constexpr int AlignmentB = + 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B + // matrix in units of elements (up to 16 bytes) + + // C/D matrix configuration + using ElementC = void; // Element type for C matrix operands + using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands + static constexpr int AlignmentC = + 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrices in + // units of elements (up to 16 bytes) + + // Output matrix configuration + using ElementOutput = OutElementType; // Element type for output matrix operands + using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output matrix operands + static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; + + // // Auxiliary matrix configuration and other fusion types + // using ElementBias = float; + + // Multiply-accumulate blocking/pipelining details + using ElementAccumulator = AccumElementType; // Element type for internal accumulation + using ElementCompute = float; // Element type for compute + using ElementComputeEpilogue = float; + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + using TileShape = CTAShape; // Threadblock-level tile size + + static constexpr bool PONG = false; + static constexpr bool FAST_ACCUM = true; + static constexpr bool USE_BIAS = false; + + using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized + // based on the tile size + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default + // setting in the Collective Builder + // Implement rowwise scaling epilogue. + using XScale = + cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue, + cute::Stride, cute::Int<0>, cute::Int<0>>>; + + using WScale = + cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue, + cute::Stride, cute::Int<1>, cute::Int<0>>>; + + using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput, + cute::Stride, cute::Int<1>, cute::Int<0>>>; + + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute; + + using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute; + + using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT; + + // With bias + using ComputeWithBias = + cutlass::epilogue::fusion::Sm90Compute; + using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT; + + using EpilogueEVT = typename cutlass::platform::conditional::type; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementComputeEpilogue, ElementC, LayoutC, + AlignmentC, ElementOutput, LayoutOutput, AlignmentOutput, cutlass::epilogue::TmaWarpSpecialized, + EpilogueEVT>::CollectiveOp; + + using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized; + using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using FastDefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using FastPongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + + using SlowAccum = DefaultSchedule; + using FastAccum = FastPongSchedule; // Default apply Pingpong + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopScheduleType>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal, // Indicates ProblemShape + CollectiveMainloop, CollectiveEpilogue, TileSchedulerType>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +template +typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + using ElementT = typename Gemm::ElementA; + using ElementOutput = typename Gemm::ElementD; + using ElementComputeEpilogue = float; + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + ElementT const* ptr_a = reinterpret_cast(a.data_ptr()); + ElementT const* ptr_b = reinterpret_cast(b.data_ptr()); + ElementOutput const* ptr_bias = nullptr; + if constexpr (WithBias) { + TORCH_CHECK(bias.has_value()) + ptr_bias = reinterpret_cast(bias.value().data_ptr()); + } + ElementOutput* ptr_d = reinterpret_cast(out.data_ptr()); + ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast(scales_a.data_ptr()); + ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast(scales_b.data_ptr()); + + StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, make_shape(m, k, 1)); + StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1)); + StrideC stride_c; + StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1)); + typename Gemm::Arguments args = {cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, 1}, + {ptr_a, stride_a, ptr_b, stride_b}, + {{}, // epilogue.thread + nullptr, + stride_c, + ptr_d, + stride_d}}; + if constexpr (WithBias) { + args.epilogue.thread = { + {ptr_scales_a}, + { + {ptr_scales_b}, + {}, // Accumulator + {} // Multiplies + }, + {ptr_bias}, + {}, // Multiplies + }; + } else { + args.epilogue.thread = { + {ptr_scales_a}, + { + {ptr_scales_b}, + {}, // Accumulator + {} // Multiplies + }, + {}, // Multiplies + }; + } + + return args; +} + +template +void launch_sm90_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + auto args = prepare_sm90_fp8_args(out, a, b, scales_a, scales_b, bias); + Gemm gemm_op; + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + auto can_implement = gemm_op.can_implement(args); + TORCH_CHECK(can_implement == cutlass::Status::kSuccess) + + auto status = gemm_op.run(args, workspace.data_ptr(), stream); + + TORCH_CHECK(status == cutlass::Status::kSuccess) +} + +template +void sm90_fp8_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias, bool fast_accum = true, + bool use_persistent = false) { + using ElementInput = cutlass::float_e4m3_t; + using ElementOutput = OutType; + using AccumElementType = float; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized; + + if (bias) { + using Gemm = + typename DeviceGemmFp8RowwiseSm90::Gemm; + return launch_sm90_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } else { + using Gemm = + typename DeviceGemmFp8RowwiseSm90::Gemm; + return launch_sm90_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } +} + +template +void sm90_fp8_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + uint32_t const m = a.size(0); + using FastPingpongScheduler = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using FastBasicScheduler = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using PersistentTileScheduler = cutlass::gemm::PersistentScheduler; + using BasicTileScheduler = void; + if (m <= 1) { + return sm90_fp8_dispatch_bias, Shape<_1, _8, _1>, FastBasicScheduler, + BasicTileScheduler>(out, a, b, scales_a, scales_b, bias); + } + if (m <= 64) { + // m in [1, 64] + return sm90_fp8_dispatch_bias, Shape<_1, _4, _1>, FastPingpongScheduler, + PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); + } else if (m <= 256) { + // m in (64, 256] + return sm90_fp8_dispatch_bias, Shape<_1, _1, _1>, FastPingpongScheduler, + PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); + } else if (m <= 1024) { + // m in (256, 1024] + return sm90_fp8_dispatch_bias, Shape<_1, _1, _1>, FastPingpongScheduler, + PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); + } else { + // m in (1024, inf) + return sm90_fp8_dispatch_bias, Shape<_2, _1, _1>, FastPingpongScheduler, + PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); + } +} +#endif + +torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, + const torch::Tensor& scales_b, const torch::Dtype& out_dtype, + const c10::optional& bias) { + TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor"); + TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor"); + TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor"); + TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor"); + TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor"); + TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor"); + TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied"); + + TORCH_CHECK((mat_a.size(1) * mat_a.element_size()) % 16 == 0, + "mat_a must be multiple of 16 bytes for memory alignment"); + TORCH_CHECK((mat_b.size(0) * mat_b.element_size()) % 16 == 0, + "mat_b must be multiple of 16 bytes for memory alignment"); + TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn"); + TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn"); + TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16"); + + TORCH_CHECK(scales_a.numel() == mat_a.size(0), "size of scales_a is not matched"); + TORCH_CHECK(scales_b.numel() == mat_b.size(1), "size of scales_b is not matched"); + TORCH_CHECK(scales_a.is_contiguous(), "scales_a must be contiguous"); + TORCH_CHECK(scales_b.is_contiguous(), "scales_b msut be contiguous"); + TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be Float32"); + TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be Float32"); + + if (bias) { + TORCH_CHECK(bias->numel() == mat_b.size(1), "size of bias is not matched"); + TORCH_CHECK(bias->is_contiguous(), "bias must be contiguous"); + TORCH_CHECK(bias->dtype() == out_dtype, "bias dtype must match output dtype"); + } + + torch::Tensor out = torch::empty({mat_a.size(0), mat_b.size(1)}, mat_a.options().dtype(out_dtype)); + TORCH_CHECK((out.size(1) * out.element_size()) % 16 == 0, "out must be multiple of 16 bytes for memory alignment"); + + auto sm_version = getSMVersion(); + +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 + if (sm_version >= 90) { + if (out_dtype == torch::kBFloat16) { + sm90_fp8_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + sm90_fp8_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } + return out; + } +#endif + +#if defined CUDA_VERSION && CUDA_VERSION >= 12040 + if (sm_version == 89) { + if (out_dtype == torch::kBFloat16) { + sm89_fp8_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + sm89_fp8_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } + return out; + } +#endif + + TORCH_CHECK_NOT_IMPLEMENTED(false, "No implemented fp8_scaled_mm for current compute capability: ", sm_version); +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu new file mode 100644 index 00000000000..f0f3a51744e --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu @@ -0,0 +1,35 @@ +#include + +#include + +#include "utils.h" + +using namespace flashinfer; + +void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps) { + CHECK_INPUT(input); + CHECK_INPUT(residual); + CHECK_INPUT(weight); + auto device = input.device(); + CHECK_EQ(residual.device(), device); + CHECK_EQ(weight.device(), device); + CHECK_DIM(2, input); // input: (batch_size, hidden_size) + CHECK_DIM(2, residual); // residual: (batch_size, hidden_size) + CHECK_DIM(1, weight); // weight: (hidden_size) + CHECK_EQ(input.size(0), residual.size(0)); + CHECK_EQ(input.size(1), residual.size(1)); + CHECK_EQ(input.size(1), weight.size(0)); + unsigned int batch_size = input.size(0); + unsigned int hidden_size = input.size(1); + + cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream(); + // support float16, bfloat16 and float32 + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { + cudaError_t status = norm::FusedAddRMSNorm( + static_cast(input.data_ptr()), static_cast(residual.data_ptr()), + static_cast(weight.data_ptr()), batch_size, hidden_size, eps, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status))); + return true; + }); +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu new file mode 100644 index 00000000000..c77851c32b6 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu @@ -0,0 +1,428 @@ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h" +#include "cutlass_extensions/gemm/gemm_universal_base_compat.h" +#include "cutlass_extensions/gemm/gemm_with_epilogue_visitor.h" +#include "utils.h" + +using namespace cute; + +template +void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + using ElementAccumulator = int32_t; + using ElementCompute = float; + using ElementInputA = int8_t; + using ElementInputB = int8_t; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>; + + using DefaultGemmConf = cutlass::gemm::device::DefaultGemmConfiguration; + using EpilogueOutputOp = typename DefaultGemmConf::EpilogueOutputOp; + + using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemm< + ElementInputA, cutlass::layout::RowMajor, DefaultGemmConf::kAlignmentA, ElementInputB, + cutlass::layout::ColumnMajor, DefaultGemmConf::kAlignmentB, ElementOutput, cutlass::layout::RowMajor, + ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp, + ThreadblockSwizzle, NumStages, true, typename DefaultGemmConf::Operator>::GemmKernel; + + using AlphaColTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< + cutlass::epilogue::threadblock::OutputTileOptimalThreadMap< + typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Shape, + typename GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::Count, + GemmKernel_::Epilogue::OutputTileIterator::ThreadMap::kThreads, + GemmKernel_::Epilogue::OutputTileIterator::kElementsPerAccess, cutlass::sizeof_bits::value>, + ElementCompute>; + + using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorPerRowPerCol< + ThreadblockShape, GemmKernel_::kThreadCount, AlphaColTileIterator, + typename GemmKernel_::Epilogue::OutputTileIterator, ElementAccumulator, ElementCompute, EpilogueOutputOp>; + + using Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue< + EpilogueVisitor, typename GemmKernel_::Epilogue>::Epilogue; + + using GemmKernel = + cutlass::gemm::kernel::GemmWithEpilogueVisitor; + + using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat; + + Gemm gemm_op; + + int m = mat_a.size(0); + int k = mat_a.size(1); + int n = mat_b.size(1); + + auto a_ptr = static_cast(mat_a.data_ptr()); + auto b_ptr = static_cast(mat_b.data_ptr()); + auto o_ptr = static_cast(out.data_ptr()); + + auto a_s_ptr = static_cast(scales_a.data_ptr()); + auto b_s_ptr = static_cast(scales_b.data_ptr()); + + int64_t lda = mat_a.stride(0); + int64_t ldb = mat_b.stride(1); + int64_t ldd = out.stride(0); + + ElementOutput* bias_ptr = nullptr; + int64_t ldc = 0; + if (bias) { + bias_ptr = static_cast(bias->data_ptr()); + } + + typename EpilogueOutputOp::Params linearScalingParams; + typename EpilogueVisitor::Arguments visitor_args{linearScalingParams}; + + typename Gemm::Arguments args{{m, n, k}, {a_ptr, lda}, {b_ptr, ldb}, {b_s_ptr, 0}, + {a_s_ptr, 0}, {bias_ptr, ldc}, {o_ptr, ldd}, visitor_args}; + + auto workspace = torch::empty(gemm_op.get_workspace_size(args), + torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device())); + + auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device()); + + auto can_implement = gemm_op.can_implement(args); + TORCH_CHECK(can_implement == cutlass::Status::kSuccess, + "gemm cannot implement, error: ", cutlassGetStatusString(can_implement)); + + auto status = gemm_op(args, workspace.data_ptr(), stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status)); +} + +template +void sm75_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + int m = mat_a.size(0); + if (m <= 32) { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a, + scales_b, bias); + } else if (m <= 64) { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a, + scales_b, bias); + } else if (m <= 256) { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a, + scales_b, bias); + } else { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 2>(out, mat_a, mat_b, scales_a, + scales_b, bias); + } +} + +template +void sm80_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + int m = mat_a.size(0); + int n = mat_b.size(1); + if (m <= 16) { + if (n <= 4096) { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<16, 64, 64>, InstructionShape, 6>(out, mat_a, mat_b, scales_a, + scales_b, bias); + } else { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<16, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, + scales_b, bias); + } + } else if (m <= 32) { + if (n <= 4096) { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 6>(out, mat_a, mat_b, scales_a, + scales_b, bias); + } else { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, + scales_b, bias); + } + } else if (m <= 64) { + if (n <= 4096) { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, + scales_b, bias); + } else { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, + scales_b, bias); + } + } else if (m <= 128 && n < 8192) { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, + scales_b, bias); + } else { + cutlass_int8_scaled_mm, + cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a, + scales_b, bias); + } +} + +template +void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + using ArchTag = cutlass::arch::Sm90; + + using ElementAccumulator = int32_t; + using ElementCompute = float; + using ElementInputA = int8_t; + using ElementInputB = int8_t; + + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized; + using TileSchedulerType = cutlass::gemm::PersistentScheduler; + + using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementCompute, ElementCompute, + Stride, Int<0>, Int<0>>>; + + using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementCompute, ElementCompute, + Stride, Int<1>, Int<0>>>; + + using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput, + Stride, Int<1>, Int<0>>>; + + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + // Scale + using Compute0 = cutlass::epilogue::fusion::Sm90Compute; + + using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute; + + using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT; + + // With bias + using ComputeWithBias = cutlass::epilogue::fusion::Sm90Compute; + using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT; + + using EpilogueEVT = typename cutlass::platform::conditional::type; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, ElementOutput, cutlass::layout::RowMajor, AlignmentC, ElementOutput, + cutlass::layout::RowMajor, AlignmentOutput, EpilogueScheduleType, EpilogueEVT>::CollectiveOp; + + using Stages = cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementInputA, cutlass::layout::RowMajor, AlignmentA, ElementInputB, + cutlass::layout::ColumnMajor, AlignmentB, ElementAccumulator, TileShape, ClusterShape, Stages, + MainloopScheduleType>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal, // Indicates ProblemShape + CollectiveMainloop, CollectiveEpilogue, TileSchedulerType>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + Gemm gemm_op; + + int m = mat_a.size(0); + int k = mat_a.size(1); + int n = mat_b.size(1); + + auto a_ptr = static_cast(mat_a.data_ptr()); + auto b_ptr = static_cast(mat_b.data_ptr()); + auto o_ptr = static_cast(out.data_ptr()); + + auto a_s_ptr = static_cast(scales_a.data_ptr()); + auto b_s_ptr = static_cast(scales_b.data_ptr()); + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, make_shape(m, k, 1)); + StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1)); + StrideC stride_c; + StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1)); + + typename Gemm::Arguments args = {cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, 1}, + {a_ptr, stride_a, b_ptr, stride_b}, + {{}, // epilogue.thread + nullptr, + stride_c, + o_ptr, + stride_d}}; + + if constexpr (WithBias) { + ElementOutput* bias_ptr = static_cast(bias->data_ptr()); + args.epilogue.thread = { + {a_s_ptr}, + {{b_s_ptr}, {}, {}}, + {bias_ptr}, + {}, + }; + } else { + args.epilogue.thread = { + {a_s_ptr}, + {{b_s_ptr}, {}, {}}, + {}, + }; + } + + auto workspace = torch::empty(gemm_op.get_workspace_size(args), + torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device())); + + auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device()); + + auto can_implement = gemm_op.can_implement(args); + TORCH_CHECK(can_implement == cutlass::Status::kSuccess, + "gemm cannot implement, error: ", cutlassGetStatusString(can_implement)); + + auto status = gemm_op(args, workspace.data_ptr(), stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status)); +} + +template +void sm90_dispatch_bias(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + if (bias) { + cutlass_int8_scaled_mm_sm90( + out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + cutlass_int8_scaled_mm_sm90( + out, mat_a, mat_b, scales_a, scales_b, bias); + } +} + +template +void sm90_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + int m = mat_a.size(0); + int n = mat_b.size(1); + if (m <= 32) { + if (n < 8192) { + return sm90_dispatch_bias, Shape<_1, _8, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + return sm90_dispatch_bias, Shape<_1, _8, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + } + } else if (m <= 64) { + if (n < 8192) { + return sm90_dispatch_bias, Shape<_1, _4, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + return sm90_dispatch_bias, Shape<_1, _1, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + } + } else if (m <= 128) { + if (n <= 4096) { + return sm90_dispatch_bias, Shape<_2, _1, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + return sm90_dispatch_bias, Shape<_2, _1, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + } + } else { + return sm90_dispatch_bias, Shape<_2, _1, _1>, + cutlass::gemm::KernelTmaWarpSpecializedPingpong>(out, mat_a, mat_b, scales_a, scales_b, + bias); + } +} + +torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, + const torch::Tensor& scales_b, const torch::Dtype& out_dtype, + const c10::optional& bias) { + TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor"); + TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor"); + TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor"); + TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor"); + TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor"); + TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor"); + TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied"); + TORCH_CHECK(mat_a.size(1) % 16 == 0, "mat_a.size(1) must be multiple of 16 for memory alignment"); + TORCH_CHECK(mat_b.size(0) % 16 == 0, "mat_b.size(0) must be multiple of 16 for memory alignment"); + TORCH_CHECK(mat_b.size(1) % 8 == 0, "mat_b.size(1) must be multiple of 8 for memory alignment"); // out.stride(0) + TORCH_CHECK(mat_a.scalar_type() == torch::kInt8, "mat_a must be Int8"); + TORCH_CHECK(mat_b.scalar_type() == torch::kInt8, "mat_b must be Int8"); + TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16"); + + TORCH_CHECK(scales_a.numel() == mat_a.size(0), "size of scales_a is not matched"); + TORCH_CHECK(scales_b.numel() == mat_b.size(1), "size of scales_b is not matched"); + TORCH_CHECK(scales_a.is_contiguous(), "scales_a must be contiguous"); + TORCH_CHECK(scales_b.is_contiguous(), "scales_b msut be contiguous"); + TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be Float32"); + TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be Float32"); + + if (bias) { + TORCH_CHECK(bias->numel() == mat_b.size(1), "size of bias is not matched"); + TORCH_CHECK(bias->is_contiguous(), "bias must be contiguous"); + TORCH_CHECK(bias->dtype() == out_dtype, "bias dtype must match output dtype"); + } + + torch::Tensor out = torch::empty({mat_a.size(0), mat_b.size(1)}, mat_a.options().dtype(out_dtype)); + + auto sm_version = getSMVersion(); + + if (sm_version >= 75 && sm_version < 80) { + TORCH_CHECK(out_dtype == torch::kHalf, "out_dtype must be Half for SM75"); + sm75_dispatch_shape>( + out, mat_a, mat_b, scales_a, scales_b, bias); + } else if (sm_version >= 80 && sm_version < 90) { + if (out_dtype == torch::kBFloat16) { + sm80_dispatch_shape>( + out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + sm80_dispatch_shape>( + out, mat_a, mat_b, scales_a, scales_b, bias); + } + } else if (sm_version == 90) { +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 + // cutlass 3.x + if (out_dtype == torch::kBFloat16) { + sm90_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + sm90_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } +#else + // fallback to cutlass 2.x + if (out_dtype == torch::kBFloat16) { + sm80_dispatch_shape>( + out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + sm80_dispatch_shape>( + out, mat_a, mat_b, scales_a, scales_b, bias); + } +#endif + } else { + TORCH_CHECK_NOT_IMPLEMENTED(false, "No implemented int8_scaled_mm for current compute capability."); + } + + return out; +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu new file mode 100644 index 00000000000..e62a154cb18 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu @@ -0,0 +1,118 @@ +#include +#include +#include +#include +#include +#include + +#define THREADS_PER_BLOCK 128 + +template +__global__ void lightning_attention_decode_kernel(const T* __restrict__ q, // [b, h, 1, d] + const T* __restrict__ k, // [b, h, 1, d] + const T* __restrict__ v, // [b, h, 1, e] + const float* __restrict__ past_kv, // [b, h, d, e] + const float* __restrict__ slope, // [h, 1, 1] + T* __restrict__ output, // [b, h, 1, e] + float* __restrict__ new_kv, // [b, h, d, e] + const int batch_size, const int num_heads, const int qk_dim, + const int v_dim) { + extern __shared__ char smem[]; + T* q_shared = reinterpret_cast(smem); + T* k_shared = reinterpret_cast(smem + qk_dim * sizeof(T)); + T* v_shared = reinterpret_cast(smem + 2 * qk_dim * sizeof(T)); + float* new_kv_shared = reinterpret_cast(smem + (2 * qk_dim + v_dim) * sizeof(T)); + T* output_shared = + reinterpret_cast(smem + (2 * qk_dim + v_dim) * sizeof(T) + qk_dim * (v_dim + 1) * sizeof(float)); + + const int32_t tid = threadIdx.x; + const int32_t current_head = blockIdx.x; + const int32_t b = current_head / num_heads; + const int32_t h = current_head % num_heads; + + if (b >= batch_size) return; + + const int32_t qk_offset = b * num_heads * qk_dim + h * qk_dim; + const int32_t v_offset = b * num_heads * v_dim + h * v_dim; + const int32_t kv_offset = b * num_heads * qk_dim * v_dim + h * qk_dim * v_dim; + + for (int d = tid; d < qk_dim; d += blockDim.x) { + q_shared[d] = q[qk_offset + d]; + k_shared[d] = k[qk_offset + d]; + } + for (int e = tid; e < v_dim; e += blockDim.x) { + v_shared[e] = v[v_offset + e]; + } + + __syncthreads(); + + const float ratio = expf(-1.0f * slope[h]); + + for (int d = tid; d < qk_dim; d += blockDim.x) { + T k_val = k_shared[d]; + for (int e = 0; e < v_dim; ++e) { + int past_kv_idx = kv_offset + d * v_dim + e; + T v_val = v_shared[e]; + float new_val = ratio * past_kv[past_kv_idx] + k_val * v_val; + int shared_idx = d * (v_dim + 1) + e; + new_kv_shared[shared_idx] = new_val; + } + } + + __syncthreads(); + + for (int idx = tid; idx < qk_dim * v_dim; idx += blockDim.x) { + int d = idx / v_dim; + int e = idx % v_dim; + int shared_idx = d * (v_dim + 1) + e; + int global_idx = kv_offset + idx; + new_kv[global_idx] = new_kv_shared[shared_idx]; + } + + __syncthreads(); + + for (int e = tid; e < v_dim; e += blockDim.x) { + float sum = 0.0f; + for (int d = 0; d < qk_dim; ++d) { + int shared_idx = d * (v_dim + 1) + e; + sum += q_shared[d] * new_kv_shared[shared_idx]; + } + output_shared[e] = static_cast(sum); + } + + __syncthreads(); + + if (tid == 0) { + for (int e = 0; e < v_dim; ++e) { + output[v_offset + e] = output_shared[e]; + } + } +} + +void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, + const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output, + torch::Tensor new_kv) { + TORCH_CHECK(q.is_contiguous(), "q must be contiguous"); + TORCH_CHECK(k.is_contiguous(), "k must be contiguous"); + TORCH_CHECK(v.is_contiguous(), "v must be contiguous"); + TORCH_CHECK(past_kv.is_contiguous(), "past_kv must be contiguous"); + + auto batch_size = q.size(0); + auto num_heads = q.size(1); + auto qk_dim = q.size(3); + auto v_dim = v.size(3); + + dim3 block(THREADS_PER_BLOCK); + dim3 grid(batch_size * num_heads); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, q.scalar_type(), "lightning_attention_decode_kernel", ([&] { + size_t smem_size = (2 * qk_dim + 2 * v_dim) * sizeof(scalar_t) + qk_dim * (v_dim + 1) * sizeof(float); + lightning_attention_decode_kernel<<>>( + q.data_ptr(), k.data_ptr(), v.data_ptr(), past_kv.data_ptr(), + slope.data_ptr(), output.data_ptr(), new_kv.data_ptr(), batch_size, num_heads, + qk_dim, v_dim); + })); +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu new file mode 100644 index 00000000000..19e9850b51a --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu @@ -0,0 +1,101 @@ +// Adapted from https://github.com/vllm-project/vllm/blob/v0.6.5/csrc/moe/moe_align_sum_kernels.cu + +#include +#include +#include +#include + +#include + +#define WARP_SIZE 32 + +#define DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) + +#define CEILDIV(x, y) (((x) + (y)-1) / (y)) + +#define DISPATCH_CASE_INTEGRAL_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) + +#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) + +__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) { + return row * total_col + col; +} + +template +__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, + int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts, + int32_t block_size, size_t numel, int32_t* cumsum) { + __shared__ int32_t shared_counts[32][8]; + __shared__ int32_t local_offsets[256]; + + const int warp_id = threadIdx.x / WARP_SIZE; + const int experts_per_warp = 8; + const int my_expert_start = warp_id * experts_per_warp; + + for (int i = 0; i < experts_per_warp; ++i) { + if (my_expert_start + i < num_experts) { + shared_counts[warp_id][i] = 0; + } + } + + const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); + const size_t start_idx = threadIdx.x * tokens_per_thread; + + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + int expert_id = topk_ids[i]; + int warp_idx = expert_id / experts_per_warp; + int expert_offset = expert_id % experts_per_warp; + atomicAdd(&shared_counts[warp_idx][expert_offset], 1); + } + + __syncthreads(); + + if (threadIdx.x == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + int expert_count = 0; + int warp_idx = (i - 1) / experts_per_warp; + int expert_offset = (i - 1) % experts_per_warp; + expert_count = shared_counts[warp_idx][expert_offset]; + + cumsum[i] = cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size; + } + *total_tokens_post_pad = cumsum[num_experts]; + } + + __syncthreads(); + + if (threadIdx.x < num_experts) { + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) { + expert_ids[i / block_size] = threadIdx.x; + } + local_offsets[threadIdx.x] = cumsum[threadIdx.x]; + } + + __syncthreads(); + + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + int32_t expert_id = topk_ids[i]; + int32_t rank_post_pad = atomicAdd(&local_offsets[expert_id], 1); + sorted_token_ids[rank_post_pad] = i; + } +} + +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, + torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, + torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer) { + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + auto kernel = moe_align_block_size_kernel; + kernel<<<1, 1024, 0, stream>>>(topk_ids.data_ptr(), sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), + num_experts, block_size, topk_ids.numel(), cumsum_buffer.data_ptr()); + }); +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu new file mode 100644 index 00000000000..2ee0c98c91e --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu @@ -0,0 +1,515 @@ +// reference: +// https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.cu +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include +#include +#include +#include + +#include "trt_reduce_internal.cuh" +#include "utils.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ void st_flag_release(uint32_t const& flag, uint32_t* flag_addr) { + asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static inline __device__ uint32_t ld_flag_acquire(uint32_t* flag_addr) { + uint32_t flag; + asm volatile("ld.global.acquire.sys.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr)); + return flag; +} + +static inline __device__ void st_flag_volatile(uint32_t const& flag, uint32_t* flag_addr) { + asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); +} + +static inline __device__ uint32_t ld_flag_volatile(uint32_t* flag_addr) { + uint32_t flag; + asm volatile("ld.volatile.global.u32 %0, [%1];" : "=r"(flag) : "l"(flag_addr)); + return flag; +} + +namespace trt_llm { +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Type Converter that packs data format to 128 bits data type +// +using PackedFloat = union { + int4 packed; + float unpacked[4]; +}; + +using PackedHalf = union { + int4 packed; + half2 unpacked[4]; +}; + +template +struct PackedOn16Bytes {}; + +template <> +struct PackedOn16Bytes { + using Type = PackedFloat; +}; + +template <> +struct PackedOn16Bytes { + using Type = PackedHalf; +}; + +#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) +using PackedBFloat16 = union { + int4 packed; + __nv_bfloat162 unpacked[4]; +}; + +template <> +struct PackedOn16Bytes<__nv_bfloat16> { + using Type = PackedBFloat16; +}; +#endif + +// add two 128b data +template +inline __device__ int4 add128b(T& a, T& b) { + T c; + c.unpacked[0] = a.unpacked[0] + b.unpacked[0]; + c.unpacked[1] = a.unpacked[1] + b.unpacked[1]; + c.unpacked[2] = a.unpacked[2] + b.unpacked[2]; + c.unpacked[3] = a.unpacked[3] + b.unpacked[3]; + return c.packed; +} + +__inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank, + size_t const world_size, int const tidx, int const bidx) { + // After this function, at least one block in each GPU has reached the barrier + if (tidx < world_size) { + // we can think of signals having the shape [world_size, world_size] + // Dimension 0 is the "listening" dimension, dimension 1 is "emitting" dimension + + // Block 0 broadcasts its flag (local_rank on emitting dimension) to all receivers + size_t offset = (flag % 2) ? world_size : 0; + + if (bidx == 0) { + st_flag_release(flag, signals[tidx] + offset + local_rank); + } + + // All blocks check that corresponding block 0 on other GPUs have set the flag + // No deadlock because block #0 is always the first block started + uint32_t* peer_barrier_d = signals[local_rank] + offset + tidx; + while (ld_flag_acquire(peer_barrier_d) != flag) { + } + } + + __syncthreads(); +} + +template +__inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank, + size_t const world_size, int const tidx, int const bidx, int const grid_size) { + if constexpr (!start) { + __syncthreads(); + } + // After this function, the block of id == bidx of each GPU has reached the barrier + if (tidx < world_size) { + // we can think of signals having the shape [world_size, 2, num_blocks, world_size] + // (+ an offset on dim 2 to account for flags used in multi_gpu_barrier) + // Dimension 0 is the "listening" dimension, dimension 3 is "emitting" dimension + + // Block broadcast its flag (local_rank on emitting dimension) to all receivers + uint32_t flag_block_offset = world_size + bidx * world_size; + + flag_block_offset += (grid_size + 1) * world_size * (flag % 2); + + uint32_t* peer_barrier_d = signals[local_rank] + flag_block_offset + tidx; + // Blocks check that corresponding blocks on other GPUs have also set the flag + if constexpr (need_fence) { + st_flag_release(flag, signals[tidx] + flag_block_offset + local_rank); + while (ld_flag_acquire(peer_barrier_d) != flag) { + } + } else { + st_flag_volatile(flag, signals[tidx] + flag_block_offset + local_rank); + while (ld_flag_volatile(peer_barrier_d) != flag) { + } + } + } + + __syncthreads(); +} + +template +static __global__ void __launch_bounds__(512, 1) oneShotAllReduceKernel(AllReduceParams params) { + // Suppose that two GPUs participate in the AR exchange, and we start four blocks. + // The message is partitioned into chunks as detailed below: + // message + // |-------------------| + // GPU 0 | B0 | B1 | B2 | B3 | + // GPU 1 | B0 | B1 | B2 | B3 | + // + // Here the step-by-step behavior of one block: + // 1. B0 copies the chunk it is responsible for, from local_input to shareable buffer + // 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier) + // 3. B0 on GPU 0 pull and sum the chunk from GPU 1, writes the result to local_output + // + // With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2. + // We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready + // + // With PUSH_MODE, we consider that the shared buffer is of size: + // params.peer_comm_buffer_ptrs: [world_size, world_size, message_size] + // + // Here the step-by-step behavior of one block: + // 1. B0 push the chunk is it responsible for into all other GPUs: + // params.peer_comm_buffer_ptrs[:, local_gpu, B0 slice] + // 2. block sync so the block is shared by other GPUs + // 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice] + + int const bidx = blockIdx.x; + int const tidx = threadIdx.x; + int const grid_size = gridDim.x; + + // The number of elements packed into one for comms + static constexpr int NUM_ELTS = 16 / sizeof(T); + + // Packed data type for comms + using PackedStruct = typename PackedOn16Bytes::Type; + + // The source pointers. Distributed round-robin for the different warps. + auto peer_comm_buffer_ptrs = params.peer_comm_buffer_ptrs->ptrs; + T* local_shared_buffer = reinterpret_cast(peer_comm_buffer_ptrs[params.local_rank]); + // Start and end offsets of the thread + size_t chunk_start = bidx * params.elts_per_block + tidx * NUM_ELTS; + size_t chunk_end = std::min((bidx + 1) * params.elts_per_block, params.elts_per_rank); + + if constexpr (COPY_INPUT) { + T const* local_input_buffer = reinterpret_cast(params.local_input_buffer_ptr); + // Copy from local buffer to shareable buffer + for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) { + *reinterpret_cast(&local_shared_buffer[iter_offset]) = + *reinterpret_cast(&local_input_buffer[iter_offset]); + } + } + // wait for equivalent blocks of other GPUs to have copied data to their shareable buffer + block_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, + grid_size); + + // Each block accumulates the values from the different GPUs on the same node. + for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) { + // Iterate over the different ranks/devices on the node to load the values. + PackedStruct vals[RANKS_PER_NODE]; +#pragma unroll + for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { + vals[ii].packed = *reinterpret_cast(&((T*)peer_comm_buffer_ptrs[ii])[iter_offset]); + } + + // Sum the values from the different ranks. + PackedStruct sums; + sums.packed = {0, 0, 0, 0}; +#pragma unroll + for (int rank = 0; rank < RANKS_PER_NODE; ++rank) { + // Always reduce from rank 0 to ensure stable reduce order. + sums.packed = add128b(sums, vals[rank]); + } + + // Store to the destination buffer. + *reinterpret_cast(&reinterpret_cast(params.local_output_buffer_ptr)[iter_offset]) = sums.packed; + } +} + +template +static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduceParams params) { + // Suppose that two GPUs participate in the AR exchange, and we start two blocks. + // The message is partitioned into chunks as detailed below: + // message + // |-------------------| + // |--GPU 0--|--GPU 1--| (GPU responsibility parts) + // GPU 0 | B0 | B1 | B0 | B1 | + // GPU 1 | B0 | B1 | B0 | B1 | + // + // Here the step-by-step behavior of one block: + // 1. B0 copies all chunks is it responsible for, from local_input to shareable buffer + // 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier #0) + // 3. B0 on GPU 0 gather and sum the B0 chunks from GPU 1, that are in the GPU 0 responsibility + // part (the first half of the message, see GPU responsibility row above) + // 3bis. Likewise, B0 on GPU 1 copies and sum the chunks for GPU 0, + // where GPU 1 is responsible: the second half of the message. + // 4. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier #1) + // 5. B0 writes result to local_output. It gathers each chunk from its responsible GPU. + // For example, here it reads the first chunk from GPU 0 and second chunk from GPU 1. + // + // With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2. + // We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready + // to be read. + // + // Note that compared to one-shot, one block (CTA) writes multiple input chunks and write multiple output chunks. + // However, it's only responsible for the summation of a single chunk. + // + // With PUSH_MODE, we consider that the shared buffer is of size: + // params.peer_comm_buffer_ptrs: [world_size, world_size, message_size / world_size] + // + // Here the step-by-step behavior of one block: + // 1. B0 push the chunks is it responsible for into the corresponding GPUs: + // params.peer_comm_buffer_ptrs[target_gpu, local_gpu, current B0 slice] + // 2. block sync so the blocks have been shared by other GPUs + // 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice] + // 4. block barrier (corresponding blocks have finished reduction) + // 5. pull and write on local buffer, by reading params.peer_comm_buffer_ptrs[:, 0, B0 slice] (reduction result is + // written at index 0 of 2nd dim) + + int const bidx = blockIdx.x; + int const tidx = threadIdx.x; + int const grid_size = gridDim.x; + + // The number of elements packed into one for comms + static constexpr int PACKED_ELTS = 16 / sizeof(T); + using PackedType = typename PackedOn16Bytes::Type; + + T const* local_input_buffer = reinterpret_cast(params.local_input_buffer_ptr); + auto peer_comm_buffer_ptrs = params.peer_comm_buffer_ptrs->ptrs; + T* local_shared_buffer = reinterpret_cast(peer_comm_buffer_ptrs[params.local_rank]); + T* local_output_buffer = reinterpret_cast(params.local_output_buffer_ptr); + + size_t const chunk_start = bidx * params.elts_per_block + tidx * PACKED_ELTS; + size_t const chunk_end = min(chunk_start + params.elts_per_block, params.elts_per_rank); + + T* buffers[RANKS_PER_NODE]; + T* buffers_unorder[RANKS_PER_NODE]; + int ranks[RANKS_PER_NODE]; +#pragma unroll + for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { + // A mapping of the ranks to scatter reads as much as possible + int rank = (params.local_rank + ii) % RANKS_PER_NODE; + ranks[ii] = rank; + buffers[ii] = reinterpret_cast(peer_comm_buffer_ptrs[rank]); + buffers_unorder[ii] = reinterpret_cast(peer_comm_buffer_ptrs[ii]); + } + +#if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12)) +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif +#endif + + if constexpr (COPY_INPUT) { + // Copy all blocks from local buffer to shareable buffer + for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) { +#pragma unroll + for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { + size_t offset_rank = ranks[ii] * params.elts_per_rank + local_offset; + if (offset_rank >= params.elts_total) { + continue; + } + *reinterpret_cast(&local_shared_buffer[offset_rank]) = + *reinterpret_cast(&local_input_buffer[offset_rank]); + } + } + } + block_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, + grid_size); + + // Each block accumulates the values from the different GPUs on the same node. + for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) { + size_t const responsible_block_offset = local_offset + params.rank_offset; + + // Iterate over the different ranks/devices on the node to load the values. + PackedType vals[RANKS_PER_NODE]; +#pragma unroll + for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { + vals[ii].packed = *reinterpret_cast(&buffers_unorder[ii][responsible_block_offset]); + } + + // Sum the values from the different ranks. + PackedType sums; + sums.packed = {0, 0, 0, 0}; +#pragma unroll + for (int rank = 0; rank < RANKS_PER_NODE; ++rank) { + // Always reduce from rank 0 to ensure stable reduce order. + sums.packed = add128b(sums, vals[rank]); + } + + // Store to the local buffer or tmp buffer + if constexpr (COPY_INPUT) { + *reinterpret_cast(&local_shared_buffer[responsible_block_offset]) = sums.packed; + } else { + *reinterpret_cast(¶ms.tmp_result_buffers[params.local_rank][responsible_block_offset]) = sums.packed; + } + } + + block_barrier(params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, + bidx, grid_size); + + // Gather all needed elts from other intra-node ranks + for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) { +#pragma unroll + for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { + // use round-robin gathering from other ranks + size_t offset_rank = ranks[ii] * params.elts_per_rank + local_offset; + if (offset_rank >= params.elts_total) { + continue; + } + if constexpr (COPY_INPUT) { + *reinterpret_cast(&local_output_buffer[offset_rank]) = + *reinterpret_cast(&buffers[ii][offset_rank]); + } else { + *reinterpret_cast(&local_output_buffer[offset_rank]) = + *reinterpret_cast(¶ms.tmp_result_buffers[ranks[ii]][offset_rank]); + } + } + } +#if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12)) +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline int divUp(int a, int b) { + return (a + b - 1) / b; +} + +inline int roundUp(int a, int n) { + return divUp(a, n) * n; +} + +std::tuple kernelLaunchConfig(AllReduceStrategyType algo, AllReduceParams& params, size_t elts_per_thread) { + int blocks_per_grid = 1, threads_per_block = DEFAULT_BLOCK_SIZE; + switch (algo) { + case AllReduceStrategyType::ONESHOT: { + assert(params.elts_total % elts_per_thread == 0); + size_t const total_threads = roundUp(params.elts_total / elts_per_thread, WARP_SIZE); + threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads); + blocks_per_grid = std::min(static_cast(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block)); + params.elts_per_block = roundUp(divUp(params.elts_total, blocks_per_grid), elts_per_thread); + params.elts_per_rank = params.elts_total; + break; + } + case AllReduceStrategyType::TWOSHOT: { + assert(params.elts_total % (elts_per_thread * params.ranks_per_node) == 0); + size_t const total_threads = roundUp(params.elts_total / (elts_per_thread * params.ranks_per_node), WARP_SIZE); + + /* + threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads); + blocks_per_grid = std::min(static_cast(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block)); + */ + while (total_threads % blocks_per_grid != 0 || total_threads / blocks_per_grid > DEFAULT_BLOCK_SIZE) { + blocks_per_grid += 1; + } + + threads_per_block = total_threads / blocks_per_grid; + + // NOTE: need to adjust here + if (blocks_per_grid > MAX_ALL_REDUCE_BLOCKS) { + size_t iter_factor = 1; + while (blocks_per_grid / iter_factor > MAX_ALL_REDUCE_BLOCKS || blocks_per_grid % iter_factor) { + iter_factor += 1; + } + blocks_per_grid /= iter_factor; + } + params.elts_per_rank = params.elts_total / params.ranks_per_node; + params.rank_offset = params.local_rank * params.elts_per_rank; + params.elts_per_block = roundUp(divUp(params.elts_per_rank, blocks_per_grid), elts_per_thread); + break; + } + default: + assert(false && "Algorithm not supported here."); + } + + return std::make_tuple(blocks_per_grid, threads_per_block); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void dispatchARKernels(AllReduceStrategyType algo, AllReduceParams& param, int blocks_per_grid, int threads_per_block, + cudaStream_t stream) { + switch (algo) { + case AllReduceStrategyType::ONESHOT: { + oneShotAllReduceKernel<<>>(param); + break; + } + case AllReduceStrategyType::TWOSHOT: { + twoShotAllReduceKernel<<>>(param); + break; + } + } +} + +template +void dispatchARKernelsCopyInput(AllReduceStrategyType strat, AllReduceParams& param, cudaStream_t stream) { + size_t elts_per_thread = 16 / sizeof(T); + auto [blocks_per_grid, threads_per_block] = kernelLaunchConfig(strat, param, elts_per_thread); + switch (param.ranks_per_node) { + case 2: + dispatchARKernels(strat, param, blocks_per_grid, threads_per_block, stream); + break; + case 4: + dispatchARKernels(strat, param, blocks_per_grid, threads_per_block, stream); + break; + case 6: + dispatchARKernels(strat, param, blocks_per_grid, threads_per_block, stream); + break; + case 8: + dispatchARKernels(strat, param, blocks_per_grid, threads_per_block, stream); + break; + default: + break; + } +} + +template +void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream) { + if (param.is_capturing) { + dispatchARKernelsCopyInput(strat, param, stream); + } else { + dispatchARKernelsCopyInput(strat, param, stream); + } + CHECK_CUDA_SUCCESS(cudaGetLastError()); +} + +void trtCustomAllReduce(AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, + cudaStream_t stream) { + if (params.elts_total == 0) { + return; + } + + switch (data_type) { + case at::ScalarType::Float: + invokeOneOrTwoShotAllReduceKernel(params, strat, stream); + break; + case at::ScalarType::Half: + invokeOneOrTwoShotAllReduceKernel(params, strat, stream); + break; +#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) + case at::ScalarType::BFloat16: + invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(params, strat, stream); + break; +#endif + default: + assert(false && "Unsupported data type"); + } +} +} // namespace trt_llm diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu new file mode 100644 index 00000000000..fd0483e39ee --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu @@ -0,0 +1,201 @@ +// reference: https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.h + +#include + +#include + +#include "trt_reduce_internal.cuh" +#include "utils.h" + +using namespace trt_llm; + +using fptr_t = int64_t; +using IPC_KEY = std::array; + +class AllReduceMeta { + public: + AllReduceMeta(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector& buffers, + const std::vector& tmp_result_buffers, const std::vector& barrier_in, + const std::vector& barrier_out) { + this->rank_id = (int)rank_id; + this->world_size = (int)world_size; + this->barrier_in = barrier_in; + this->barrier_out = barrier_out; + this->tmp_result_buffers = tmp_result_buffers; + + this->rank_data_base = reinterpret_cast(rank_data.data_ptr()); + RankData data; + for (int i = 0; i < world_size; i++) { + data.ptrs[i] = (void*)buffers[i]; + } + auto d_data = this->rank_data_base++; + CHECK_CUDA_SUCCESS(cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice)); + this->buffers = d_data; + } + + ~AllReduceMeta() { + for (auto [_, ptr] : ipc_handles_) { + CHECK_CUDA_SUCCESS(cudaIpcCloseMemHandle(ptr)); + } + } + + public: + int world_size; + int rank_id; + std::vector barrier_in; + std::vector barrier_out; + std::vector tmp_result_buffers; + int barrier_flag = 1; + RankData* buffers; + RankData* rank_data_base; + std::vector graph_unreg_buffers; + std::map ipc_handles_; +}; + +// Get the number of bits for a given data type. +inline int get_bits(at::ScalarType dtype) { + switch (dtype) { + case at::ScalarType::Float: + return 32; + case at::ScalarType::Half: + case at::ScalarType::BFloat16: + return 16; + default: + assert(false && "Unsupported data type"); + } +} + +// Check if customized all-reduce kernels can be applied. +inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype) { + // The customized all-reduce kernel has the following requirement(s). + return num_elements % (16 / ((get_bits(dtype) + 7) / 8)) == 0; +} + +fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector& buffers, + const std::vector& tmp_result_buffers, const std::vector& barrier_in, + const std::vector& barrier_out) { + auto m = new AllReduceMeta(rank_id, world_size, rank_data, buffers, tmp_result_buffers, barrier_in, barrier_out); + return (fptr_t)m; +} + +void dispose(fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + delete fa; +} + +std::tuple, std::vector> get_graph_buffer_ipc_meta(fptr_t _fa) { + AllReduceMeta* m = reinterpret_cast(_fa); + auto num_buffers = m->graph_unreg_buffers.size(); + auto handle_sz = sizeof(cudaIpcMemHandle_t); + std::string handles(handle_sz * num_buffers, static_cast(0)); + std::vector offsets(num_buffers); + for (int i = 0; i < num_buffers; i++) { + auto ptr = m->graph_unreg_buffers[i]; + void* base_ptr; + // note: must share the base address of each allocation, or we get wrong + // address + if (cuPointerGetAttribute(&base_ptr, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, (CUdeviceptr)ptr) != CUDA_SUCCESS) { + assert(false && "failed to get pointer attr"); + } + + CHECK_CUDA_SUCCESS(cudaIpcGetMemHandle((cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr)); + offsets[i] = ((char*)ptr) - ((char*)base_ptr); + } + std::vector bytes(handles.begin(), handles.end()); + return std::make_pair(bytes, offsets); +} + +char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) { + auto [it, new_handle] = meta->ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr}); + if (new_handle) { + char* ipc_ptr; + CHECK_CUDA_SUCCESS(cudaIpcOpenMemHandle((void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)ipc_handle), + cudaIpcMemLazyEnablePeerAccess)); + it->second = ipc_ptr; + } + return it->second; +} + +// Note: when registering graph buffers, we intentionally choose to not +// deduplicate the addresses. That means if the allocator reuses some +// addresses, they will be registered again. This is to account for the remote +// possibility of different allocation patterns between ranks. For example, +// rank 1 may get the same input address for the second allreduce, but rank 2 +// got a different address. IPC handles have internal reference counting +// mechanism so overhead should be small. +void register_graph_buffers(fptr_t _fa, const std::vector>& handles, + const std::vector>& offsets) { + AllReduceMeta* m = reinterpret_cast(_fa); + std::vector handle_bytes; + handle_bytes.reserve(handles.size()); + for (int i = 0; i < handles.size(); i++) { + handle_bytes.emplace_back(handles[i].begin(), handles[i].end()); + } + auto num_buffers = m->graph_unreg_buffers.size(); + std::vector rank_data(num_buffers); + for (int i = 0; i < num_buffers; i++) { + auto self_ptr = m->graph_unreg_buffers[i]; + auto& rd = rank_data[i]; + for (int j = 0; j < m->world_size; j++) { + if (j != m->rank_id) { + char* handle = open_ipc_handle(m, &handle_bytes[j][i * sizeof(cudaIpcMemHandle_t)]); + handle += offsets[j][i]; + rd.ptrs[j] = handle; + } else { + rd.ptrs[j] = self_ptr; + } + } + } + CHECK_CUDA_SUCCESS( + cudaMemcpy(m->rank_data_base, rank_data.data(), sizeof(RankData) * num_buffers, cudaMemcpyHostToDevice)); + m->rank_data_base += num_buffers; + m->graph_unreg_buffers.clear(); +} + +void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) { + AllReduceMeta* m = reinterpret_cast(_fa); + auto stream = c10::cuda::getCurrentCUDAStream().stream(); + auto num_elements = inp.numel(); + auto dtype = inp.scalar_type(); + AllReduceStrategyType strategy = SelectImplementation(num_elements * ((get_bits(dtype) + 7) / 8), m->world_size); + + // should be gurantee in python code + assert(strategy == AllReduceStrategyType::ONESHOT || strategy == AllReduceStrategyType::TWOSHOT); + assert(CanApplyCustomAllReduce(num_elements, dtype)); + + // Initialize the all-reduce kernel arguments. + int world_size = m->world_size; + + AllReduceParams params; + params.ranks_per_node = world_size; + params.rank = m->rank_id; + params.local_rank = m->rank_id; + params.local_input_buffer_ptr = inp.data_ptr(); + params.local_output_buffer_ptr = out.data_ptr(); + params.elts_total = inp.numel(); + params.elts_size = inp.element_size(); + params.barrier_flag = ++(m->barrier_flag); + + cudaStreamCaptureStatus status; + CHECK_CUDA_SUCCESS(cudaStreamIsCapturing(stream, &status)); + params.is_capturing = (status == cudaStreamCaptureStatusActive); + if (params.is_capturing) { + params.peer_comm_buffer_ptrs = m->rank_data_base + m->graph_unreg_buffers.size(); + m->graph_unreg_buffers.push_back(params.local_input_buffer_ptr); + } else { + params.peer_comm_buffer_ptrs = m->buffers; + } + + for (int i = 0; i < world_size; ++i) { + params.tmp_result_buffers[i] = reinterpret_cast(m->tmp_result_buffers[i]); + } + for (int i = 0; i < world_size; ++i) { + params.peer_barrier_ptrs_in[i] = reinterpret_cast(m->barrier_in[i]); + } + for (int i = 0; i < world_size; ++i) { + params.peer_barrier_ptrs_out[i] = reinterpret_cast(m->barrier_out[i]); + } + + auto data_type = out.scalar_type(); + trtCustomAllReduce(params, data_type, strategy, stream); +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc b/sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc deleted file mode 100644 index 46c6a41c3ac..00000000000 --- a/sgl-kernel/src/sgl-kernel/csrc/warp_reduce.cc +++ /dev/null @@ -1,20 +0,0 @@ -#include - -torch::Tensor warp_reduce_cuda(torch::Tensor input); - -#define CHECK_CUDA(x) \ - TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) \ - TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) - -torch::Tensor warp_reduce(torch::Tensor input) { - CHECK_INPUT(input); - return warp_reduce_cuda(input); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("reduce", &warp_reduce, "Warp Reduce (CUDA)"); -} diff --git a/sgl-kernel/src/sgl-kernel/csrc/warp_reduce_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/warp_reduce_kernel.cu deleted file mode 100644 index c547682f60d..00000000000 --- a/sgl-kernel/src/sgl-kernel/csrc/warp_reduce_kernel.cu +++ /dev/null @@ -1,97 +0,0 @@ -#include -#include -#include - -#define FINAL_MASK 0xffffffff -#define BLOCK_SIZE 256 - -template -__device__ __forceinline__ scalar_t add(scalar_t a, scalar_t b) { - return a + b; -} - -template -__device__ __forceinline__ scalar_t warpReduceSum(scalar_t val) { -#pragma unroll - for (int offset = 16; offset > 0; offset /= 2) { - val += __shfl_down_sync(FINAL_MASK, val, offset); - } - return val; -} - -template -__device__ __forceinline__ scalar_t blockReduceSum(scalar_t val) { - __shared__ scalar_t shared[32]; - int lane = threadIdx.x % 32; - int wid = threadIdx.x / 32; - - val = warpReduceSum(val); // First reduce within warp - - if (lane == 0) - shared[wid] = val; // Write reduced value to shared memory - - __syncthreads(); // Wait for all partial reductions - - // Read from shared memory only if that warp existed - val = (threadIdx.x < (blockDim.x / 32)) ? shared[lane] : 0; - - if (wid == 0) - val = warpReduceSum(val); // Final reduce within first warp - - return val; -} - -template -__global__ void warp_reduce_cuda_kernel( - const torch::PackedTensorAccessor32 - input, - torch::PackedTensorAccessor32 output, - int N) { - - scalar_t sum = 0; - - // Grid-stride loop - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; - i += blockDim.x * gridDim.x) { - sum += input[i]; - } - - // Perform block-wide reduction - sum = blockReduceSum(sum); - - // Write result for this block to global memory - if (threadIdx.x == 0) { - output[blockIdx.x] = sum; - } -} - -torch::Tensor warp_reduce_cuda(torch::Tensor input) { - // Input validation - TORCH_CHECK(input.dim() == 1, "1D tensor expected"); - TORCH_CHECK(input.is_cuda(), "CUDA tensor expected"); - - const auto N = input.size(0); - - // Handle empty tensor - if (N == 0) { - return torch::zeros({1}, input.options()); - } - - // Calculate grid dimensions - const int threads = BLOCK_SIZE; - const int blocks = (N + threads - 1) / threads; - - // Allocate output tensor for partial sums - auto output = torch::empty({blocks}, input.options()); - - AT_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "warp_reduce_cuda", ([&] { - warp_reduce_cuda_kernel<<>>( - input.packed_accessor32(), - output.packed_accessor32(), - N); - })); - - // Sum the partial results - return output.sum(); -} diff --git a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h new file mode 100644 index 00000000000..c5cc30c1888 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h @@ -0,0 +1,114 @@ +#pragma once + +#include +#include + +#include + +#define _CONCAT(A, B) A##B +#define CONCAT(A, B) _CONCAT(A, B) + +#define _STRINGIFY(A) #A +#define STRINGIFY(A) _STRINGIFY(A) + +#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) + +#define REGISTER_EXTENSION(NAME) \ + PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \ + static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \ + return PyModule_Create(&module); \ + } + +// trt_reduce +using fptr_t = int64_t; +fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector& buffers, + const std::vector& tmp_result_buffers, const std::vector& barrier_in, + const std::vector& barrier_out); +void dispose(fptr_t _fa); +void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); +std::tuple, std::vector> get_graph_buffer_ipc_meta(fptr_t _fa); +void register_graph_buffers(fptr_t _fa, const std::vector>& handles, + const std::vector>& offsets); + +// moe_align_block_size +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, + torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, + torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer); + +// int8_scaled_mm +torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, + const torch::Tensor& scales_b, const torch::Dtype& out_dtype, + const c10::optional& bias); + +// fp8_scaled_mm +torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, + const torch::Tensor& scales_b, const torch::Dtype& out_dtype, + const c10::optional& bias); + +// lightning_attention_decode +void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, + const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output, + torch::Tensor new_kv); + +// rms norm +void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); + +// fused rms norm +void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps); + +// gemma rms norm +void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); + +// fused gemma rms norm +void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, + int64_t cuda_stream); + +// silu and mul +void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); + +// gelu tanh and mul +void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); + +// gelu and mul +void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); + +// bmm fp8 +void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale, + at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream); + +// min p sampling from probs +void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, + std::optional maybe_min_p_arr, double min_p_val, bool deterministic, + int64_t cuda_stream); + +// top k renorm probs +// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension. +void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional maybe_top_k_arr, + unsigned int top_k_val, int64_t cuda_stream); + +// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension. +// wrapper for binding +inline void top_k_renorm_probs_wrapper(at::Tensor probs, at::Tensor renorm_probs, + std::optional maybe_top_k_arr, int64_t top_k_val, + int64_t cuda_stream) { + top_k_renorm_probs(probs, renorm_probs, maybe_top_k_arr, static_cast(top_k_val), cuda_stream); +} + +// top p renorm probs +void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional maybe_top_p_arr, + double top_p_val, int64_t cuda_stream); + +// top k top p sampling from probs +void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, + at::Tensor success, std::optional maybe_top_k_arr, double top_k_val, + std::optional maybe_top_p_arr, double top_p_val, bool deterministic, + int64_t cuda_stream); + +// top p sampling from probs +void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success, + std::optional maybe_top_p_arr, double top_p_val, bool deterministic, + int64_t cuda_stream); + +void apply_rope_pos_ids_cos_sin_cache(at::Tensor q, at::Tensor k, at::Tensor q_rope, at::Tensor k_rope, + at::Tensor cos_sin_cache, at::Tensor pos_ids, bool interleave, + int64_t cuda_stream); diff --git a/sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh b/sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh new file mode 100644 index 00000000000..46522348aaf --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh @@ -0,0 +1,94 @@ +// reference: +// https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/plugins/ncclPlugin/allreducePlugin.cpp +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace trt_llm { +constexpr size_t WARP_SIZE = 32; +constexpr size_t MAX_ALL_REDUCE_BLOCKS = 36; +constexpr size_t MAX_RANKS_PER_NODE = 8; +constexpr size_t DEFAULT_BLOCK_SIZE = 512; + +enum class AllReduceStrategyType : int8_t { + RING = 0, + ONESHOT = 1, + TWOSHOT = 2, + AUTO = 3, +}; + +struct RankData { + void* ptrs[MAX_RANKS_PER_NODE]; +}; + +struct AllReduceParams { + size_t elts_size; + size_t elts_total; + size_t elts_per_rank; + size_t elts_per_block; + size_t rank_offset; + size_t ranks_per_node, rank, local_rank; + uint32_t barrier_flag; + uint32_t* peer_barrier_ptrs_in[MAX_RANKS_PER_NODE]; + uint32_t* peer_barrier_ptrs_out[MAX_RANKS_PER_NODE]; + uint32_t* tmp_result_buffers[MAX_RANKS_PER_NODE]; + RankData* peer_comm_buffer_ptrs; + void* local_input_buffer_ptr; + void* local_output_buffer_ptr; + bool is_capturing; +}; + +inline size_t GetMaxRequiredWorkspaceSize(int world_size) { + if (world_size <= 2) { + return 16 * 1024 * 1024; + } + return 8 * 1024 * 1024; +} + +inline AllReduceStrategyType SelectImplementation(size_t message_size, int world_size) { + const size_t maxWorkspaceSize = GetMaxRequiredWorkspaceSize(world_size); + + if (message_size > maxWorkspaceSize) { + assert(false && "Custom allreduce do not ring currently"); + return AllReduceStrategyType::RING; + } + + if (world_size <= 2) { + return AllReduceStrategyType::ONESHOT; + } + + if (world_size <= 4) { + if (message_size < 1 * 1024 * 1024) { + return AllReduceStrategyType::ONESHOT; + } + return AllReduceStrategyType::TWOSHOT; + } + + if (message_size < 512 * 1024) { + return AllReduceStrategyType::ONESHOT; + } + return AllReduceStrategyType::TWOSHOT; +} + +void trtCustomAllReduce(AllReduceParams& params, at::ScalarType data_type, AllReduceStrategyType strat, + cudaStream_t stream); + +} // namespace trt_llm diff --git a/sgl-kernel/src/sgl-kernel/include/utils.h b/sgl-kernel/src/sgl-kernel/include/utils.h new file mode 100644 index 00000000000..55594f7b273 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/include/utils.h @@ -0,0 +1,66 @@ +#pragma once + +#include +#include +#include + +#include + +struct cuda_error : public std::runtime_error { + /** + * @brief Constructs a `cuda_error` object with the given `message`. + * + * @param message The error char array used to construct `cuda_error` + */ + cuda_error(const char* message) : std::runtime_error(message) {} + /** + * @brief Constructs a `cuda_error` object with the given `message` string. + * + * @param message The `std::string` used to construct `cuda_error` + */ + cuda_error(std::string const& message) : cuda_error{message.c_str()} {} +}; + +#define CHECK_CUDA_SUCCESS(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + std::stringstream _message; \ + auto s = cudaGetErrorString(e); \ + _message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \ + throw cuda_error(_message.str()); \ + } \ + } while (0) + +#define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_CUDA_INPUT(x) \ + CHECK_IS_CUDA(x); \ + CHECK_IS_CONTIGUOUS(x) + +inline int getSMVersion() { + int device{-1}; + CHECK_CUDA_SUCCESS(cudaGetDevice(&device)); + int sm_major = 0; + int sm_minor = 0; + CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device)); + CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device)); + return sm_major * 10 + sm_minor; +} + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + case at::ScalarType::Float: { \ + using c_type = float; \ + return __VA_ARGS__(); \ + } \ + _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index 21870032e5a..5aa484ff54d 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -1,5 +1,497 @@ -from .warp_reduce_cuda import reduce as _reduce +import os +from typing import Optional, Tuple, Union +import sgl_kernel.ops._kernels +import torch +from sgl_kernel.ops.utils import ( + _get_cache_buf, + _get_cuda_stream, + _to_tensor_scalar_tuple, +) -def warp_reduce(input_tensor): - return _reduce(input_tensor) + +def apply_rope_with_cos_sin_cache_inplace( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + head_size: int, + cos_sin_cache: torch.Tensor, + is_neox: bool = True, +) -> None: + r""" + Apply rotary embedding to keys and queries with precomputed cos/sin values. + This is designed to be compatible with the SGL/vLLM implementation. + The result is inplace applied to the input tensors. + + Parameters + ---------- + positions : torch.Tensor + Position indices, shape: ``(nnz)``. + query : torch.Tensor + Query tensor, shape: ``(nnz, num_q_heads * head_size)``. + key : torch.Tensor + Key tensor, shape: ``(nnz, num_k_heads * head_size)``. + cos_sin_cache : torch.Tensor + Cosine and Sine cache tensor, shape: ``(max_seq_len, rotary_dim)``. + Cosine is the first half and Sine is the second half on rotary_dim. + is_neox : bool + Whether to use Neox style RoPE, default: ``True``. + + * If ``True``, the last dimension of the query/key tensor is not interleaved, i.e., + we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half + dimensions ``([..., head_dim//2:])``. + + * If ``False``, the last dimension of the query/key tensor is interleaved, i.e., + we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. + Note + ---- + The rotary dimension is determined by the cosine cache and sine cache. + """ + if cos_sin_cache.dtype != torch.float32: + raise ValueError("cos_sin_cache should be float32") + + with query.device as device: + positions = positions.int() + torch.ops.sgl_kernels.apply_rope_pos_ids_cos_sin_cache( + q=query.view(query.shape[0], -1, head_size), + k=key.view(key.shape[0], -1, head_size), + q_rope=query.view(query.shape[0], -1, head_size), + k_rope=key.view(key.shape[0], -1, head_size), + cos_sin_cache=cos_sin_cache, + pos_ids=positions, + interleave=(not is_neox), + cuda_stream=_get_cuda_stream(device), + ) + + +def init_custom_reduce( + rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out +): + return torch.ops.sgl_kernels.init_custom_ar( + rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out + ) + + +def custom_dispose(fa): + torch.ops.sgl_kernels.dispose(fa) + + +def custom_reduce(fa, inp, out): + torch.ops.sgl_kernels.all_reduce(fa, inp, out) + + +def get_graph_buffer_ipc_meta(fa): + return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa) + + +def register_graph_buffers(fa, handles, offsets): + torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets) + + +def moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_token_ids, + experts_ids, + num_tokens_post_pad, + token_cnts_buffer, + cumsum_buffer, +): + torch.ops.sgl_kernels.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_token_ids, + experts_ids, + num_tokens_post_pad, + token_cnts_buffer, + cumsum_buffer, + ) + + +def sampling_scaling_penalties(logits, scaling_penalties): + return torch.ops.sgl_kernels.sampling_scaling_penalties(logits, scaling_penalties) + + +def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): + return torch.ops.sgl_kernels.int8_scaled_mm( + mat_a, + mat_b, + scales_a, + scales_b, + out_dtype, + bias, + ) + + +def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): + return torch.ops.sgl_kernels.fp8_scaled_mm( + mat_a, + mat_b, + scales_a, + scales_b, + out_dtype, + bias, + ) + + +def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv): + torch.ops.sgl_kernels.lightning_attention_decode( + q, k, v, past_kv, slope, output, new_kv + ) + + +# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer +# Kudos to @yzh119 +def rmsnorm( + input: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + with input.device as device: + if out is None: + out = torch.empty_like(input) + torch.ops.sgl_kernels.rmsnorm(out, input, weight, eps, _get_cuda_stream(device)) + return out + + +def fused_add_rmsnorm( + input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 +) -> None: + with input.device as device: + torch.ops.sgl_kernels.fused_add_rmsnorm(input, residual, weight, eps) + + +def gemma_rmsnorm( + input: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + with input.device as device: + if out is None: + out = torch.empty_like(input) + torch.ops.sgl_kernels.gemma_rmsnorm( + out, input, weight, eps, _get_cuda_stream(device) + ) + return out + + +def gemma_fused_add_rmsnorm( + input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 +) -> None: + with input.device as device: + torch.ops.sgl_kernels.gemma_fused_add_rmsnorm( + input, residual, weight, eps, _get_cuda_stream(device) + ) + + +def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None: + assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}" + assert ( + input.shape[:-1] == output.shape[:-1] + ), f"{input.shape[:-1]} != {output.shape[:-1]}" + assert ( + input.shape[-1] == 2 * output.shape[-1] + ), f"{input.shape[-1]} != {2 * output.shape[-1]}" + + +def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: + if input.shape[-1] * input.dtype.itemsize % 16 != 0: + raise ValueError("The pointers must be multiple of 16 bytes.") + if out is not None: + _check_shape(input, out) + else: + out = torch.empty( + input.shape[:-1] + (input.shape[-1] // 2,), + device=input.device, + dtype=input.dtype, + ) + with input.device as device: + torch.ops.sgl_kernels.silu_and_mul(out, input, _get_cuda_stream(device)) + return out + + +def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: + if input.shape[-1] * input.dtype.itemsize % 16 != 0: + raise ValueError("The pointers must be multiple of 16 bytes.") + if out is not None: + _check_shape(input, out) + else: + out = torch.empty( + input.shape[:-1] + (input.shape[-1] // 2,), + device=input.device, + dtype=input.dtype, + ) + with input.device as device: + torch.ops.sgl_kernels.gelu_tanh_and_mul(out, input, _get_cuda_stream(device)) + return out + + +def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: + if input.shape[-1] * input.dtype.itemsize % 16 != 0: + raise ValueError("The pointers must be multiple of 16 bytes.") + if out is not None: + _check_shape(input, out) + else: + out = torch.empty( + input.shape[:-1] + (input.shape[-1] // 2,), + device=input.device, + dtype=input.dtype, + ) + with input.device as device: + torch.ops.sgl_kernels.gelu_and_mul(out, input, _get_cuda_stream(device)) + return out + + +def _bmm_fp8_internal( + workspace_buffer: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + D: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, +) -> None: + with A.device as device: + cublas_handle = torch.cuda.current_blas_handle() + torch.ops.sgl_kernels.bmm_fp8( + A, + B, + D, + A_scale, + B_scale, + workspace_buffer, + cublas_handle, + _get_cuda_stream(device), + ) + + +def bmm_fp8( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if out is None: + out = torch.empty( + (A.shape[0], A.shape[1], B.shape[2]), + device=A.device, + dtype=dtype, + ) + workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device) + _bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale) + return out + + +def _top_k_renorm_probs_internal( + probs: torch.Tensor, + maybe_top_k_arr: Optional[torch.Tensor], + top_k_val: int, +) -> torch.Tensor: + with probs.device as device: + probs = probs.float() + maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None + renorm_probs = torch.empty_like(probs) + torch.ops.sgl_kernels.top_k_renorm_probs_wrapper( + probs, + renorm_probs, + maybe_top_k_arr, + top_k_val, + _get_cuda_stream(device), + ) + return renorm_probs + + +def top_k_renorm_probs( + probs: torch.Tensor, + top_k: Union[torch.Tensor, int], +) -> torch.Tensor: + return _top_k_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_k)) + + +top_k_renorm_prob = top_k_renorm_probs + + +def _top_p_renorm_probs_internal( + probs: torch.Tensor, + maybe_top_p_arr: Optional[torch.Tensor], + top_p_val: float, +) -> torch.Tensor: + with probs.device as device: + probs = probs.float() + maybe_top_p_arr = ( + maybe_top_p_arr.float() if maybe_top_p_arr is not None else None + ) + renorm_probs = torch.empty_like(probs) + torch.ops.sgl_kernels.top_p_renorm_probs( + probs, + renorm_probs, + maybe_top_p_arr, + top_p_val, + _get_cuda_stream(device), + ) + return renorm_probs + + +def top_p_renorm_probs( + probs: torch.Tensor, + top_p: Union[torch.Tensor, float], +) -> torch.Tensor: + return _top_p_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_p)) + + +top_p_renorm_prob = top_p_renorm_probs + + +def _top_p_sampling_from_probs_internal( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + maybe_top_p_arr: Optional[torch.Tensor], + top_p_val: float, + deterministic: bool, +) -> Tuple[torch.Tensor, torch.Tensor]: + with probs.device as device: + probs = probs.float() + uniform_samples = uniform_samples.float() + maybe_top_p_arr = ( + maybe_top_p_arr.float() if maybe_top_p_arr is not None else None + ) + samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) + success = torch.empty(probs.size(0), dtype=torch.bool, device=device) + torch.ops.sgl_kernels.top_p_sampling_from_probs( + probs, + uniform_samples, + samples, + success, + maybe_top_p_arr, + top_p_val, + deterministic, + _get_cuda_stream(device), + ) + return samples, success + + +def top_p_sampling_from_probs( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + top_p: Union[torch.Tensor, float], + deterministic: bool = True, + check_nan: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + if check_nan: + if torch.any(torch.isnan(probs)): + raise ValueError("Input probs contains NaN.") + return _top_p_sampling_from_probs_internal( + probs, uniform_samples, *_to_tensor_scalar_tuple(top_p), deterministic + ) + + +def _top_k_top_p_sampling_from_probs_internal( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + maybe_top_k_arr: Optional[torch.Tensor], + top_k_val: int, + maybe_top_p_arr: Optional[torch.Tensor], + top_p_val: float, + deterministic: bool, +) -> Tuple[torch.Tensor, torch.Tensor]: + with probs.device as device: + probs = probs.float() + uniform_samples = uniform_samples.float() + maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None + maybe_top_p_arr = ( + maybe_top_p_arr.float() if maybe_top_p_arr is not None else None + ) + samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) + success = torch.empty(probs.size(0), dtype=torch.bool, device=device) + torch.ops.sgl_kernels.top_k_top_p_sampling_from_probs( + probs, + uniform_samples, + samples, + success, + maybe_top_k_arr, + top_k_val, + maybe_top_p_arr, + top_p_val, + deterministic, + _get_cuda_stream(device), + ) + return samples, success + + +def top_k_top_p_sampling_from_probs( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + top_k: Union[torch.Tensor, int], + top_p: Union[torch.Tensor, float], + filter_apply_order: str = "top_k_first", + deterministic: bool = True, + check_nan: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + if filter_apply_order == "top_k_first": + renorm_probs = top_k_renorm_probs(probs, top_k) + return top_p_sampling_from_probs( + renorm_probs, uniform_samples, top_p, deterministic, check_nan=check_nan + ) + elif filter_apply_order == "joint": + if check_nan: + if torch.any(torch.isnan(probs)): + raise ValueError("Input probs contains NaN.") + return _top_k_top_p_sampling_from_probs_internal( + probs, + uniform_samples, + *_to_tensor_scalar_tuple(top_k), + *_to_tensor_scalar_tuple(top_p), + deterministic, + ) + else: + raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}") + + +def _min_p_sampling_from_probs_internal( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + maybe_min_p_arr: Optional[torch.Tensor], + min_p_val: float, + deterministic: bool, +) -> torch.Tensor: + with probs.device as device: + probs = probs.float() + uniform_samples = uniform_samples.float() + maybe_min_p_arr = ( + maybe_min_p_arr.float() if maybe_min_p_arr is not None else None + ) + samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) + torch.ops.sgl_kernels.min_p_sampling_from_probs( + probs, + uniform_samples, + samples, + maybe_min_p_arr, + min_p_val, + deterministic, + _get_cuda_stream(device), + ) + return samples + + +def min_p_sampling_from_probs( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + min_p: Union[torch.Tensor, float], + deterministic: bool = True, + check_nan: bool = False, +) -> torch.Tensor: + if uniform_samples.dim() == 2: + # Take the first row (round) of uniform_samples + uniform_samples = uniform_samples[0] + + if check_nan: + if torch.any(torch.isnan(probs)): + raise ValueError("Input probs contains NaN.") + return _min_p_sampling_from_probs_internal( + probs, uniform_samples, *_to_tensor_scalar_tuple(min_p), deterministic + ) diff --git a/sgl-kernel/src/sgl-kernel/ops/utils.py b/sgl-kernel/src/sgl-kernel/ops/utils.py new file mode 100644 index 00000000000..31a6bbf9919 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/ops/utils.py @@ -0,0 +1,26 @@ +from typing import Dict, Tuple + +import torch + + +def _get_cuda_stream(device: torch.device) -> int: + return torch.cuda.current_stream(device).cuda_stream + + +_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {} + + +def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor: + key = (name, device) + buf = _cache_buf.get(key) + if buf is None: + buf = torch.empty(bytes, dtype=torch.uint8, device=device) + _cache_buf[key] = buf + return buf + + +def _to_tensor_scalar_tuple(x): + if isinstance(x, torch.Tensor): + return (x, 0) + else: + return (None, x) diff --git a/sgl-kernel/src/sgl-kernel/torch_extension.cc b/sgl-kernel/src/sgl-kernel/torch_extension.cc new file mode 100644 index 00000000000..01f93199ccb --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/torch_extension.cc @@ -0,0 +1,120 @@ +#include +#include + +#include "sgl_kernels_ops.h" + +TORCH_LIBRARY_EXPAND(sgl_kernels, m) { + // trt_reduce + m.def( + "init_custom_ar(int rank_id, int world_size, Tensor rank_data, int[] buffers, int[] tmp_result_buffers, int[] " + "barrier_in, int[] barrier_out) -> int"); + m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar); + + m.def("dispose", &dispose); + + m.def("all_reduce(int fa, Tensor inp, Tensor! out) -> ()"); + m.impl("all_reduce", torch::kCUDA, &all_reduce); + + m.def("get_graph_buffer_ipc_meta(int fa) -> (int[], int[])"); + m.impl("get_graph_buffer_ipc_meta", torch::kCUDA, &get_graph_buffer_ipc_meta); + + m.def("register_graph_buffers(int fa, int[][] handles, int[][] offsets) -> ()"); + m.impl("register_graph_buffers", torch::kCUDA, ®ister_graph_buffers); + + // moe_align_block_size + m.def( + "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " + "experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"); + m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); + + // int8_scaled_mm + m.def( + "int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? " + "bias) -> Tensor"); + m.impl("int8_scaled_mm", torch::kCUDA, &int8_scaled_mm); + + // fp8_scaled_mm + m.def( + "fp8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? " + "bias) -> Tensor"); + m.impl("fp8_scaled_mm", torch::kCUDA, &fp8_scaled_mm); + + // lightning_attention_decode + m.def( + "lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! " + "new_kv) -> ()"); + m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode); + + // rms norm + m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("rmsnorm", torch::kCUDA, &rmsnorm); + + // fused rms norm + m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()"); + m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm); + + // gemma rms norm + m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm); + + // fused gemma rms norm + m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm); + + // silu and mul + m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); + + // gelu tanh and mul + m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); + + // gelu and mul + m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); + + // bmm fp8 + m.def( + "bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int " + "cublas_handle, int cuda_stream) -> ()"); + m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8); + + // min p sampling from probs + m.def( + "min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor? maybe_min_p_arr, float " + "min_p_val, bool deterministic, int cuda_stream) -> ()"); + m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs); + + // top k renorm probs + m.def( + "top_k_renorm_probs_wrapper(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int " + "cuda_stream) -> ()"); + m.impl("top_k_renorm_probs_wrapper", torch::kCUDA, &top_k_renorm_probs_wrapper); + + // top p renorm probs + m.def( + "top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int " + "cuda_stream) -> ()"); + m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs); + + // top k top p sampling from probs + m.def( + "top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? " + "maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int " + "cuda_stream) -> ()"); + m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs); + + // top p sampling from probs + m.def( + "top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? " + "maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()"); + m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs); + + // apply rope with cos sin cache + m.def( + "apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, " + "Tensor pos_ids, bool interleave, int cuda_stream) -> ()"); + m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache); +} + +REGISTER_EXTENSION(_kernels) diff --git a/sgl-kernel/tests/.gitkeep b/sgl-kernel/tests/.gitkeep deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/sgl-kernel/tests/test_activation.py b/sgl-kernel/tests/test_activation.py new file mode 100644 index 00000000000..43593441e3b --- /dev/null +++ b/sgl-kernel/tests/test_activation.py @@ -0,0 +1,39 @@ +# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_activation.py + +import pytest +import sgl_kernel +import torch + + +@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384]) +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512]) +def test_fused_silu_mul(dim, batch_size, seq_len): + x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16) + y_ref = x[..., dim:] * torch.nn.functional.silu(x[..., :dim]) + y = sgl_kernel.silu_and_mul(x) + torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384]) +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512]) +def test_fused_gelu_tanh_mul(dim, batch_size, seq_len): + x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16) + y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="tanh") + y = sgl_kernel.gelu_tanh_and_mul(x) + torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384]) +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512]) +def test_fused_gelu_mul(dim, batch_size, seq_len): + x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16) + y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="none") + y = sgl_kernel.gelu_and_mul(x) + torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_bmm_fp8.py b/sgl-kernel/tests/test_bmm_fp8.py new file mode 100644 index 00000000000..e0be92896f6 --- /dev/null +++ b/sgl-kernel/tests/test_bmm_fp8.py @@ -0,0 +1,43 @@ +# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_bmm_fp8.py + +import pytest +import torch +import torch.nn.functional as F +from sgl_kernel import bmm_fp8 + + +def to_float8(x, dtype=torch.float8_e4m3fn): + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype), scale.float().reciprocal() + + +@pytest.mark.parametrize("input_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("mat2_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) +def test_bmm_fp8(input_dtype, mat2_dtype, res_dtype): + if input_dtype == torch.float8_e5m2 and mat2_dtype == torch.float8_e5m2: + pytest.skip("Invalid combination: both input and mat2 are e5m2") + + input = torch.randn([16, 48, 64], device="cuda", dtype=torch.bfloat16) + input_fp8, input_inv_s = to_float8(input, dtype=input_dtype) + + # mat2 row major -> column major + mat2 = torch.randn([16, 80, 64], device="cuda", dtype=torch.bfloat16).transpose( + -2, -1 + ) + mat2_fp8, mat2_inv_s = to_float8(mat2, dtype=mat2_dtype) + + res = torch.empty([16, 48, 80], device="cuda", dtype=res_dtype) + bmm_fp8(input_fp8, mat2_fp8, input_inv_s, mat2_inv_s, res_dtype, res) + + reference = torch.bmm(input, mat2) + cos_sim = F.cosine_similarity(reference.reshape(-1), res.reshape(-1), dim=0) + assert cos_sim > 0.99 + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_fp8_gemm.py b/sgl-kernel/tests/test_fp8_gemm.py new file mode 100644 index 00000000000..1a731865944 --- /dev/null +++ b/sgl-kernel/tests/test_fp8_gemm.py @@ -0,0 +1,67 @@ +import unittest + +import torch +from sgl_kernel import fp8_scaled_mm + + +def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias): + o = torch.matmul(a.to(torch.float32), b.to(torch.float32)) + + o = o.to(torch.float32) + temp1 = o * scale_a.view(-1, 1) + temp2 = temp1 * scale_b.view(1, -1) + final = temp2.to(out_dtype) + if bias is not None: + final = final + bias.view(1, -1) + + return final + + +class TestFp8Gemm(unittest.TestCase): + def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device): + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + a_fp32 = ( + (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max + ) + a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + b_fp32 = ( + (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max + ) + b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + scale_a = torch.randn((M,), device=device, dtype=torch.float32) * 0.001 + scale_b = torch.randn((N,), device=device, dtype=torch.float32) * 0.001 + if with_bias: + bias = torch.randn((N,), device=device, dtype=out_dtype) + else: + bias = None + o1 = torch.empty((M, N), device=device, dtype=torch.bfloat16) + b_fp8 = b_fp8.t() + o = torch_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias) + o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias) + rtol = 0.02 + atol = 1 + torch.testing.assert_close(o, o1, rtol=rtol, atol=atol) + print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK") + + def test_accuracy(self): + Ms = [1, 128, 512, 1024, 4096] + Ns = [16, 128, 512, 1024, 4096] + Ks = [512, 1024, 4096, 8192, 16384] + bias_opts = [True, False] + out_dtypes = [torch.bfloat16, torch.float16] + for M in Ms: + for N in Ns: + for K in Ks: + for with_bias in bias_opts: + for out_dtype in out_dtypes: + self._test_accuracy_once( + M, N, K, with_bias, out_dtype, "cuda" + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/sgl-kernel/tests/test_int8_gemm.py b/sgl-kernel/tests/test_int8_gemm.py new file mode 100644 index 00000000000..c33a3effcaf --- /dev/null +++ b/sgl-kernel/tests/test_int8_gemm.py @@ -0,0 +1,56 @@ +import unittest + +import torch +from sgl_kernel import int8_scaled_mm +from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm + + +def to_int8(tensor: torch.Tensor) -> torch.Tensor: + return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) + + +def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias): + o = torch.matmul(a.to(torch.float32), b.to(torch.float32)) + if bias is not None: + o = o.to(torch.float32) * scale_a.view(-1, 1) * scale_b.view(1, -1) + bias + else: + o = o.to(torch.float32) * scale_a.view(-1, 1) * scale_b.view(1, -1) + return o.to(out_dtype) + + +class TestInt8Gemm(unittest.TestCase): + def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device): + a = to_int8(torch.randn((M, K), device=device) * 5) + b = to_int8(torch.randn((N, K), device=device).t() * 5) + scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) + scale_b = torch.randn((N,), device="cuda", dtype=torch.float32) + if with_bias: + bias = torch.randn((N,), device="cuda", dtype=out_dtype) * 10 + else: + bias = None + + o = int8_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + o1 = torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + o2 = vllm_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) + torch.testing.assert_close(o, o1) + torch.testing.assert_close(o, o2) + print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK") + + def test_accuracy(self): + Ms = [1, 128, 512, 1024, 4096, 8192] + Ns = [16, 128, 512, 1024, 4096, 8192, 16384] + Ks = [512, 1024, 4096, 8192, 16384] + bias_opts = [True, False] + out_dtypes = [torch.float16, torch.bfloat16] + for M in Ms: + for N in Ns: + for K in Ks: + for with_bias in bias_opts: + for out_dtype in out_dtypes: + self._test_accuracy_once( + M, N, K, with_bias, out_dtype, "cuda" + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/sgl-kernel/tests/test_lightning_attention_decode.py b/sgl-kernel/tests/test_lightning_attention_decode.py new file mode 100644 index 00000000000..f2cace00157 --- /dev/null +++ b/sgl-kernel/tests/test_lightning_attention_decode.py @@ -0,0 +1,88 @@ +import pytest +import torch +from sgl_kernel import lightning_attention_decode + + +def naive_lightning_attention_decode(q, k, v, past_kv, slope): + """Naive implementation of lightning attention decode""" + original_dtype = q.dtype + ratio = torch.exp(-slope) # [h, 1, 1] + + kv = past_kv + b, h, n, d = q.shape + + output = [] + for i in range(n): + kv = ratio * kv.to(torch.float32) + torch.einsum( + "... n d, ... n e -> ... d e", + k[:, :, i : i + 1], + v[:, :, i : i + 1], + ) + qkv = torch.einsum( + "... n e, ... e d -> ... n d", + q[:, :, i : i + 1].to(torch.float32), + kv.to(torch.float32), + ) + output.append(qkv) + output = torch.concat(output, dim=-2) + + return output.to(original_dtype), kv + + +configs = [ + # (batch_size, num_heads, dim, embed_dim) + (1, 8, 64, 64), + (2, 8, 64, 64), + (1, 32, 32, 64), + (2, 32, 32, 64), + (4, 32, 64, 64), + (4, 32, 64, 64), + (16, 64, 96, 96), + (64, 64, 96, 96), +] + +dtypes = [torch.float32, torch.float16, torch.bfloat16] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("batch_size,num_heads,dim,embed_dim", configs) +def test_lightning_attention_decode(dtype, batch_size, num_heads, dim, embed_dim): + device = torch.device("cuda") + + q = torch.randn(batch_size, num_heads, 1, dim, device=device, dtype=dtype) + k = torch.randn(batch_size, num_heads, 1, dim, device=device, dtype=dtype) + v = torch.randn(batch_size, num_heads, 1, embed_dim, device=device, dtype=dtype) + past_kv = torch.randn(batch_size, num_heads, dim, embed_dim, device=device) + slope = torch.randn(num_heads, 1, 1, device=device) + + ref_output, ref_new_kv = naive_lightning_attention_decode(q, k, v, past_kv, slope) + + output = torch.empty_like(ref_output) + new_kv = torch.empty_like(ref_new_kv) + lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv) + + rtol = 1e-2 + atol = 1e-2 + + torch.testing.assert_close( + output, + ref_output, + rtol=rtol, + atol=atol, + msg=f"Output mismatch for batch_size={batch_size}, num_heads={num_heads}, " + f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}", + ) + + torch.testing.assert_close( + new_kv, + ref_new_kv, + rtol=rtol, + atol=atol, + msg=f"New KV mismatch for batch_size={batch_size}, num_heads={num_heads}, " + f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}", + ) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_moe_align.py b/sgl-kernel/tests/test_moe_align.py new file mode 100644 index 00000000000..2fca90b2f56 --- /dev/null +++ b/sgl-kernel/tests/test_moe_align.py @@ -0,0 +1,67 @@ +import torch +from sgl_kernel import moe_align_block_size + + +def test_moe_align_block_size(): + # For DeepSeek V3, we have 256 experts + num_experts = 256 + + # Test different combinations of block_size, num_tokens and topk + for block_size in [32, 64, 128, 256]: + print(f"\nTesting block_size={block_size}") + for num_tokens in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]: + for topk in [1, 2, 4, 8, 16, 32, 64]: + print( + f"Testing block_size={block_size}, num_tokens={num_tokens}, topk={topk}" + ) + + # Create random topk_ids with shape [num_tokens, topk] + topk_ids = torch.randint( + 0, num_experts, (num_tokens, topk), dtype=torch.int32, device="cuda" + ) + + max_num_tokens_padded = topk_ids.numel() + num_experts * ( + block_size - 1 + ) + sorted_ids = torch.empty( + (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device + ) + sorted_ids.fill_(topk_ids.numel()) + max_num_m_blocks = max_num_tokens_padded // block_size + expert_ids = torch.empty( + (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device + ) + num_tokens_post_pad = torch.empty( + (1), dtype=torch.int32, device=topk_ids.device + ) + + token_cnts_buffer = torch.empty( + (num_experts + 1) * num_experts, + dtype=torch.int32, + device=topk_ids.device, + ) + cumsum_buffer = torch.empty( + num_experts + 1, dtype=torch.int32, device=topk_ids.device + ) + + try: + moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + token_cnts_buffer, + cumsum_buffer, + ) + except Exception as e: + print( + f"Error occurred with block_size={block_size}, num_tokens={num_tokens}, topk={topk}" + ) + print(f"Error message: {str(e)}") + raise e + + +if __name__ == "__main__": + test_moe_align_block_size() diff --git a/sgl-kernel/tests/test_norm.py b/sgl-kernel/tests/test_norm.py new file mode 100644 index 00000000000..d22da931f57 --- /dev/null +++ b/sgl-kernel/tests/test_norm.py @@ -0,0 +1,133 @@ +# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_norm.py + +import pytest +import sgl_kernel +import torch + + +def llama_rms_norm(x, w, eps=1e-6): + orig_dtype = x.dtype + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + x = x * w.float() + x = x.to(orig_dtype) + return x + + +def gemma_rms_norm(x, w, eps=1e-6): + orig_dtype = x.dtype + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + x = x * (1.0 + w.float()) + x = x.to(orig_dtype) + return x + + +def gemma_fused_add_rms_norm(x, residual, w, eps=1e-6): + orig_dtype = x.dtype + x = x + residual + residual = x + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + x = x * (1.0 + w.float()) + x = x.to(orig_dtype) + return x, residual + + +def fused_add_rms_norm(x, residual, weight, eps): + orig_dtype = x.dtype + x = x.to(torch.float32) + x = x + residual.to(torch.float32) + residual = x.to(orig_dtype) + + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + x = (x * weight.float()).to(orig_dtype) + return x, residual + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("specify_out", [True, False]) +def test_norm(batch_size, hidden_size, dtype, specify_out): + x = torch.randn(batch_size, hidden_size).to(0).to(dtype) + w = torch.randn(hidden_size).to(0).to(dtype) + + y_ref = llama_rms_norm(x, w) + if specify_out: + y = torch.empty_like(x) + sgl_kernel.rmsnorm(x, w, out=y) + else: + y = sgl_kernel.rmsnorm(x, w) + + torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +def test_fused_add_rmsnorm(batch_size, hidden_size, dtype): + eps = 1e-6 + + x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda") + residual = torch.randn_like(x) + weight = torch.randn(hidden_size, dtype=dtype, device="cuda") + + x_native, residual_native = fused_add_rms_norm( + x.clone(), residual.clone(), weight, eps + ) + + x_fused = x.clone() + residual_fused = residual.clone() + sgl_kernel.fused_add_rmsnorm(x_fused, residual_fused, weight, eps) + + torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("specify_out", [True, False]) +def test_gemma_norm(batch_size, hidden_size, dtype, specify_out): + x = torch.randn(batch_size, hidden_size).to(0).to(dtype) + w = torch.randn(hidden_size).to(0).to(dtype) + + y_ref = gemma_rms_norm(x, w) + if specify_out: + y = torch.empty_like(x) + sgl_kernel.gemma_rmsnorm(x, w, out=y) + else: + y = sgl_kernel.gemma_rmsnorm(x, w) + + torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype): + eps = 1e-6 + + x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda") + residual = torch.randn_like(x) + weight = torch.randn(hidden_size, dtype=dtype, device="cuda") + + x_native, residual_native = gemma_fused_add_rms_norm( + x.clone(), residual.clone(), weight, eps + ) + + x_fused = x.clone() + residual_fused = residual.clone() + sgl_kernel.gemma_fused_add_rmsnorm(x_fused, residual_fused, weight, eps) + + torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_rotary_embedding.py b/sgl-kernel/tests/test_rotary_embedding.py new file mode 100644 index 00000000000..b7a141404e6 --- /dev/null +++ b/sgl-kernel/tests/test_rotary_embedding.py @@ -0,0 +1,202 @@ +import math +from typing import Any, Dict, List, Optional, Tuple, Union + +import pytest +import torch +import torch.nn as nn +from sgl_kernel import apply_rope_with_cos_sin_cache_inplace + + +# vLLM torch native +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + +class RotaryEmbedding(torch.nn.Module): + # Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + if offsets is not None: + positions = positions + offsets + + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + + # Modification: float32 is required for the rotary embedding to work correctly + query = query.to(torch.float32) + key = key.to(torch.float32) + + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + + # Modification: convert to the correct dtype + query = query.to(self.dtype) + key = key.to(self.dtype) + return query, key + + +class FlashInferRotaryEmbedding(RotaryEmbedding): + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + + apply_rope_with_cos_sin_cache_inplace( + positions=positions, + query=query, + key=key, + head_size=self.head_size, + cos_sin_cache=self.cos_sin_cache, + is_neox=self.is_neox_style, + ) + + return query, key + + +@pytest.mark.parametrize( + "head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads", + [ + (64, 64, 32, 8000, True, torch.bfloat16, "cuda", 32, 32, 1, 1), + (256, 128, 4096, 10000, True, torch.bfloat16, "cuda", 2, 512, 4, 2), + (512, 128, 311, 10000, True, torch.bfloat16, "cuda", 3, 39, 4, 2), + (128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 32, 8), + (128, 128, 2048, 10000, False, torch.bfloat16, "cuda", 2, 512, 16, 4), + (512, 128, 311, 10000, False, torch.bfloat16, "cuda", 3, 39, 4, 2), + ], +) +def test_correctness( + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + device: str, + batch_size: int, + seq_len: int, + num_q_heads: int, + num_kv_heads: int, +): + rope_ref = RotaryEmbedding( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ).to(device) + rope_flashinfer = FlashInferRotaryEmbedding( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ).to(device) + + pos_ids = torch.arange(seq_len, device=device).repeat(batch_size) + query = torch.randn( + batch_size * seq_len, num_q_heads * head_size, dtype=dtype, device=device + ) + key = torch.randn( + batch_size * seq_len, num_kv_heads * head_size, dtype=dtype, device=device + ) + + query_ref, key_ref = query.clone(), key.clone() + query_flashinfer, key_flashinfer = query.clone(), key.clone() + + query_ref_out, key_ref_out = rope_ref.forward_native(pos_ids, query_ref, key_ref) + query_flashinfer_out, key_flashinfer_out = rope_flashinfer.forward_cuda( + pos_ids, query_flashinfer, key_flashinfer + ) + + print(query_ref_out) + print(query_flashinfer_out) + + torch.testing.assert_close( + query_ref_out, query_flashinfer_out, atol=1e-2, rtol=1e-2 + ) + torch.testing.assert_close(key_ref_out, key_flashinfer_out, atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_sampling.py b/sgl-kernel/tests/test_sampling.py new file mode 100644 index 00000000000..7d3bc5059ee --- /dev/null +++ b/sgl-kernel/tests/test_sampling.py @@ -0,0 +1,141 @@ +# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/93e1a2634e22355b0856246b032b285ad1d1da6b/tests/test_sampling.py + +import pytest +import sgl_kernel +import torch + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("p", [0.1, 0.5]) +def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p): + torch.manual_seed(42) + if p == 0.1: + k = int(vocab_size * 0.5) + elif p == 0.5: + k = int(vocab_size * 0.1) + else: + raise ValueError("p not recognized") + max_top_k_trails = 32 + eps = 1e-4 + pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + # top-p mask + sorted_prob, indices = torch.sort(normalized_prob, descending=False) + cdf = torch.cumsum(sorted_prob, dim=-1) + mask_top_p = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0) + mask_top_p.scatter_add_(1, indices, (cdf > (1 - p) - eps).int()) + # top-k mask + sorted_prob, _ = torch.sort(normalized_prob, descending=True) + pivot = sorted_prob[:, k - 1] + mask_top_k = (normalized_prob >= pivot.unsqueeze(-1)).int() + # overall mask + mask = torch.minimum(mask_top_p, mask_top_k) + uniform_samples = torch.empty(max_top_k_trails, batch_size, dtype=torch.float32).to( + 0 + ) + top_p_tensor = torch.full((batch_size,), p).to(0) + top_k_tensor = torch.full((batch_size,), k).to(0) + + num_trails = 1000 + for _ in range(num_trails): + uniform_samples.uniform_() + samples, success = sgl_kernel.top_k_top_p_sampling_from_probs( + normalized_prob, + uniform_samples, + top_k_tensor, + top_p_tensor, + filter_apply_order="joint", + ) + assert torch.all(success) + assert torch.all(samples < vocab_size) and torch.all(samples >= 0) + assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[ + torch.arange(batch_size), samples + ] + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("p", [0.1, 0.5, 0.9]) +def test_top_p_renorm_probs(batch_size, vocab_size, p): + pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + sorted_prob, indices = torch.sort(normalized_prob, descending=False) + cdf = torch.cumsum(sorted_prob, dim=-1) + mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0) + mask.scatter_add_(1, indices, (cdf >= (1 - p)).int()) + renorm_prob_ground_truth = normalized_prob + renorm_prob_ground_truth[mask == 0] = 0 + renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum( + dim=-1, keepdim=True + ) + + renorm_prob = sgl_kernel.top_p_renorm_prob(normalized_prob, p) + torch.testing.assert_close( + renorm_prob_ground_truth, + renorm_prob, + rtol=1e-3, + atol=1e-3, + ) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("k", [10, 100, 500]) +def test_top_k_renorm_probs(batch_size, vocab_size, k): + if k > vocab_size: + pytest.skip("k should be less than vocab_size") + torch.manual_seed(42) + pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + sorted_prob, _ = torch.sort(normalized_prob, descending=True) + pivot = sorted_prob[:, k - 1] + mask = (normalized_prob >= pivot.unsqueeze(-1)).int() + renorm_prob_ground_truth = normalized_prob + renorm_prob_ground_truth[mask == 0] = 0 + renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum( + dim=-1, keepdim=True + ) + + renorm_prob = sgl_kernel.top_k_renorm_prob(normalized_prob, k) + torch.testing.assert_close( + renorm_prob_ground_truth, + renorm_prob, + rtol=1e-3, + atol=1e-3, + ) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("p", [0.05, 0.1, 0.2, 0.7, 1]) +def test_min_p_sampling(batch_size, vocab_size, p): + torch.manual_seed(42) + pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + sorted_prob, indices = torch.sort(normalized_prob, descending=False) + # scale min-p + top_probs = sorted_prob[:, -1].unsqueeze(-1) + scaled_p = p * top_probs + # min-p mask + mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0) + mask.scatter_add_(1, indices, (sorted_prob >= scaled_p).int()) + uniform_samples = torch.empty(batch_size, dtype=torch.float32).to(0) + min_p_tensor = torch.full((batch_size,), p).to(0) + + num_trails = 1000 + for _ in range(num_trails): + uniform_samples.uniform_() + samples = sgl_kernel.min_p_sampling_from_probs( + normalized_prob, + uniform_samples, + min_p_tensor, + ) + + assert torch.all(mask[torch.arange(batch_size), samples] == 1), samples[ + torch.nonzero(mask[torch.arange(batch_size), samples] == 0) + ] + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_trt_reduce.py b/sgl-kernel/tests/test_trt_reduce.py new file mode 100644 index 00000000000..b79580070c0 --- /dev/null +++ b/sgl-kernel/tests/test_trt_reduce.py @@ -0,0 +1,246 @@ +import ctypes +import logging +import os +import random +import socket +import time +import unittest +from typing import Any, List, Optional, Union + +import ray +import torch +import torch.distributed as dist +from sgl_kernel import ops as custom_ops +from torch.distributed import ProcessGroup +from vllm import _custom_ops as vllm_ops + +from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary + +logger = logging.getLogger(__name__) + + +def get_open_port() -> int: + # try ipv4 + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + except OSError: + # try ipv6 + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +def multi_process_parallel( + world_size: int, + cls: Any, + test_target: Any, +) -> None: + # Using ray helps debugging the error when it failed + # as compared to multiprocessing. + # NOTE: We need to set working_dir for distributed tests, + # otherwise we may get import errors on ray workers + ray.init(log_to_driver=True) + + distributed_init_port = get_open_port() + refs = [] + for rank in range(world_size): + refs.append(test_target.remote(cls, world_size, rank, distributed_init_port)) + ray.get(refs) + + ray.shutdown() + + +class TestCustomAllReduce(unittest.TestCase): + @classmethod + def setUpClass(cls): + random.seed(42) + cls.test_sizes = [512, 4096, 32768, 262144, 524288, 1048576, 2097152] + cls.world_sizes = [2, 4, 8] + + @staticmethod + def create_shared_buffer( + size_in_bytes: int, group: Optional[ProcessGroup] = None + ) -> List[int]: + """ + Creates a shared buffer and returns a list of pointers + representing the buffer on all processes in the group. + """ + lib = CudaRTLibrary() + pointer = lib.cudaMalloc(size_in_bytes) + handle = lib.cudaIpcGetMemHandle(pointer) + world_size = dist.get_world_size(group=group) + rank = dist.get_rank(group=group) + handles = [None] * world_size + dist.all_gather_object(handles, handle, group=group) + + pointers: List[int] = [] + for i, h in enumerate(handles): + if i == rank: + pointers.append(pointer.value) # type: ignore + else: + pointers.append(lib.cudaIpcOpenMemHandle(h).value) # type: ignore + + return pointers + + @staticmethod + def free_shared_buffer( + pointers: List[int], group: Optional[ProcessGroup] = None + ) -> None: + rank = dist.get_rank(group=group) + lib = CudaRTLibrary() + lib.cudaFree(ctypes.c_void_p(pointers[rank])) + + def test_correctness(self): + for world_size in self.world_sizes: + if world_size > torch.cuda.device_count(): + continue + multi_process_parallel(world_size, self, self.correctness) + + def test_performance(self): + for world_size in self.world_sizes: + if world_size > torch.cuda.device_count(): + continue + multi_process_parallel(world_size, self, self.performance) + + def init_custom_allreduce(self, rank, world_size, group): + buffer_max_size = 8 * 1024 * 1024 + barrier_max_size = 8 * (24 + 2) * 8 + + self.buffer_ptrs = self.create_shared_buffer(buffer_max_size, group=group) + self.tmp_result_buffer_ptrs = self.create_shared_buffer( + buffer_max_size, group=group + ) + self.barrier_in_ptrs = self.create_shared_buffer(barrier_max_size, group=group) + self.barrier_out_ptrs = self.create_shared_buffer(barrier_max_size, group=group) + self.rank_data = torch.empty( + 8 * 1024 * 1024, dtype=torch.uint8, device=torch.device(f"cuda:{rank}") + ) + + self.custom_ptr = custom_ops.init_custom_reduce( + rank, + world_size, + self.rank_data, + self.buffer_ptrs, + self.tmp_result_buffer_ptrs, + self.barrier_in_ptrs, + self.barrier_out_ptrs, + ) + + def custom_allreduce(self, inp, out): + custom_ops.custom_reduce(self.custom_ptr, inp, out) + + def free_custom_allreduce(self, group): + self.free_shared_buffer(self.buffer_ptrs, group) + self.free_shared_buffer(self.tmp_result_buffer_ptrs, group) + self.free_shared_buffer(self.barrier_in_ptrs, group) + self.free_shared_buffer(self.barrier_out_ptrs, group) + custom_ops.custom_dispose(self.custom_ptr) + + def init_vllm_allreduce(self, rank, group): + self.vllm_rank = rank + self.vllm_max_size = 8 * 1024 * 1024 + self.vllm_meta_ptrs = self.create_shared_buffer( + vllm_ops.meta_size() + self.vllm_max_size, group=group + ) + self.vllm_buffer_ptrs = self.create_shared_buffer( + self.vllm_max_size, group=group + ) + self.vllm_rank_data = torch.empty( + 8 * 1024 * 1024, dtype=torch.uint8, device=torch.device(f"cuda:{rank}") + ) + self.vllm_ptr = vllm_ops.init_custom_ar( + self.vllm_meta_ptrs, self.vllm_rank_data, rank, True + ) + vllm_ops.register_buffer(self.vllm_ptr, self.vllm_buffer_ptrs) + + def vllm_allreduce(self, inp, out): + vllm_ops.all_reduce( + self.vllm_ptr, + inp, + out, + self.vllm_buffer_ptrs[self.vllm_rank], + self.vllm_max_size, + ) + + def free_vllm_allreduce(self, group): + vllm_ops.dispose(self.vllm_ptr) + self.free_shared_buffer(self.vllm_meta_ptrs, group) + self.free_shared_buffer(self.vllm_buffer_ptrs, group) + + @staticmethod + def init_distributed_env(world_size, rank, distributed_init_port): + del os.environ["CUDA_VISIBLE_DEVICES"] + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + ranks = [i for i in range(world_size)] + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + dist.init_process_group( + backend="nccl", + init_method=distributed_init_method, + rank=rank, + world_size=world_size, + ) + group = torch.distributed.new_group(ranks, backend="gloo") + return group + + # compare result with torch.distributed + @ray.remote(num_gpus=1, max_calls=1) + def correctness(self, world_size, rank, distributed_init_port): + group = self.init_distributed_env(world_size, rank, distributed_init_port) + + self.init_custom_allreduce(rank=rank, world_size=world_size, group=group) + + test_loop = 10 + for sz in self.test_sizes: + for dtype in [torch.float32, torch.float16, torch.bfloat16]: + for _ in range(test_loop): + inp1 = torch.randint( + 1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device() + ) + out1 = torch.empty_like(inp1) + self.custom_allreduce(inp1, out1) + + dist.all_reduce(inp1, group=group) + torch.testing.assert_close(out1, inp1) + + self.free_custom_allreduce(group) + + # compare performance with vllm + @ray.remote(num_gpus=1, max_calls=1) + def performance(self, world_size, rank, distributed_init_port): + group = self.init_distributed_env(world_size, rank, distributed_init_port) + + self.init_vllm_allreduce(rank, group) + self.init_custom_allreduce(rank=rank, world_size=world_size, group=group) + + for sz in self.test_sizes: + inp1 = torch.randint( + 1, 16, (sz,), dtype=torch.float32, device=torch.cuda.current_device() + ) + out1 = torch.empty_like(inp1) + test_loop = 5000 + start = time.time() + for _ in range(test_loop): + self.custom_allreduce(inp1, out1) + elapse_custom = time.time() - start + + start = time.time() + for _ in range(test_loop): + self.vllm_allreduce(inp1, out1) + elapse_vllm = time.time() - start + + if rank == 0: + logger.warning( + f"test_size = {sz}, world_size = {world_size}, " + f"vllm time = {elapse_vllm * 1000 / test_loop:.4f}ms," + f"custom time = {elapse_custom * 1000 / test_loop:.4f}ms" + ) + + self.free_custom_allreduce(group) + self.free_vllm_allreduce(group) + + +if __name__ == "__main__": + unittest.main() diff --git a/sgl-kernel/version.py b/sgl-kernel/version.py new file mode 100644 index 00000000000..647733203b6 --- /dev/null +++ b/sgl-kernel/version.py @@ -0,0 +1 @@ +__version__ = "0.0.3.post1" diff --git a/rust/Cargo.lock b/sgl-router/Cargo.lock similarity index 99% rename from rust/Cargo.lock rename to sgl-router/Cargo.lock index 37c2733fdc0..dc9c46a7146 100644 --- a/rust/Cargo.lock +++ b/sgl-router/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "actix-codec" @@ -851,6 +851,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -1986,6 +1987,7 @@ dependencies = [ "base64 0.22.1", "bytes", "encoding_rs", + "futures-channel", "futures-core", "futures-util", "h2 0.4.6", @@ -2219,6 +2221,7 @@ dependencies = [ "serde", "serde_json", "tokenizers", + "tokio", ] [[package]] @@ -2475,9 +2478,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.41.0" +version = "1.42.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "145f3413504347a2be84393cc8a7d2fb4d863b375909ea59f2158261aa258bbb" +checksum = "5cec9b21b0450273377fc97bd4c33a8acffc8c996c987a7c5b319a0083707551" dependencies = [ "backtrace", "bytes", diff --git a/rust/Cargo.toml b/sgl-router/Cargo.toml similarity index 86% rename from rust/Cargo.toml rename to sgl-router/Cargo.toml index 5ac77665bcc..2173dba086c 100644 --- a/rust/Cargo.toml +++ b/sgl-router/Cargo.toml @@ -3,10 +3,6 @@ name = "sglang_router_rs" version = "0.0.0" edition = "2021" -[[bin]] -name = "sglang_router_rs" -path = "src/main.rs" - [lib] name = "sglang_router_rs" # Pure Rust library: Just omit crate-type (defaults to rlib) @@ -19,7 +15,7 @@ serde = { version = "1.0", features = ["derive"] } clap = { version = "4.4", features = ["derive"] } bytes = "1.8.0" rand = "0.8.5" -reqwest = { version = "0.12.8", features = ["stream"] } +reqwest = { version = "0.12.8", features = ["stream", "blocking"] } futures-util = "0.3" serde_json = "1.0" pyo3 = { version = "0.22.5", features = ["extension-module"] } @@ -29,6 +25,7 @@ http = "1.1.0" env_logger = "0.11.5" log = "0.4.22" chrono = "0.4.38" +tokio = "1.42.0" [profile.release] lto = "thin" diff --git a/rust/MANIFEST.in b/sgl-router/MANIFEST.in similarity index 100% rename from rust/MANIFEST.in rename to sgl-router/MANIFEST.in diff --git a/sgl-router/README.md b/sgl-router/README.md new file mode 100644 index 00000000000..61c9e692c92 --- /dev/null +++ b/sgl-router/README.md @@ -0,0 +1,97 @@ +# SGLang Router + +SGLang router is a standalone module implemented in Rust to achieve data parallelism across SGLang instances. + +## User docs + +Please check https://docs.sglang.ai/router/router.html + +## Developer docs + +### Prerequisites + +- Rust and Cargo installed + +```bash +# Install rustup (Rust installer and version manager) +curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh + +# Follow the installation prompts, then reload your shell +source $HOME/.cargo/env + +# Verify installation +rustc --version +cargo --version +``` + +- Python with pip installed + + +### Build Process + +#### 1. Build Rust Project + +```bash +$ cargo build +``` + +#### 2. Build Python Binding + +##### Option A: Build and Install Wheel +1. Build the wheel package: +```bash +$ pip install setuptools-rust wheel build +$ python -m build +``` + +2. Install the generated wheel: +```bash +$ pip install +``` + +If you want one handy command to do build + install for every change you make: + +```bash +$ python -m build && pip install --force-reinstall dist/*.whl +``` + +##### Option B: Development Mode + +For development purposes, you can install the package in editable mode: + +Warning: Using editable python binding can suffer from performance degradation!! Please build a fresh wheel for every update if you want to test performance. + +```bash +$ pip install -e . +``` + +**Note:** When modifying Rust code, you must rebuild the wheel for changes to take effect. + +### Troubleshooting + +1. If rust analyzer is not working in VSCode, set `rust-analyzer.linkedProjects` to the absolute path of `Cargo.toml` in your repo. For example: + +```json +{ + "rust-analyzer.linkedProjects": ["/workspaces/sglang/sgl-router/Cargo.toml"] +} +``` + +### CI/CD Setup + +The continuous integration pipeline consists of three main steps: + +#### 1. Build Wheels +- Uses `cibuildwheel` to create manylinux x86_64 packages +- Compatible with major Linux distributions (Ubuntu, CentOS, etc.) +- Additional configurations can be added to support other OS/architectures +- Reference: [cibuildwheel documentation](https://cibuildwheel.pypa.io/en/stable/) + +#### 2. Build Source Distribution +- Creates a source distribution containing the raw, unbuilt code +- Enables `pip` to build the package from source when prebuilt wheels are unavailable + +#### 3. Publish to PyPI +- Uploads both wheels and source distribution to PyPI + +The CI configuration is based on the [tiktoken workflow](https://github.com/openai/tiktoken/blob/63527649963def8c759b0f91f2eb69a40934e468/.github/workflows/build_wheels.yml#L1). diff --git a/rust/py_src/sglang_router/__init__.py b/sgl-router/py_src/sglang_router/__init__.py similarity index 52% rename from rust/py_src/sglang_router/__init__.py rename to sgl-router/py_src/sglang_router/__init__.py index ec41b5d0c74..081740479ca 100644 --- a/rust/py_src/sglang_router/__init__.py +++ b/sgl-router/py_src/sglang_router/__init__.py @@ -1,5 +1,7 @@ # a lightweihgt wrapper on router with argument type and comments +# no wrapper on policy type => direct export +from sglang_router.router import Router +from sglang_router.version import __version__ from sglang_router_rs import PolicyType -# no wrapper on policy type => direct export -from .router import Router +__all__ = ["Router", "PolicyType", "__version__"] diff --git a/rust/py_src/sglang_router/launch_router.py b/sgl-router/py_src/sglang_router/launch_router.py similarity index 82% rename from rust/py_src/sglang_router/launch_router.py rename to sgl-router/py_src/sglang_router/launch_router.py index a8b6adf0388..38f1fbba2dc 100644 --- a/rust/py_src/sglang_router/launch_router.py +++ b/sgl-router/py_src/sglang_router/launch_router.py @@ -27,17 +27,20 @@ def setup_logger(): @dataclasses.dataclass class RouterArgs: # Worker configuration - worker_urls: List[str] + worker_urls: List[str] = dataclasses.field(default_factory=list) host: str = "127.0.0.1" port: int = 30000 # Routing policy policy: str = "cache_aware" + worker_startup_timeout_secs: int = 300 + worker_startup_check_interval: int = 10 cache_threshold: float = 0.5 balance_abs_threshold: int = 32 balance_rel_threshold: float = 1.0001 eviction_interval: int = 60 max_tree_size: int = 2**24 + max_payload_size: int = 4 * 1024 * 1024 # 4MB verbose: bool = False @staticmethod @@ -86,6 +89,18 @@ def add_cli_args( choices=["random", "round_robin", "cache_aware"], help="Load balancing policy to use", ) + parser.add_argument( + f"--{prefix}worker-startup-timeout-secs", + type=int, + default=RouterArgs.worker_startup_timeout_secs, + help="Timeout in seconds for worker startup", + ) + parser.add_argument( + f"--{prefix}worker-startup-check-interval", + type=int, + default=RouterArgs.worker_startup_check_interval, + help="Interval in seconds between checks for worker startup", + ) parser.add_argument( f"--{prefix}cache-threshold", type=float, @@ -116,6 +131,12 @@ def add_cli_args( default=RouterArgs.max_tree_size, help="Maximum size of the approximation tree for cache-aware routing", ) + parser.add_argument( + f"--{prefix}max-payload-size", + type=int, + default=RouterArgs.max_payload_size, + help="Maximum payload size in bytes", + ) parser.add_argument( f"--{prefix}verbose", action="store_true", @@ -134,16 +155,24 @@ def from_cli_args( use_router_prefix: If True, look for arguments with 'router-' prefix """ prefix = "router_" if use_router_prefix else "" + worker_urls = args.worker_urls if args.worker_urls is not None else [] return cls( - worker_urls=args.worker_urls, + worker_urls=worker_urls, host=args.host, port=args.port, policy=getattr(args, f"{prefix}policy"), + worker_startup_timeout_secs=getattr( + args, f"{prefix}worker_startup_timeout_secs" + ), + worker_startup_check_interval=getattr( + args, f"{prefix}worker_startup_check_interval" + ), cache_threshold=getattr(args, f"{prefix}cache_threshold"), balance_abs_threshold=getattr(args, f"{prefix}balance_abs_threshold"), balance_rel_threshold=getattr(args, f"{prefix}balance_rel_threshold"), eviction_interval=getattr(args, f"{prefix}eviction_interval"), max_tree_size=getattr(args, f"{prefix}max_tree_size"), + max_payload_size=getattr(args, f"{prefix}max_payload_size"), verbose=getattr(args, f"{prefix}verbose", False), ) @@ -179,14 +208,17 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: router = Router( worker_urls=router_args.worker_urls, - policy=policy_from_str(router_args.policy), host=router_args.host, port=router_args.port, + policy=policy_from_str(router_args.policy), + worker_startup_timeout_secs=router_args.worker_startup_timeout_secs, + worker_startup_check_interval=router_args.worker_startup_check_interval, cache_threshold=router_args.cache_threshold, balance_abs_threshold=router_args.balance_abs_threshold, balance_rel_threshold=router_args.balance_rel_threshold, eviction_interval_secs=router_args.eviction_interval, max_tree_size=router_args.max_tree_size, + max_payload_size=router_args.max_payload_size, verbose=router_args.verbose, ) @@ -194,8 +226,8 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: return router except Exception as e: - logger.error(f"Error starting router: {e}", file=sys.stderr) - return None + logger.error(f"Error starting router: {e}") + raise e class CustomHelpFormatter( @@ -228,12 +260,8 @@ def parse_router_args(args: List[str]) -> RouterArgs: def main() -> None: - logger = setup_logger() router_args = parse_router_args(sys.argv[1:]) - router = launch_router(router_args) - - if router is None: - sys.exit(1) + launch_router(router_args) if __name__ == "__main__": diff --git a/rust/py_src/sglang_router/launch_server.py b/sgl-router/py_src/sglang_router/launch_server.py similarity index 56% rename from rust/py_src/sglang_router/launch_server.py rename to sgl-router/py_src/sglang_router/launch_server.py index ec86e8b2adb..74353c21edb 100644 --- a/rust/py_src/sglang_router/launch_server.py +++ b/sgl-router/py_src/sglang_router/launch_server.py @@ -10,12 +10,12 @@ from typing import List import requests +from setproctitle import setproctitle from sglang_router.launch_router import RouterArgs, launch_router -from sglang.srt.server import launch_server +from sglang.srt.entrypoints.http_server import launch_server from sglang.srt.server_args import ServerArgs from sglang.srt.utils import is_port_available -from sglang.utils import get_exception_traceback def setup_logger(): @@ -23,7 +23,7 @@ def setup_logger(): logger.setLevel(logging.INFO) formatter = logging.Formatter( - "[Router (Python)] %(asctime)s - %(levelname)s - %(message)s", + "[Router (Python)] %(asctime)s - %(levelname)s - %(message)s - %(filename)s:%(lineno)d", datefmt="%Y-%m-%d %H:%M:%S", ) @@ -34,10 +34,41 @@ def setup_logger(): return logger +logger = setup_logger() + + # Create new process group def run_server(server_args, dp_rank): - os.setpgrp() # Create new process group - + """ + Note: + + 1. Without os.setpgrp(), all processes share the same PGID. When you press Ctrl+C, the terminal sends SIGINT to all processes in the group simultaneously. + This can cause leaf processes to terminate first, which messes up the cleaning order and produces orphaned processes. + + Terminal (PGID=100) + └── Main Python Process (PGID=100) + └── Server Process 1 (PGID=100) + └── Scheduler 1 + └── Detokenizer 1 + └── Server Process 2 (PGID=100) + └── Scheduler 2 + └── Detokenizer 2 + + 2. With os.setpgrp(), the main Python process and its children are in a separate group. Now: + + Terminal (PGID=100) + └── Main Python Process (PGID=200) + └── Server Process 1 (PGID=300) + └── Scheduler 1 + └── Detokenizer 1 + └── Server Process 2 (PGID=400) + └── Scheduler 2 + └── Detokenizer 2 + """ + # create new process group + os.setpgrp() + + setproctitle("sglang::server") # Set SGLANG_DP_RANK environment variable os.environ["SGLANG_DP_RANK"] = str(dp_rank) @@ -58,36 +89,6 @@ def launch_server_process( return proc -def cleanup_processes(processes: List[mp.Process]): - logger = logging.getLogger("router") - logger.info("Cleaning up processes...") - for proc in processes: - if proc.is_alive(): - try: - os.killpg(os.getpgid(proc.pid), signal.SIGTERM) - proc.join(timeout=3) - if proc.is_alive(): - logger.warning( - f"Process {proc.pid} did not terminate gracefully, force killing..." - ) - os.killpg(os.getpgid(proc.pid), signal.SIGKILL) - except ProcessLookupError: - pass - - -def setup_signal_handlers(cleanup_func): - """Setup handlers for various termination signals.""" - - def signal_handler(signum, frame): - cleanup_func() - sys.exit(1) - - signal.signal(signal.SIGTERM, signal_handler) - signal.signal(signal.SIGINT, signal_handler) - if hasattr(signal, "SIGQUIT"): - signal.signal(signal.SIGQUIT, signal_handler) - - def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool: """Wait for server to be healthy by checking /health endpoint.""" start_time = time.time() @@ -117,9 +118,31 @@ def find_available_ports(base_port: int, count: int) -> List[int]: return available_ports -def main(): - logger = setup_logger() +def cleanup_processes(processes: List[mp.Process]): + for process in processes: + logger.info(f"Terminating process group {process.pid}") + try: + os.killpg(process.pid, signal.SIGTERM) + except ProcessLookupError: + # Process group may already be terminated + pass + # Wait for processes to terminate + for process in processes: + process.join(timeout=5) + if process.is_alive(): + logger.warning( + f"Process {process.pid} did not terminate gracefully, forcing kill" + ) + try: + os.killpg(process.pid, signal.SIGKILL) + except ProcessLookupError: + pass + + logger.info("All process groups terminated") + + +def main(): # CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes mp.set_start_method("spawn") @@ -148,52 +171,31 @@ def main(): # Start server processes server_processes = [] - try: - for i, worker_port in enumerate(worker_ports): - logger.info(f"Launching DP server process {i} on port {worker_port}") - proc = launch_server_process(server_args, worker_port, i) - server_processes.append(proc) - - # Setup cleanup handler - setup_signal_handlers(lambda: cleanup_processes(server_processes)) + for i, worker_port in enumerate(worker_ports): + logger.info(f"Launching DP server process {i} on port {worker_port}") + proc = launch_server_process(server_args, worker_port, i) + server_processes.append(proc) - # Wait for all servers to be healthy - all_healthy = True - - for port in worker_ports: - if not wait_for_server_health(server_args.host, port): - logger.error(f"Server on port {port} failed to become healthy") - all_healthy = False - break - - if not all_healthy: - logger.error("Not all servers are healthy. Shutting down...") - cleanup_processes(server_processes) - sys.exit(1) - - logger.info("All servers are healthy. Starting router...") - - # Update router args with worker URLs - router_args.worker_urls = [ - f"http://{server_args.host}:{port}" for port in worker_ports - ] - - # Start the router - router = launch_router(router_args) + signal.signal(signal.SIGINT, lambda sig, frame: cleanup_processes(server_processes)) + signal.signal( + signal.SIGTERM, lambda sig, frame: cleanup_processes(server_processes) + ) + signal.signal( + signal.SIGQUIT, lambda sig, frame: cleanup_processes(server_processes) + ) - if router is None: - logger.error("Failed to start router. Shutting down...") - cleanup_processes(server_processes) - sys.exit(1) + # Update router args with worker URLs + router_args.worker_urls = [ + f"http://{server_args.host}:{port}" for port in worker_ports + ] - except KeyboardInterrupt: - logger.info("Received shutdown signal...") + # Start the router + try: + launch_router(router_args) except Exception as e: - logger.error(f"Error occurred: {e}") - logger.error(get_exception_traceback()) - finally: - logger.info("Cleaning up processes...") + logger.error(f"Failed to start router: {e}") cleanup_processes(server_processes) + sys.exit(1) if __name__ == "__main__": diff --git a/rust/py_src/sglang_router/router.py b/sgl-router/py_src/sglang_router/router.py similarity index 82% rename from rust/py_src/sglang_router/router.py rename to sgl-router/py_src/sglang_router/router.py index 91d608b749e..b8757168b24 100644 --- a/rust/py_src/sglang_router/router.py +++ b/sgl-router/py_src/sglang_router/router.py @@ -17,6 +17,8 @@ class Router: - PolicyType.CacheAware: Distribute requests based on cache state and load balance host: Host address to bind the router server. Default: '127.0.0.1' port: Port number to bind the router server. Default: 3001 + worker_startup_timeout_secs: Timeout in seconds for worker startup. Default: 300 + worker_startup_check_interval: Interval in seconds between checks for worker initialization. Default: 10 cache_threshold: Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker if the match rate exceeds threshold, otherwise routes to the worker with the smallest tree. Default: 0.5 @@ -26,6 +28,7 @@ class Router: AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 1.0001 eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware routing. Default: 60 + max_payload_size: Maximum payload size in bytes. Default: 4MB max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24 verbose: Enable verbose logging. Default: False """ @@ -36,11 +39,14 @@ def __init__( policy: PolicyType = PolicyType.RoundRobin, host: str = "127.0.0.1", port: int = 3001, + worker_startup_timeout_secs: int = 300, + worker_startup_check_interval: int = 10, cache_threshold: float = 0.50, balance_abs_threshold: int = 32, balance_rel_threshold: float = 1.0001, eviction_interval_secs: int = 60, max_tree_size: int = 2**24, + max_payload_size: int = 4 * 1024 * 1024, # 4MB verbose: bool = False, ): self._router = _Router( @@ -48,11 +54,14 @@ def __init__( policy=policy, host=host, port=port, + worker_startup_timeout_secs=worker_startup_timeout_secs, + worker_startup_check_interval=worker_startup_check_interval, cache_threshold=cache_threshold, balance_abs_threshold=balance_abs_threshold, balance_rel_threshold=balance_rel_threshold, eviction_interval_secs=eviction_interval_secs, max_tree_size=max_tree_size, + max_payload_size=max_payload_size, verbose=verbose, ) diff --git a/sgl-router/py_src/sglang_router/version.py b/sgl-router/py_src/sglang_router/version.py new file mode 100644 index 00000000000..bbab0242f6a --- /dev/null +++ b/sgl-router/py_src/sglang_router/version.py @@ -0,0 +1 @@ +__version__ = "0.1.4" diff --git a/rust/py_test/run_suite.py b/sgl-router/py_test/run_suite.py similarity index 100% rename from rust/py_test/run_suite.py rename to sgl-router/py_test/run_suite.py diff --git a/rust/py_test/test_launch_router.py b/sgl-router/py_test/test_launch_router.py similarity index 63% rename from rust/py_test/test_launch_router.py rename to sgl-router/py_test/test_launch_router.py index 787c091bf9c..27ed64d6e66 100644 --- a/rust/py_test/test_launch_router.py +++ b/sgl-router/py_test/test_launch_router.py @@ -22,22 +22,32 @@ def terminate_process(process: multiprocessing.Process, timeout: float = 1.0) -> class TestLaunchRouter(unittest.TestCase): - def test_launch_router_no_exception(self): - - # Create SimpleNamespace with default arguments - args = SimpleNamespace( - worker_urls=["http://localhost:8000"], + def setUp(self): + """Set up default arguments for router tests.""" + self.default_args = SimpleNamespace( host="127.0.0.1", port=30000, policy="cache_aware", + worker_startup_timeout_secs=600, + worker_startup_check_interval=10, cache_threshold=0.5, balance_abs_threshold=32, balance_rel_threshold=1.0001, eviction_interval=60, max_tree_size=2**24, + max_payload_size=4 * 1024 * 1024, # 4MB verbose=False, ) + def create_router_args(self, **kwargs): + """Create router arguments by updating default args with provided kwargs.""" + args_dict = vars(self.default_args).copy() + args_dict.update(kwargs) + return SimpleNamespace(**args_dict) + + def run_router_process(self, args): + """Run router in a separate process and verify it starts successfully.""" + def run_router(): try: from sglang_router.launch_router import launch_router @@ -50,7 +60,6 @@ def run_router(): print(e) return 1 - # Start router in separate process process = multiprocessing.Process(target=run_router) try: process.start() @@ -61,6 +70,14 @@ def run_router(): finally: terminate_process(process) + def test_launch_router_common(self): + args = self.create_router_args(worker_urls=["http://localhost:8000"]) + self.run_router_process(args) + + def test_launch_router_with_empty_worker_urls(self): + args = self.create_router_args(worker_urls=[]) + self.run_router_process(args) + if __name__ == "__main__": unittest.main() diff --git a/sgl-router/py_test/test_launch_server.py b/sgl-router/py_test/test_launch_server.py new file mode 100644 index 00000000000..80659fc4f3e --- /dev/null +++ b/sgl-router/py_test/test_launch_server.py @@ -0,0 +1,394 @@ +import socket +import subprocess +import time +import unittest +from types import SimpleNamespace + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, +) + + +def popen_launch_router( + model: str, + base_url: str, + dp_size: int, + timeout: float, + policy: str = "cache_aware", + max_payload_size: int = None, + api_key: str = None, +): + """ + Launch the router server process. + + Args: + model: Model path/name + base_url: Server base URL + dp_size: Data parallel size + timeout: Server launch timeout + policy: Router policy, one of "cache_aware", "round_robin", "random" + max_payload_size: Maximum payload size in bytes + api_key: API key for the router + """ + _, host, port = base_url.split(":") + host = host[2:] + + command = [ + "python3", + "-m", + "sglang_router.launch_server", + "--model-path", + model, + "--host", + host, + "--port", + port, + "--dp", + str(dp_size), + "--router-eviction-interval", + "5", + "--router-policy", + policy, + ] + + if api_key is not None: + command.extend(["--api-key", api_key]) + + if max_payload_size is not None: + command.extend(["--router-max-payload-size", str(max_payload_size)]) + + process = subprocess.Popen(command, stdout=None, stderr=None) + + start_time = time.time() + with requests.Session() as session: + while time.time() - start_time < timeout: + try: + response = session.get(f"{base_url}/health") + if response.status_code == 200: + print(f"Router {base_url} is healthy") + return process + except requests.RequestException: + pass + time.sleep(10) + + raise TimeoutError("Router failed to start within the timeout period.") + + +def find_available_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +def popen_launch_server( + model: str, + base_url: str, + timeout: float, +): + _, host, port = base_url.split(":") + host = host[2:] + + command = [ + "python3", + "-m", + "sglang.launch_server", + "--model-path", + model, + "--host", + host, + "--port", + port, + "--base-gpu-id", + "1", + ] + + process = subprocess.Popen(command, stdout=None, stderr=None) + + # intentionally don't wait and defer the job to the router health check + return process + + +def terminate_and_wait(process, timeout=300): + """Terminate a process and wait until it is terminated. + + Args: + process: subprocess.Popen object + timeout: maximum time to wait in seconds + + Raises: + TimeoutError: if process does not terminate within timeout + """ + if process is None: + return + + process.terminate() + start_time = time.time() + + while process.poll() is None: + print(f"Terminating process {process.pid}") + if time.time() - start_time > timeout: + raise TimeoutError( + f"Process {process.pid} failed to terminate within {timeout}s" + ) + time.sleep(1) + + print(f"Process {process.pid} is successfully terminated") + + +class TestLaunchServer(unittest.TestCase): + def setUp(self): + self.model = DEFAULT_MODEL_NAME_FOR_TEST + self.base_url = DEFAULT_URL_FOR_TEST + self.process = None + self.other_process = [] + + def tearDown(self): + print("Running tearDown...") + if self.process: + terminate_and_wait(self.process) + for process in self.other_process: + terminate_and_wait(process) + print("tearDown done") + + def test_1_mmlu(self): + print("Running test_1_mmlu...") + # DP size = 2 + self.process = popen_launch_router( + self.model, + self.base_url, + dp_size=2, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + policy="cache_aware", + ) + + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + temperature=0.1, + ) + + metrics = run_eval(args) + score = metrics["score"] + THRESHOLD = 0.65 + passed = score >= THRESHOLD + msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" + self.assertGreaterEqual(score, THRESHOLD, msg) + + def test_2_add_and_remove_worker(self): + print("Running test_2_add_and_remove_worker...") + # DP size = 1 + self.process = popen_launch_router( + self.model, + self.base_url, + dp_size=1, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + policy="round_robin", # use round robin to make sure every worker processes requests + ) + # 1. start a worker + port = find_available_port() + worker_url = f"http://127.0.0.1:{port}" + worker_process = popen_launch_server( + self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH + ) + self.other_process.append(worker_process) + + # 2. use /add_worker api to add it the the router. It will be used by router after it is healthy + with requests.Session() as session: + response = session.post(f"{self.base_url}/add_worker?url={worker_url}") + print(f"status code: {response.status_code}, response: {response.text}") + self.assertEqual(response.status_code, 200) + + # 3. run mmlu + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + temperature=0.1, + ) + metrics = run_eval(args) + score = metrics["score"] + THRESHOLD = 0.65 + passed = score >= THRESHOLD + msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" + self.assertGreaterEqual(score, THRESHOLD, msg) + + # 4. use /remove_worker api to remove it from the router + with requests.Session() as session: + response = session.post(f"{self.base_url}/remove_worker?url={worker_url}") + print(f"status code: {response.status_code}, response: {response.text}") + self.assertEqual(response.status_code, 200) + + # 5. run mmlu again + metrics = run_eval(args) + score = metrics["score"] + THRESHOLD = 0.65 + passed = score >= THRESHOLD + msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" + self.assertGreaterEqual(score, THRESHOLD, msg) + + def test_3_lazy_fault_tolerance(self): + print("Running test_3_lazy_fault_tolerance...") + # DP size = 1 + self.process = popen_launch_router( + self.model, + self.base_url, + dp_size=1, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + policy="round_robin", + ) + + # 1. start a worker + port = find_available_port() + worker_url = f"http://127.0.0.1:{port}" + worker_process = popen_launch_server( + self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH + ) + self.other_process.append(worker_process) + + # 2. use /add_worker api to add it the the router. It will be used by router after it is healthy + with requests.Session() as session: + response = session.post(f"{self.base_url}/add_worker?url={worker_url}") + print(f"status code: {response.status_code}, response: {response.text}") + self.assertEqual(response.status_code, 200) + + # Start a thread to kill the worker after 10 seconds to mimic abrupt worker failure + def kill_worker(): + time.sleep(10) + kill_process_tree(worker_process.pid) + print("Worker process killed") + + import threading + + kill_thread = threading.Thread(target=kill_worker) + kill_thread.daemon = True + kill_thread.start() + + # 3. run mmlu + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=256, + num_threads=32, + temperature=0.1, + ) + metrics = run_eval(args) + score = metrics["score"] + THRESHOLD = 0.65 + passed = score >= THRESHOLD + msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" + self.assertGreaterEqual(score, THRESHOLD, msg) + + def test_4_payload_size(self): + print("Running test_4_payload_size...") + # Start router with 3MB limit + self.process = popen_launch_router( + self.model, + self.base_url, + dp_size=1, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + policy="round_robin", + max_payload_size=1 * 1024 * 1024, # 1MB limit + ) + + # Test case 1: Payload just under 1MB should succeed + payload_0_5_mb = { + "text": "x" * int(0.5 * 1024 * 1024), # 0.5MB of text + "temperature": 0.0, + } + + with requests.Session() as session: + response = session.post( + f"{self.base_url}/generate", + json=payload_0_5_mb, + headers={"Content-Type": "application/json"}, + ) + self.assertEqual( + response.status_code, + 200, + f"0.5MB payload should succeed but got status {response.status_code}", + ) + + # Test case 2: Payload over 1MB should fail + payload_1_plus_mb = { + "text": "x" * int((1.2 * 1024 * 1024)), # 1.2MB of text + "temperature": 0.0, + } + + with requests.Session() as session: + response = session.post( + f"{self.base_url}/generate", + json=payload_1_plus_mb, + headers={"Content-Type": "application/json"}, + ) + self.assertEqual( + response.status_code, + 413, # Payload Too Large + f"1.2MB payload should fail with 413 but got status {response.status_code}", + ) + + def test_5_api_key(self): + print("Running test_5_api_key...") + + self.process = popen_launch_router( + self.model, + self.base_url, + dp_size=1, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + policy="round_robin", + api_key="correct_api_key", + ) + + # # Test case 1: request without api key should fail + with requests.Session() as session: + response = session.post( + f"{self.base_url}/generate", + json={"text": "Kanye west is, ", "temperature": 0}, + ) + print(f"status code: {response.status_code}, response: {response.text}") + self.assertEqual( + response.status_code, + 401, + "Request without api key should fail with 401", + ) + + # Test case 2: request with invalid api key should fail + with requests.Session() as session: + response = requests.post( + f"{self.base_url}/generate", + json={"text": "Kanye west is, ", "temperature": 0}, + headers={"Authorization": "Bearer 123"}, + ) + print(f"status code: {response.status_code}, response: {response.text}") + self.assertEqual( + response.status_code, + 401, + "Request with invalid api key should fail with 401", + ) + + # Test case 3: request with correct api key should succeed + with requests.Session() as session: + response = session.post( + f"{self.base_url}/generate", + json={"text": "Kanye west is ", "temperature": 0}, + headers={"Authorization": "Bearer correct_api_key"}, + ) + print(f"status code: {response.status_code}, response: {response.text}") + self.assertEqual( + response.status_code, 200, "Request with correct api key should succeed" + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/rust/pyproject.toml b/sgl-router/pyproject.toml similarity index 87% rename from rust/pyproject.toml rename to sgl-router/pyproject.toml index d1327d9203e..da5c44a1196 100644 --- a/rust/pyproject.toml +++ b/sgl-router/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sglang-router" -version = "0.0.11" +version = "0.1.4" description = "SGLang router is a standalone module implemented in Rust to achieve data parallelism across SGLang instances." authors = [{name = "Byron Hsu", email = "byronhsu1230@gmail.com"}] requires-python = ">=3.8" @@ -20,6 +20,10 @@ classifiers = [ [tool.setuptools.packages] find = { where = ["py_src"] } +# workaround for https://github.com/pypa/twine/issues/1216 +[tool.setuptools] +license-files = [] + [[tool.setuptools-rust.ext-modules]] target = "sglang_router_rs" path = "Cargo.toml" diff --git a/rust/src/lib.rs b/sgl-router/src/lib.rs similarity index 70% rename from rust/src/lib.rs rename to sgl-router/src/lib.rs index 63d5bfe324a..ba9aeac1fef 100644 --- a/rust/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -17,11 +17,14 @@ struct Router { port: u16, worker_urls: Vec, policy: PolicyType, + worker_startup_timeout_secs: u64, + worker_startup_check_interval: u64, cache_threshold: f32, balance_abs_threshold: usize, balance_rel_threshold: f32, eviction_interval_secs: u64, max_tree_size: usize, + max_payload_size: usize, verbose: bool, } @@ -33,11 +36,14 @@ impl Router { policy = PolicyType::RoundRobin, host = String::from("127.0.0.1"), port = 3001, + worker_startup_timeout_secs = 300, + worker_startup_check_interval = 10, cache_threshold = 0.50, balance_abs_threshold = 32, balance_rel_threshold = 1.0001, eviction_interval_secs = 60, max_tree_size = 2usize.pow(24), + max_payload_size = 4 * 1024 * 1024, verbose = false ))] fn new( @@ -45,11 +51,14 @@ impl Router { policy: PolicyType, host: String, port: u16, + worker_startup_timeout_secs: u64, + worker_startup_check_interval: u64, cache_threshold: f32, balance_abs_threshold: usize, balance_rel_threshold: f32, eviction_interval_secs: u64, max_tree_size: usize, + max_payload_size: usize, verbose: bool, ) -> PyResult { Ok(Router { @@ -57,20 +66,31 @@ impl Router { port, worker_urls, policy, + worker_startup_timeout_secs, + worker_startup_check_interval, cache_threshold, balance_abs_threshold, balance_rel_threshold, eviction_interval_secs, max_tree_size, + max_payload_size, verbose, }) } fn start(&self) -> PyResult<()> { let policy_config = match &self.policy { - PolicyType::Random => router::PolicyConfig::RandomConfig, - PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig, + PolicyType::Random => router::PolicyConfig::RandomConfig { + timeout_secs: self.worker_startup_timeout_secs, + interval_secs: self.worker_startup_check_interval, + }, + PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig { + timeout_secs: self.worker_startup_timeout_secs, + interval_secs: self.worker_startup_check_interval, + }, PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig { + timeout_secs: self.worker_startup_timeout_secs, + interval_secs: self.worker_startup_check_interval, cache_threshold: self.cache_threshold, balance_abs_threshold: self.balance_abs_threshold, balance_rel_threshold: self.balance_rel_threshold, @@ -86,12 +106,12 @@ impl Router { worker_urls: self.worker_urls.clone(), policy_config, verbose: self.verbose, + max_payload_size: self.max_payload_size, }) .await - .unwrap(); - }); - - Ok(()) + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + Ok(()) + }) } } diff --git a/sgl-router/src/router.rs b/sgl-router/src/router.rs new file mode 100644 index 00000000000..5ee34c59869 --- /dev/null +++ b/sgl-router/src/router.rs @@ -0,0 +1,809 @@ +use crate::tree::Tree; +use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; +use actix_web::{HttpRequest, HttpResponse}; +use bytes::Bytes; +use futures_util::{StreamExt, TryStreamExt}; +use log::{debug, error, info, warn}; +use std::collections::HashMap; +use std::fmt::Debug; +use std::sync::atomic::AtomicUsize; +use std::sync::{Arc, Mutex, RwLock}; +use std::thread; +use std::time::Duration; +use tokio; + +fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> { + req.headers() + .iter() + .filter_map(|(name, value)| { + value + .to_str() + .ok() + .map(|v| (name.to_string(), v.to_string())) + }) + .collect() +} + +#[derive(Debug)] +pub enum Router { + RoundRobin { + worker_urls: Arc>>, + current_index: AtomicUsize, + timeout_secs: u64, + interval_secs: u64, + }, + Random { + worker_urls: Arc>>, + timeout_secs: u64, + interval_secs: u64, + }, + CacheAware { + /* + Cache-Aware Load Balancing Router + + This router combines two strategies to optimize both cache utilization and request distribution: + + 1. Cache-Aware Routing (Approximate Tree) + 2. Load Balancing (Shortest Queue with Balance Thresholds) + + The router dynamically switches between these strategies based on load conditions: + - Uses load balancing when the system is imbalanced + - Uses cache-aware routing when the system is balanced + + A system is considered imbalanced if both conditions are met: + 1. (max - min) > abs_threshold + 2. max > rel_threshold * min + + Strategy Details: + + 1. Cache-Aware Routing (Approximate Tree) + ------------------------------------------- + This strategy maintains an approximate radix tree for each worker based on request history, + eliminating the need for direct cache state queries. The tree stores raw text characters + instead of token IDs to avoid tokenization overhead. + + Process: + a. For each request, find the worker with the highest prefix match + b. If match rate > cache_threshold: + Route to the worker with highest match (likely has relevant data cached) + c. If match rate ≤ cache_threshold: + Route to the worker with smallest tree size (most available cache capacity) + d. Background maintenance: + Periodically evict least recently used leaf nodes to prevent memory overflow + + 2. Load Balancing (Shortest Queue) + ------------------------------------------- + This strategy tracks pending request counts per worker and routes new requests + to the least busy worker when the system is detected to be imbalanced. + + Configuration Parameters: + ------------------------ + 1. cache_threshold: (float, 0.0 to 1.0) + Minimum prefix match ratio to use highest-match routing. + Below this threshold, routes to worker with most available cache space. + + 2. balance_abs_threshold: (integer) + Absolute difference threshold for load imbalance detection. + System is potentially imbalanced if (max_load - min_load) > abs_threshold + + 3. balance_rel_threshold: (float) + Relative ratio threshold for load imbalance detection. + System is potentially imbalanced if max_load > min_load * rel_threshold + Used in conjunction with abs_threshold to determine final imbalance state. + + 4. eviction_interval_secs: (integer) + Interval between LRU eviction cycles for the approximate trees. + + 5. max_tree_size: (integer) + Maximum nodes per tree. When exceeded, LRU leaf nodes are evicted + during the next eviction cycle. + */ + worker_urls: Arc>>, + tree: Arc>, + running_queue: Arc>>, + processed_queue: Arc>>, + cache_threshold: f32, + balance_abs_threshold: usize, + balance_rel_threshold: f32, + timeout_secs: u64, + interval_secs: u64, + _eviction_thread: Option>, + }, +} + +#[derive(Debug, Clone)] +pub enum PolicyConfig { + RandomConfig { + timeout_secs: u64, + interval_secs: u64, + }, + RoundRobinConfig { + timeout_secs: u64, + interval_secs: u64, + }, + CacheAwareConfig { + cache_threshold: f32, + balance_abs_threshold: usize, + balance_rel_threshold: f32, + eviction_interval_secs: u64, + max_tree_size: usize, + timeout_secs: u64, + interval_secs: u64, + }, +} + +impl Router { + pub fn new(worker_urls: Vec, policy_config: PolicyConfig) -> Result { + // Get timeout and interval from policy config + let (timeout_secs, interval_secs) = match &policy_config { + PolicyConfig::RandomConfig { + timeout_secs, + interval_secs, + } => (*timeout_secs, *interval_secs), + PolicyConfig::RoundRobinConfig { + timeout_secs, + interval_secs, + } => (*timeout_secs, *interval_secs), + PolicyConfig::CacheAwareConfig { + timeout_secs, + interval_secs, + .. + } => (*timeout_secs, *interval_secs), + }; + + // Wait until all workers are healthy + Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?; + + // Create router based on policy... + Ok(match policy_config { + PolicyConfig::RandomConfig { + timeout_secs, + interval_secs, + } => Router::Random { + worker_urls: Arc::new(RwLock::new(worker_urls)), + timeout_secs, + interval_secs, + }, + PolicyConfig::RoundRobinConfig { + timeout_secs, + interval_secs, + } => Router::RoundRobin { + worker_urls: Arc::new(RwLock::new(worker_urls)), + current_index: std::sync::atomic::AtomicUsize::new(0), + timeout_secs, + interval_secs, + }, + PolicyConfig::CacheAwareConfig { + cache_threshold, + balance_abs_threshold, + balance_rel_threshold, + eviction_interval_secs, + max_tree_size, + timeout_secs, + interval_secs, + } => { + let mut running_queue = HashMap::new(); + for url in &worker_urls { + running_queue.insert(url.clone(), 0); + } + + let mut processed_queue = HashMap::new(); + for url in &worker_urls { + processed_queue.insert(url.clone(), 0); + } + + let tree = Arc::new(Mutex::new(Tree::new())); + let running_queue = Arc::new(Mutex::new(running_queue)); + let processed_queue = Arc::new(Mutex::new(processed_queue)); + + // Create background eviction thread + let tree_clone = Arc::clone(&tree); + let processed_queue_clone = Arc::clone(&processed_queue); + let running_queue_clone = Arc::clone(&running_queue); + let eviction_thread = thread::spawn(move || { + loop { + // Sleep for the specified interval + thread::sleep(Duration::from_secs(eviction_interval_secs)); + + let locked_tree_clone = tree_clone.lock().unwrap(); + // Run eviction + locked_tree_clone.evict_tenant_by_size(max_tree_size); + + // Print the process queue + let locked_processed_queue = processed_queue_clone.lock().unwrap(); + info!("Processed Queue: {:?}", locked_processed_queue); + + // Print the running queue + let locked_running_queue = running_queue_clone.lock().unwrap(); + info!("Running Queue: {:?}", locked_running_queue); + } + }); + + for url in &worker_urls { + tree.lock().unwrap().insert(&"".to_string(), url); + } + + Router::CacheAware { + worker_urls: Arc::new(RwLock::new(worker_urls)), + tree, + running_queue, + processed_queue, + cache_threshold, + balance_abs_threshold, + balance_rel_threshold, + timeout_secs, + interval_secs, + _eviction_thread: Some(eviction_thread), + } + } + }) + } + + fn wait_for_healthy_workers( + worker_urls: &[String], + timeout_secs: u64, + interval_secs: u64, + ) -> Result<(), String> { + let start_time = std::time::Instant::now(); + let sync_client = reqwest::blocking::Client::new(); + + loop { + if start_time.elapsed() > Duration::from_secs(timeout_secs) { + error!( + "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", + timeout_secs, worker_urls + ); + return Err(format!( + "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", + timeout_secs, worker_urls + )); + } + + let mut all_healthy = true; + let mut unhealthy_workers = Vec::new(); + + for url in worker_urls { + match sync_client.get(&format!("{}/health", url)).send() { + Ok(res) => { + if !res.status().is_success() { + info!( + "Worker {} health check is pending with status: {}.", + url, + res.status() + ); + all_healthy = false; + unhealthy_workers.push((url, format!("Status: {}", res.status()))); + } + } + Err(e) => { + info!("Worker {} health check is pending with error: {}", url, e); + all_healthy = false; + unhealthy_workers.push((url, format!("Error: {}", e))); + } + } + } + + if all_healthy { + info!("All workers are healthy"); + return Ok(()); + } else { + info!("Unhealthy workers:"); + for (url, reason) in &unhealthy_workers { + info!(" {} - {}", url, reason); + } + thread::sleep(Duration::from_secs(interval_secs)); + } + } + } + + fn select_first_worker(&self) -> Result { + match self { + Router::RoundRobin { worker_urls, .. } + | Router::Random { worker_urls, .. } + | Router::CacheAware { worker_urls, .. } => { + if worker_urls.read().unwrap().is_empty() { + Err("No workers are available".to_string()) + } else { + Ok(worker_urls.read().unwrap()[0].clone()) + } + } + } + } + + async fn send_request( + &self, + client: &reqwest::Client, + worker_url: &str, + route: &str, + req: &HttpRequest, + ) -> HttpResponse { + let mut request_builder = client.get(format!("{}{}", worker_url, route)); + + // Copy all headers from original request except for /health because it does not need authorization + if route != "/health" { + for (name, value) in copy_request_headers(req) { + request_builder = request_builder.header(name, value); + } + } + + match request_builder.send().await { + Ok(res) => { + let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) + .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); + + match res.bytes().await { + Ok(body) => HttpResponse::build(status).body(body.to_vec()), + Err(e) => HttpResponse::InternalServerError() + .body(format!("Failed to read response body: {}", e)), + } + } + Err(e) => HttpResponse::InternalServerError().body(format!( + "Failed to send request to worker {}: {}", + worker_url, e + )), + } + } + + pub async fn route_to_first( + &self, + client: &reqwest::Client, + route: &str, + req: &HttpRequest, + ) -> HttpResponse { + const MAX_REQUEST_RETRIES: u32 = 3; + const MAX_TOTAL_RETRIES: u32 = 6; + let mut total_retries = 0; + + while total_retries < MAX_TOTAL_RETRIES { + match self.select_first_worker() { + Ok(worker_url) => { + let mut request_retries = 0; + + // Try the same worker multiple times + while request_retries < MAX_REQUEST_RETRIES { + if total_retries >= 1 { + info!("Retrying request after {} failed attempts", total_retries); + } + + let response = self.send_request(client, &worker_url, route, req).await; + + if response.status().is_success() { + return response; + } else { + // if the worker is healthy, it means the request is bad, so return the error response + let health_response = + self.send_request(client, &worker_url, "/health", req).await; + if health_response.status().is_success() { + return response; + } + } + + warn!( + "Request to {} failed (attempt {}/{})", + worker_url, + request_retries + 1, + MAX_REQUEST_RETRIES + ); + + request_retries += 1; + total_retries += 1; + + if request_retries == MAX_REQUEST_RETRIES { + warn!("Removing failed worker: {}", worker_url); + self.remove_worker(&worker_url); + break; + } + } + } + Err(e) => return HttpResponse::InternalServerError().body(e), + } + } + + HttpResponse::InternalServerError().body("All retry attempts failed") + } + + fn get_text_from_request(&self, body: &Bytes, route: &str) -> String { + // convert body to json + let json = serde_json::from_slice::(body).unwrap(); + + if route == "generate" { + // get the "text" field + let text = json.get("text").and_then(|t| t.as_str()).unwrap_or(""); + return text.to_string(); + } else if route == "v1/chat/completions" { + // get the messages field as raw text + if let Some(messages) = json.get("messages") { + // Convert messages back to a string, preserving all JSON formatting + return serde_json::to_string(messages).unwrap_or_default(); + } + } else if route == "v1/completions" { + let prompt = json.get("prompt").and_then(|t| t.as_str()).unwrap_or(""); + return prompt.to_string(); + } + + return "".to_string(); + } + + // TODO: return Result instead of panicking + fn select_generate_worker(&self, body: &Bytes, route: &str) -> String { + let text = self.get_text_from_request(&body, route); + + let worker_url = match self { + Router::RoundRobin { + worker_urls, + current_index, + .. + } => { + let idx = current_index + .fetch_update( + std::sync::atomic::Ordering::SeqCst, + std::sync::atomic::Ordering::SeqCst, + |x| Some((x + 1) % worker_urls.read().unwrap().len()), + ) + .unwrap(); + worker_urls.read().unwrap()[idx].clone() + } + + Router::Random { worker_urls, .. } => worker_urls.read().unwrap() + [rand::random::() % worker_urls.read().unwrap().len()] + .clone(), + + Router::CacheAware { + worker_urls, + tree, + running_queue, + processed_queue, + cache_threshold, + balance_abs_threshold, + balance_rel_threshold, + .. + } => { + // TODO: delay scheduling if cache hit rate is high because it may cause imbalance. prioritize low hit rate ones + + let tree = tree.lock().unwrap(); + let mut running_queue = running_queue.lock().unwrap(); + + // Get current load statistics + let max_load = *running_queue.values().max().unwrap_or(&0); + let min_load = *running_queue.values().min().unwrap_or(&0); + + // Load is considered imbalanced if: + // 1. (max - min) > abs_threshold AND + // 2. max > rel_threshold * min + let is_imbalanced = max_load.saturating_sub(min_load) > *balance_abs_threshold + && (max_load as f32) > (min_load as f32 * balance_rel_threshold); + + let selected_url = if is_imbalanced { + // Log load balancing trigger and current queue state + info!( + "Load balancing triggered due to workload imbalance:\n\ + Max load: {}, Min load: {}\n\ + Current running queue: {:?}", + max_load, min_load, running_queue + ); + + // Use shortest queue routing when load is imbalanced + running_queue + .iter() + .min_by_key(|(_url, &count)| count) + .map(|(url, _)| url.clone()) + .unwrap_or_else(|| worker_urls.read().unwrap()[0].clone()) + } else { + // Use cache-aware routing when load is balanced + let (matched_text, matched_worker) = tree.prefix_match(&text); + let matched_rate = + matched_text.chars().count() as f32 / text.chars().count() as f32; + + if matched_rate > *cache_threshold { + matched_worker.to_string() + } else { + tree.get_smallest_tenant() + } + }; + + // Update queues and tree + *running_queue.get_mut(&selected_url).unwrap() += 1; + + *processed_queue + .lock() + .unwrap() + .get_mut(&selected_url) + .unwrap() += 1; + tree.insert(&text, &selected_url); + + selected_url + } + }; + + worker_url + } + + async fn send_generate_request( + &self, + client: &reqwest::Client, + req: &HttpRequest, + body: &Bytes, + route: &str, + worker_url: &str, + ) -> HttpResponse { + let is_stream = serde_json::from_slice::(&body) + .map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false)) + .unwrap_or(false); + + let mut request_builder = client + .post(format!("{}{}", worker_url, route)) + .body(body.to_vec()); + + // Copy all headers from original request + for (name, value) in copy_request_headers(req) { + request_builder = request_builder.header(name, value); + } + + let res = match request_builder.send().await { + Ok(res) => res, + Err(_) => return HttpResponse::InternalServerError().finish(), + }; + + let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) + .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); + + if !is_stream { + // For non-streaming requests, get response first + let response = match res.bytes().await { + Ok(body) => HttpResponse::build(status).body(body.to_vec()), + Err(e) => { + let error_msg = format!("Failed to get response body: {}", e); + HttpResponse::InternalServerError().body(error_msg) + } + }; + + // Then decrement running queue counter if using CacheAware + if let Router::CacheAware { running_queue, .. } = self { + if let Ok(mut queue) = running_queue.lock() { + if let Some(count) = queue.get_mut(worker_url) { + *count = count.saturating_sub(1); + } + } + } + + response + } else if let Router::CacheAware { running_queue, .. } = self { + let running_queue = Arc::clone(running_queue); + let worker_url = worker_url.to_string(); + + HttpResponse::build(status) + .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))) + .streaming( + res.bytes_stream() + .map_err(|_| { + actix_web::error::ErrorInternalServerError("Failed to read stream") + }) + .inspect(move |bytes| { + let bytes = bytes.as_ref().unwrap(); + if bytes + .as_ref() + .windows(12) + .any(|window| window == b"data: [DONE]") + { + let mut locked_queue = running_queue.lock().unwrap(); + let count = locked_queue.get_mut(&worker_url).unwrap(); + *count = count.saturating_sub(1); + debug!("Streaming is done!!") + } + }), + ) + } else { + HttpResponse::build(status) + .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))) + .streaming(res.bytes_stream().map_err(|_| { + actix_web::error::ErrorInternalServerError("Failed to read stream") + })) + } + } + + pub async fn route_generate_request( + &self, + client: &reqwest::Client, + req: &HttpRequest, + body: &Bytes, + route: &str, + ) -> HttpResponse { + const MAX_REQUEST_RETRIES: u32 = 3; + const MAX_TOTAL_RETRIES: u32 = 6; + let mut total_retries = 0; + + while total_retries < MAX_TOTAL_RETRIES { + let worker_url = self.select_generate_worker(body, route); + let mut request_retries = 0; + + // Try the same worker multiple times + while request_retries < MAX_REQUEST_RETRIES { + if total_retries >= 1 { + info!("Retrying request after {} failed attempts", total_retries); + } + let response = self + .send_generate_request(client, req, body, route, &worker_url) + .await; + + if response.status().is_success() { + return response; + } else { + // if the worker is healthy, it means the request is bad, so return the error response + let health_response = + self.send_request(client, &worker_url, "/health", req).await; + if health_response.status().is_success() { + return response; + } + } + + warn!( + "Generate request to {} failed (attempt {}/{})", + worker_url, + request_retries + 1, + MAX_REQUEST_RETRIES + ); + + request_retries += 1; + total_retries += 1; + + if request_retries == MAX_REQUEST_RETRIES { + warn!("Removing failed worker: {}", worker_url); + self.remove_worker(&worker_url); + break; + } + } + } + + HttpResponse::InternalServerError().body("All retry attempts failed") + } + + pub async fn add_worker(&self, worker_url: &str) -> Result { + let (timeout_secs, interval_secs) = match self { + Router::Random { + timeout_secs, + interval_secs, + .. + } => (*timeout_secs, *interval_secs), + Router::RoundRobin { + timeout_secs, + interval_secs, + .. + } => (*timeout_secs, *interval_secs), + Router::CacheAware { + timeout_secs, + interval_secs, + .. + } => (*timeout_secs, *interval_secs), + }; + + let start_time = std::time::Instant::now(); + let client = reqwest::Client::new(); + + loop { + if start_time.elapsed() > Duration::from_secs(timeout_secs) { + error!( + "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", + timeout_secs, worker_url + ); + return Err(format!( + "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", + timeout_secs, worker_url + )); + } + + match client.get(&format!("{}/health", worker_url)).send().await { + Ok(res) => { + if res.status().is_success() { + match self { + Router::RoundRobin { worker_urls, .. } + | Router::Random { worker_urls, .. } + | Router::CacheAware { worker_urls, .. } => { + info!("Worker {} health check passed", worker_url); + let mut urls = worker_urls.write().unwrap(); + if urls.contains(&worker_url.to_string()) { + return Err(format!("Worker {} already exists", worker_url)); + } + info!("Added worker: {}", worker_url); + urls.push(worker_url.to_string()); + } + } + + // If cache aware, initialize the queues for the new worker + if let Router::CacheAware { + running_queue, + processed_queue, + tree, + .. + } = self + { + // Add worker to running queue with initial count of 0 + running_queue + .lock() + .unwrap() + .insert(worker_url.to_string(), 0); + + // Add worker to processed queue with initial count of 0 + processed_queue + .lock() + .unwrap() + .insert(worker_url.to_string(), 0); + + // Add worker to tree + tree.lock().unwrap().insert(&"".to_string(), &worker_url); + } + + return Ok(format!("Successfully added worker: {}", worker_url)); + } else { + info!( + "Worker {} health check is pending with status: {}.", + worker_url, + res.status() + ); + // if the url does not have http or https prefix, warn users + if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") + { + warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url); + } + + tokio::time::sleep(Duration::from_secs(interval_secs)).await; + continue; + } + } + Err(e) => { + info!( + "Worker {} health check is pending with error: {}", + worker_url, e + ); + + // if the url does not have http or https prefix, warn users + if !worker_url.starts_with("http://") && !worker_url.starts_with("https://") { + warn!("The worker url {} does not have http or https prefix. Please add the prefix to the url.", worker_url); + } + + tokio::time::sleep(Duration::from_secs(interval_secs)).await; + continue; + } + } + } + } + + pub fn remove_worker(&self, worker_url: &str) { + match self { + Router::RoundRobin { worker_urls, .. } + | Router::Random { worker_urls, .. } + | Router::CacheAware { worker_urls, .. } => { + let mut urls = worker_urls.write().unwrap(); + if let Some(index) = urls.iter().position(|url| url == &worker_url) { + urls.remove(index); + info!("Removed worker: {}", worker_url); + } else { + warn!("Worker {} not found, skipping removal", worker_url); + return; + } + } + } + + // if cache aware, remove the worker from the tree + if let Router::CacheAware { + tree, + running_queue, + processed_queue, + .. + } = self + { + tree.lock().unwrap().remove_tenant(&worker_url); + running_queue + .lock() + .unwrap() + .remove(&worker_url.to_string()); + processed_queue + .lock() + .unwrap() + .remove(&worker_url.to_string()); + info!( + "Removed worker from tree and cleaned up queues: {}", + worker_url + ); + } + } +} diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs new file mode 100644 index 00000000000..0706c57c06c --- /dev/null +++ b/sgl-router/src/server.rs @@ -0,0 +1,199 @@ +use crate::router::PolicyConfig; +use crate::router::Router; +use actix_web::{get, post, web, App, HttpRequest, HttpResponse, HttpServer, Responder}; +use bytes::Bytes; +use env_logger::Builder; +use log::{info, LevelFilter}; +use std::collections::HashMap; +use std::io::Write; + +#[derive(Debug)] +pub struct AppState { + router: Router, + client: reqwest::Client, +} + +impl AppState { + pub fn new( + worker_urls: Vec, + client: reqwest::Client, + policy_config: PolicyConfig, + ) -> Result { + // Create router based on policy + let router = Router::new(worker_urls, policy_config)?; + Ok(Self { router, client }) + } +} + +#[get("/health")] +async fn health(req: HttpRequest, data: web::Data) -> impl Responder { + data.router + .route_to_first(&data.client, "/health", &req) + .await +} + +#[get("/health_generate")] +async fn health_generate(req: HttpRequest, data: web::Data) -> impl Responder { + data.router + .route_to_first(&data.client, "/health_generate", &req) + .await +} + +#[get("/get_server_info")] +async fn get_server_info(req: HttpRequest, data: web::Data) -> impl Responder { + data.router + .route_to_first(&data.client, "/get_server_info", &req) + .await +} + +#[get("/v1/models")] +async fn v1_models(req: HttpRequest, data: web::Data) -> impl Responder { + data.router + .route_to_first(&data.client, "/v1/models", &req) + .await +} + +#[get("/get_model_info")] +async fn get_model_info(req: HttpRequest, data: web::Data) -> impl Responder { + data.router + .route_to_first(&data.client, "/get_model_info", &req) + .await +} + +#[post("/generate")] +async fn generate(req: HttpRequest, body: Bytes, data: web::Data) -> impl Responder { + data.router + .route_generate_request(&data.client, &req, &body, "/generate") + .await +} + +#[post("/v1/chat/completions")] +async fn v1_chat_completions( + req: HttpRequest, + body: Bytes, + data: web::Data, +) -> impl Responder { + data.router + .route_generate_request(&data.client, &req, &body, "/v1/chat/completions") + .await +} + +#[post("/v1/completions")] +async fn v1_completions( + req: HttpRequest, + body: Bytes, + data: web::Data, +) -> impl Responder { + data.router + .route_generate_request(&data.client, &req, &body, "/v1/completions") + .await +} + +#[post("/add_worker")] +async fn add_worker( + query: web::Query>, + data: web::Data, +) -> impl Responder { + let worker_url = match query.get("url") { + Some(url) => url.to_string(), + None => { + return HttpResponse::BadRequest() + .body("Worker URL required. Provide 'url' query parameter") + } + }; + + match data.router.add_worker(&worker_url).await { + Ok(message) => HttpResponse::Ok().body(message), + Err(error) => HttpResponse::BadRequest().body(error), + } +} + +#[post("/remove_worker")] +async fn remove_worker( + query: web::Query>, + data: web::Data, +) -> impl Responder { + let worker_url = match query.get("url") { + Some(url) => url.to_string(), + None => return HttpResponse::BadRequest().finish(), + }; + data.router.remove_worker(&worker_url); + HttpResponse::Ok().body(format!("Successfully removed worker: {}", worker_url)) +} + +pub struct ServerConfig { + pub host: String, + pub port: u16, + pub worker_urls: Vec, + pub policy_config: PolicyConfig, + pub verbose: bool, + pub max_payload_size: usize, +} + +pub async fn startup(config: ServerConfig) -> std::io::Result<()> { + // Initialize logger + Builder::new() + .format(|buf, record| { + use chrono::Local; + writeln!( + buf, + "[Router (Rust)] {} - {} - {}", + Local::now().format("%Y-%m-%d %H:%M:%S"), + record.level(), + record.args() + ) + }) + .filter( + None, + if config.verbose { + LevelFilter::Debug + } else { + LevelFilter::Info + }, + ) + .init(); + + info!("🚧 Initializing router on {}:{}", config.host, config.port); + info!("🚧 Initializing workers on {:?}", config.worker_urls); + info!("🚧 Policy Config: {:?}", config.policy_config); + info!( + "🚧 Max payload size: {} MB", + config.max_payload_size / (1024 * 1024) + ); + + let client = reqwest::Client::builder() + .build() + .expect("Failed to create HTTP client"); + + let app_state = web::Data::new( + AppState::new( + config.worker_urls.clone(), + client, + config.policy_config.clone(), + ) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?, + ); + + info!("✅ Serving router on {}:{}", config.host, config.port); + info!("✅ Serving workers on {:?}", config.worker_urls); + + HttpServer::new(move || { + App::new() + .app_data(app_state.clone()) + .app_data(web::JsonConfig::default().limit(config.max_payload_size)) + .app_data(web::PayloadConfig::default().limit(config.max_payload_size)) + .service(generate) + .service(v1_chat_completions) + .service(v1_completions) + .service(v1_models) + .service(get_model_info) + .service(health) + .service(health_generate) + .service(get_server_info) + .service(add_worker) + .service(remove_worker) + }) + .bind((config.host, config.port))? + .run() + .await +} diff --git a/rust/src/tree.rs b/sgl-router/src/tree.rs similarity index 100% rename from rust/src/tree.rs rename to sgl-router/src/tree.rs diff --git a/sgl-router/v0.1.0.md b/sgl-router/v0.1.0.md new file mode 100644 index 00000000000..747731a71c2 --- /dev/null +++ b/sgl-router/v0.1.0.md @@ -0,0 +1,63 @@ +# SGLang Router v0.1.0: Dynamic Scaling and Fault Tolerance + +We have released `sglang-router` v0.1.0 equipped with dynamic scaling and fault tolerance! It is essential for the router to be able to dynamically scale the number of workers and handle worker failures. To achieve this, we have implemented the following features: + +## 1. Dynamic scaling: The router can dynamically scale the number of workers based on the request load. + +We offer `/add_worker` and `/remove_worker` APIs to dynamically add or remove workers from the router. + +- `/add_worker` + +Usage: + +```bash +$ curl -X POST http://localhost:30000/add_worker?url=http://worker_url_1 +``` + +Example: + +```bash +$ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30001 +$ curl -X POST http://localhost:30000/add_worker?url=http://127.0.0.1:30001 +Successfully added worker: http://127.0.0.1:30001 +``` + +- `/remove_worker` + +Usage: + +```bash +$ curl -X POST http://localhost:30000/remove_worker?url=http://worker_url_1 +``` + +Example: + +```bash +$ curl -X POST http://localhost:30000/remove_worker?url=http://127.0.0.1:30001 +Successfully removed worker: http://127.0.0.1:30001 +``` + +Note: + +- For cache-aware router, the worker will be removed from the tree and the queues. + +## 2. Fault tolerance: The router can handle worker failures and automatically remove the failed worker from the router. + +We provide retries based for failure tolerance. + +1. If the request to a worker fails for `max_worker_retries` times, the router will remove the worker from the router and move on to the next worker. +2. If the total number of retries exceeds `max_total_retries`, the router will return an error. + +Note: + +- `max_worker_retries` is 3 and `max_total_retries` is 6 by default. + +## Closing remarks: + +1. Please read the full usage at https://docs.sglang.ai/router/router.html +2. The feature is still under active improvement, so please don't hesitate to raise issues or submit PRs if you have any suggestions or feedback. + + +# Release Instructions + +Update the version in `rust/pyproject.toml` and `py_src/sglang_router/version.py`. diff --git a/test/README.md b/test/README.md index 9825faf63bf..868061bbc4a 100644 --- a/test/README.md +++ b/test/README.md @@ -13,7 +13,7 @@ python3 test_srt_endpoint.py python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_simple_decode # Run a suite with multiple files -python3 run_suite.py --suite minimal +python3 run_suite.py --suite per-commit ``` ## Test Frontend Language @@ -25,8 +25,23 @@ export OPENAI_API_KEY=sk-***** python3 test_openai_backend.py # Run a single test -python3 -m unittest test_openai_backend.TestOpenAIBackend.test_few_shot_qa +python3 -m unittest test_openai_backend.TestOpenAIServer.test_few_shot_qa # Run a suite with multiple files -python3 run_suite.py --suite minimal +python3 run_suite.py --suite per-commit ``` + +## Adding or Updating Tests in CI + +- Create new test files under `test/srt` or `test/lang` depending on the type of test. +- Ensure they are referenced in the respective `run_suite.py` (e.g., `test/srt/run_suite.py` or `test/lang/run_suite.py`) so they’re picked up in CI. For most small test cases, they can be added to the `per-commit` suite. +- The CI will run the `per-commit` and `nightly` automatically. If you need special setup or custom test groups, you may modify the workflows in [`.github/workflows/`](https://github.com/sgl-project/sglang/tree/main/.github/workflows). + + +## Writing Elegant Test Cases + +- Examine existing tests in [sglang/test](https://github.com/sgl-project/sglang/tree/main/test) for practical examples. +- Keep each test function focused on a single scenario or piece of functionality. +- Give tests descriptive names reflecting their purpose. +- Use robust assertions (e.g., assert, unittest methods) to validate outcomes. +- Clean up resources to avoid side effects and preserve test independence. diff --git a/test/lang/run_suite.py b/test/lang/run_suite.py index 379427afac9..327d18b3fbd 100644 --- a/test/lang/run_suite.py +++ b/test/lang/run_suite.py @@ -4,7 +4,11 @@ from sglang.test.test_utils import run_unittest_files suites = { - "minimal": ["test_srt_backend.py", "test_openai_backend.py"], + "per-commit": [ + "test_srt_backend.py", + # Skip this due to some OPENAI_API_KEY issues + # "test_openai_backend.py", + ], } diff --git a/test/lang/test_srt_backend.py b/test/lang/test_srt_backend.py index b99606fc1cb..a4b1b88a23d 100644 --- a/test/lang/test_srt_backend.py +++ b/test/lang/test_srt_backend.py @@ -1,6 +1,7 @@ """ Usage: python3 -m unittest test_srt_backend.TestSRTBackend.test_gen_min_new_tokens +python3 -m unittest test_srt_backend.TestSRTBackend.test_hellaswag_select """ import unittest @@ -73,7 +74,7 @@ def test_hellaswag_select(self): # Run twice to capture more bugs for _ in range(2): accuracy, latency = test_hellaswag_select() - self.assertGreater(accuracy, 0.71) + self.assertGreater(accuracy, 0.70) def test_gen_min_new_tokens(self): test_gen_min_new_tokens() diff --git a/test/srt/configs/random_config.yaml b/test/srt/configs/random_config.yaml new file mode 100644 index 00000000000..eae8c27f41c --- /dev/null +++ b/test/srt/configs/random_config.yaml @@ -0,0 +1,25 @@ +tasks: + - name: sglang-128-4 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 128 --random-output 4 --request-rate 24 --num-prompt 1440 + - name: vllm-128-4 + server_cmd: python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests + client_cmd: python3 -m sglang.bench_serving --backend vllm --dataset-name random --random-input 128 --random-output 4 --request-rate 24 --num-prompt 1440 + - name: sglang-2000-100 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 2000 --random-output 100 --request-rate 2 --num-prompt 120 + - name: vllm-2000-100 + server_cmd: python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests + client_cmd: python3 -m sglang.bench_serving --backend vllm --dataset-name random --random-input 2000 --random-output 100 --request-rate 2 --num-prompt 120 + - name: sglang-4000-200 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 4000 --random-output 200 --request-rate 8 --num-prompt 480 + - name: vllm-4000-200 + server_cmd: python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests + client_cmd: python3 -m sglang.bench_serving --backend vllm --dataset-name random --random-input 4000 --random-output 200 --request-rate 8 --num-prompt 480 + - name: sglang-32000-100 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 32000 --random-output 100 --request-rate 1 --num-prompt 60 + - name: vllm-32000-100 + server_cmd: python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests + client_cmd: python3 -m sglang.bench_serving --backend vllm --dataset-name random --random-input 32000 --random-output 100 --request-rate 1 --num-prompt 60 diff --git a/test/srt/configs/random_flashinfer_vs_triton_config.yaml b/test/srt/configs/random_flashinfer_vs_triton_config.yaml new file mode 100644 index 00000000000..7f4a386ddcf --- /dev/null +++ b/test/srt/configs/random_flashinfer_vs_triton_config.yaml @@ -0,0 +1,25 @@ +tasks: + - name: sglang-128-4 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 128 --random-output 4 --request-rate 24 --num-prompt 1440 + - name: sglang-triton-128-4 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache --attention-backend triton + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 128 --random-output 4 --request-rate 24 --num-prompt 1440 + - name: sglang-2000-100 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 2000 --random-output 100 --request-rate 2 --num-prompt 120 + - name: sglang-triton-2000-100 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache --attention-backend triton + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 2000 --random-output 100 --request-rate 2 --num-prompt 120 + - name: sglang-4000-200 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 4000 --random-output 200 --request-rate 8 --num-prompt 480 + - name: sglang-triton-4000-200 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache --attention-backend triton + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 4000 --random-output 200 --request-rate 8 --num-prompt 480 + - name: sglang-32000-100 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 32000 --random-output 100 --request-rate 1 --num-prompt 60 + - name: sglang-triton-32000-100 + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache --attention-backend triton + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name random --random-input 32000 --random-output 100 --request-rate 1 --num-prompt 60 diff --git a/test/srt/configs/sharegpt_config.yaml b/test/srt/configs/sharegpt_config.yaml new file mode 100644 index 00000000000..a80b96c8eae --- /dev/null +++ b/test/srt/configs/sharegpt_config.yaml @@ -0,0 +1,7 @@ +tasks: + - name: sglang-benchmark + server_cmd: python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --disable-radix-cache + client_cmd: python3 -m sglang.bench_serving --backend sglang --dataset-name sharegpt --request-rate 16 + - name: vllm-benchmark + server_cmd: python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-3.1-8B-Instruct --disable-log-requests + client_cmd: python3 -m sglang.bench_serving --backend vllm --dataset-name sharegpt --request-rate 16 diff --git a/test/srt/experiment_runner.py b/test/srt/experiment_runner.py new file mode 100644 index 00000000000..c4966dc77ba --- /dev/null +++ b/test/srt/experiment_runner.py @@ -0,0 +1,359 @@ +import argparse +import logging +import os +import queue +import re +import subprocess +import threading +import time +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import psutil +import requests +import yaml + + +@dataclass +class ServerConfig: + command: str + process_names: List[str] + default_port: int + + +@dataclass +class TaskConfig: + server_cmd: str + client_cmd: str + name: Optional[str] = None + server_type: Optional[str] = None + + +@dataclass +class TaskResult: + name: str + success: bool + output: str + runtime: float + timestamp: str + + +SERVER_DEFAULTS = { + "sglang": ServerConfig( + command="sglang.launch_server", + process_names=["sglang.launch_server"], + default_port=30000, + ), + "vllm": ServerConfig( + command="vllm.entrypoints.openai.api_server", + process_names=["vllm.entrypoints.openai.api_server"], + default_port=8000, + ), +} + + +def parse_key_info(output: str) -> str: + """Extract and format key information from the output""" + key_info = [] + + # Extract Args namespace + args_match = re.search(r"Namespace\(.*?\)", output, re.DOTALL) + if args_match: + key_info.append(args_match.group(0)) + + # Extract input/output token counts + token_matches = re.findall(r"#(Input|Output) tokens: \d+", output) + key_info.extend(token_matches) + + # Extract benchmark result section + result_match = re.search( + r"============ Serving Benchmark Result ============.*?={50,}", + output, + re.DOTALL, + ) + if result_match: + key_info.append(result_match.group(0)) + + return "\n\n".join(key_info) + + +def extract_port_from_command(cmd: str, server_type: str) -> int: + port_match = re.search(r"--port[= ](\d+)", cmd) + if port_match: + return int(port_match.group(1)) + return SERVER_DEFAULTS.get(server_type, ServerConfig("", [], 8000)).default_port + + +def detect_server_type(cmd: str) -> str: + for server_type, config in SERVER_DEFAULTS.items(): + if config.command in cmd: + return server_type + return "unknown" + + +def stream_output( + process: subprocess.Popen, prefix: str, logger: logging.Logger +) -> queue.Queue: + output_queue = queue.Queue() + + def stream_pipe(pipe, prefix): + for line in iter(pipe.readline, ""): + if prefix == "CLIENT": + output_queue.put(line.rstrip()) + logger.debug(f"{prefix} | {line.rstrip()}") + + stdout_thread = threading.Thread( + target=stream_pipe, args=(process.stdout, prefix), daemon=True + ) + stderr_thread = threading.Thread( + target=stream_pipe, args=(process.stderr, prefix), daemon=True + ) + + stdout_thread.start() + stderr_thread.start() + return output_queue, (stdout_thread, stderr_thread) + + +class ProcessManager: + def __init__(self): + self.server_process: Optional[subprocess.Popen] = None + self.client_process: Optional[subprocess.Popen] = None + self.logger = logging.getLogger(__name__) + + def start_process( + self, command: str, prefix: str + ) -> Tuple[subprocess.Popen, queue.Queue]: + process = subprocess.Popen( + command, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, + ) + + output_queue, threads = stream_output(process, prefix, self.logger) + return process, output_queue, threads + + def kill_process_tree(self, process: subprocess.Popen): + try: + parent = psutil.Process(process.pid) + children = parent.children(recursive=True) + + for child in children: + try: + child.kill() + except psutil.NoSuchProcess: + pass + + parent.kill() + gone, alive = psutil.wait_procs(children + [parent], timeout=3) + + for p in alive: + try: + p.kill() + except psutil.NoSuchProcess: + pass + + except psutil.NoSuchProcess: + pass + + def cleanup(self, process_names: List[str]): + if self.client_process: + self.kill_process_tree(self.client_process) + self.client_process = None + + if self.server_process: + self.kill_process_tree(self.server_process) + self.server_process = None + + for proc in psutil.process_iter(["pid", "name", "cmdline"]): + try: + cmdline = " ".join(proc.cmdline()) + if any(name in cmdline for name in process_names): + proc.kill() + except (psutil.NoSuchProcess, psutil.AccessDenied): + continue + + +class ExperimentRunner: + def __init__(self): + self.process_manager = ProcessManager() + self.logger = logging.getLogger(__name__) + + def wait_for_server(self, port: int, timeout: int = 300) -> bool: + start_time = time.time() + + while time.time() - start_time < timeout: + try: + response = requests.get(f"http://localhost:{port}/health") + if response.status_code == 200: + self.logger.debug(f"Server ready on port {port}") + return True + except requests.RequestException: + time.sleep(2) + return False + + def run_task(self, config: TaskConfig) -> TaskResult: + start_time = time.time() + client_output = [] + + try: + if not config.server_type: + config.server_type = detect_server_type(config.server_cmd) + + server_config = SERVER_DEFAULTS.get(config.server_type) + if not server_config: + raise ValueError(f"Unknown server type: {config.server_type}") + + port = extract_port_from_command(config.server_cmd, config.server_type) + + self.process_manager.cleanup(server_config.process_names) + + self.logger.debug(f"Starting server: {config.name}") + self.process_manager.server_process, _, server_threads = ( + self.process_manager.start_process(config.server_cmd, "SERVER") + ) + + if not self.wait_for_server(port): + raise TimeoutError("Server startup timeout") + + time.sleep(10) + + self.logger.debug("Starting client") + self.process_manager.client_process, output_queue, client_threads = ( + self.process_manager.start_process(config.client_cmd, "CLIENT") + ) + + returncode = self.process_manager.client_process.wait() + + while True: + try: + line = output_queue.get_nowait() + client_output.append(line) + except queue.Empty: + break + + if returncode != 0: + raise RuntimeError(f"Client failed with code {returncode}") + + # Parse and format the output + full_output = "\n".join(client_output) + formatted_output = parse_key_info(full_output) + + return TaskResult( + name=config.name, + success=True, + output=formatted_output, + runtime=time.time() - start_time, + timestamp=datetime.now().isoformat(), + ) + + except Exception as e: + return TaskResult( + name=config.name, + success=False, + output=str(e), + runtime=time.time() - start_time, + timestamp=datetime.now().isoformat(), + ) + + finally: + if config.server_type in SERVER_DEFAULTS: + self.process_manager.cleanup( + SERVER_DEFAULTS[config.server_type].process_names + ) + time.sleep(10) + + +def load_config(config_path: str) -> List[TaskConfig]: + with open(config_path, "r") as f: + config_data = yaml.safe_load(f) + + configs = [] + for idx, entry in enumerate(config_data.get("tasks", [])): + if not isinstance(entry, dict): + raise ValueError(f"Invalid entry at index {idx}") + + config = TaskConfig( + server_cmd=entry.get("server_cmd"), + client_cmd=entry.get("client_cmd"), + name=entry.get("name", f"task-{idx+1}"), + server_type=entry.get("server_type"), + ) + + if not config.server_cmd or not config.client_cmd: + raise ValueError(f"Missing commands in {config.name}") + + configs.append(config) + + return configs + + +def setup_logging(debug: bool = False): + level = logging.DEBUG if debug else logging.INFO + logging.basicConfig( + level=level, + format="%(asctime)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(), logging.FileHandler("experiment.log")], + ) + + +def format_results(results: List[TaskResult]) -> str: + """Format experiment results in Markdown for GitHub step summary.""" + output = ["# Experiment Results\n"] + + for result in results: + output.append(f"## {result.name}") + output.append(f"**Status**: {'✅ Success' if result.success else '❌ Failed'}") + output.append(f"**Runtime**: {result.runtime:.2f} seconds") + output.append(f"**Timestamp**: {result.timestamp}") + output.append("\n**Output**:\n```") + output.append(result.output) + output.append("```\n") + + return "\n".join(output) + + +def write_in_github_step_summary(results: List[TaskResult]): + """Write formatted results to GitHub step summary.""" + if not os.environ.get("GITHUB_STEP_SUMMARY"): + logging.warning("GITHUB_STEP_SUMMARY environment variable not set") + return + + formatted_content = format_results(results) + with open(os.environ["GITHUB_STEP_SUMMARY"], "a") as f: + f.write(formatted_content) + + +def main(): + parser = argparse.ArgumentParser(description="Experiment Runner") + parser.add_argument( + "--config", type=str, required=True, help="Path to YAML config file" + ) + parser.add_argument("--debug", action="store_true", help="Enable debug output") + args = parser.parse_args() + + setup_logging(args.debug) + logger = logging.getLogger(__name__) + results = [] + + try: + configs = load_config(args.config) + runner = ExperimentRunner() + + for config in configs: + logger.info(f"Running {config.name}") + result = runner.run_task(config) + results.append(result) + + write_in_github_step_summary(results) + except Exception as e: + logger.error(f"Error: {e}") + raise + + +if __name__ == "__main__": + main() diff --git a/test/srt/kv_cache_scales_llama3_1_8b.json b/test/srt/kv_cache_scales_llama3_1_8b.json new file mode 100644 index 00000000000..3e890e50e4a --- /dev/null +++ b/test/srt/kv_cache_scales_llama3_1_8b.json @@ -0,0 +1,42 @@ +{ + "model_type": "llama", + "kv_cache": { + "dtype": "float8_e4m3fn", + "scaling_factor": { + "0": { + "0": 1, + "1": 1, + "2": 1, + "3": 1, + "4": 1, + "5": 1, + "6": 1, + "7": 1, + "8": 1, + "9": 1, + "10": 1, + "11": 1, + "12": 1, + "13": 1, + "14": 1, + "15": 1, + "16": 1, + "17": 1, + "18": 1, + "19": 1, + "20": 1, + "21": 1, + "22": 1, + "23": 1, + "24": 1, + "25": 1, + "26": 1, + "27": 1, + "28": 1, + "29": 1, + "30": 1, + "31": 1 + } + } + } +} diff --git a/test/srt/kv_cache_scales_llama3_8b.json b/test/srt/kv_cache_scales_llama3_8b.json new file mode 100644 index 00000000000..466b0d01a74 --- /dev/null +++ b/test/srt/kv_cache_scales_llama3_8b.json @@ -0,0 +1,42 @@ +{ + "model_type": "llama", + "kv_cache": { + "dtype": "float8_e4m3fn", + "scaling_factor": { + "0": { + "0": 0.0408, + "1": 0.0503, + "2": 0.0667, + "3": 0.0909, + "4": 0.1135, + "5": 0.127, + "6": 0.1768, + "7": 0.1488, + "8": 0.1135, + "9": 0.1203, + "10": 0.1013, + "11": 0.0842, + "12": 0.1231, + "13": 0.1096, + "14": 0.1221, + "15": 0.1013, + "16": 0.1067, + "17": 0.0952, + "18": 0.0899, + "19": 0.097, + "20": 0.087, + "21": 0.0994, + "22": 0.0904, + "23": 0.1013, + "24": 0.1019, + "25": 0.1053, + "26": 0.1, + "27": 0.0894, + "28": 0.1013, + "29": 0.1488, + "30": 0.0766, + "31": 0.0821 + } + } + } +} diff --git a/test/srt/kv_cache_scales_qwen2_1_5b.json b/test/srt/kv_cache_scales_qwen2_1_5b.json new file mode 100644 index 00000000000..984747509f7 --- /dev/null +++ b/test/srt/kv_cache_scales_qwen2_1_5b.json @@ -0,0 +1,38 @@ +{ + "model_type": "qwen", + "kv_cache": { + "dtype": "float8_e4m3fn", + "scaling_factor": { + "0": { + "0": 0.9846, + "1": 0.0645, + "2": 0.0731, + "3": 0.0800, + "4": 0.0748, + "5": 0.0780, + "6": 0.0702, + "7": 0.0894, + "8": 0.0410, + "9": 0.0758, + "10": 0.0556, + "11": 0.0731, + "12": 0.0899, + "13": 0.0780, + "14": 0.1441, + "15": 0.0914, + "16": 0.5614, + "17": 0.1067, + "18": 0.0537, + "19": 0.0658, + "20": 0.0523, + "21": 0.0533, + "22": 0.0699, + "23": 0.0635, + "24": 0.0588, + "25": 0.0884, + "26": 0.0947, + "27": 0.1032 + } + } + } +} diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index d9f1795341c..fd27d5c07b6 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -57,6 +57,7 @@ class ModelCase: ModelCase("openai-community/gpt2"), ModelCase("microsoft/Phi-3-small-8k-instruct"), ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True), + ModelCase("ibm-granite/granite-3.0-2b-instruct", skip_long_prompt=True), ] TORCH_DTYPES = [torch.float16] diff --git a/test/srt/models/test_qwen_models.py b/test/srt/models/test_qwen_models.py new file mode 100644 index 00000000000..c7788fa8e50 --- /dev/null +++ b/test/srt/models/test_qwen_models.py @@ -0,0 +1,76 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestQwen2(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = "Qwen/Qwen2-7B-Instruct" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["accuracy"], 0.81) + + +class TestQwen2FP8(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = "neuralmagic/Qwen2-7B-Instruct-FP8" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["accuracy"], 0.79) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/models/test_reward_models.py b/test/srt/models/test_reward_models.py index 0d80a4d0cde..69ad563671b 100644 --- a/test/srt/models/test_reward_models.py +++ b/test/srt/models/test_reward_models.py @@ -20,8 +20,8 @@ from sglang.test.runners import HFRunner, SRTRunner MODELS = [ - ("LxzGordon/URM-LLaMa-3.1-8B", 1, 3e-2), - ("Skywork/Skywork-Reward-Llama-3.1-8B-v0.2", 1, 3e-2), + ("LxzGordon/URM-LLaMa-3.1-8B", 1, 4e-2), + ("Skywork/Skywork-Reward-Llama-3.1-8B-v0.2", 1, 4e-2), ] TORCH_DTYPES = [torch.float16] diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 5035810f86a..603bab957bd 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -4,28 +4,35 @@ from sglang.test.test_utils import run_unittest_files suites = { - "minimal": [ + "per-commit": [ "models/test_embedding_models.py", "models/test_generation_models.py", "models/test_lora.py", + "models/test_qwen_models.py", "models/test_reward_models.py", "sampling/penaltylib", "test_abort.py", "test_chunked_prefill.py", + "test_custom_allreduce.py", "test_double_sparsity.py", + "test_eagle_infer.py", "test_embedding_openai_server.py", "test_eval_accuracy_mini.py", - "test_get_weights_by_name.py", "test_gguf.py", "test_input_embeddings.py", "test_json_constrained.py", "test_large_max_new_tokens.py", "test_metrics.py", + "test_mla.py", + "test_mla_fp8.py", "test_no_chunked_prefill.py", "test_no_overlap_scheduler.py", "test_openai_server.py", "test_pytorch_sampling_backend.py", "test_radix_attention.py", + "test_regex_constrained.py", + "test_release_memory_occupation.py", + "test_request_length_validation.py", "test_retract_decode.py", "test_server_args.py", "test_session_control.py", @@ -39,14 +46,25 @@ "test_triton_attention_kernels.py", "test_triton_attention_backend.py", "test_update_weights_from_disk.py", + "test_update_weights_from_tensor.py", + "test_vision_chunked_prefill.py", + "test_vision_llm.py", "test_vision_openai_server.py", - "test_session_control.py", + "test_w8a8_quantization.py", + "test_fp8_kvcache.py", + "test_fp8_kernel.py", + ], + "nightly": [ + "test_nightly_gsm8k_eval.py", + # Disable temporarily + # "test_nightly_math_eval.py", ], "sampling/penaltylib": glob.glob( "sampling/penaltylib/**/test_*.py", recursive=True ), } +# Expand suite for target_suite_name, target_tests in suites.items(): for suite_name, tests in suites.items(): if suite_name == target_suite_name: @@ -55,7 +73,6 @@ tests.remove(target_suite_name) tests.extend(target_tests) - if __name__ == "__main__": arg_parser = argparse.ArgumentParser() arg_parser.add_argument( diff --git a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py index 0eccb3407f1..d9d77a9ae24 100644 --- a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py +++ b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py @@ -6,7 +6,7 @@ from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( - DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, popen_launch_server, @@ -17,7 +17,7 @@ class TestBatchPenalizerE2E(unittest.TestCase): @classmethod def setUpClass(cls): - cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( cls.model, @@ -58,8 +58,7 @@ def run_decode( "logprob_start_len": 0, }, ) - print(json.dumps(response.json())) - print("=" * 100) + assert response.status_code == 200, "Request failed: " + response.text def test_default_values(self): self.run_decode() @@ -112,4 +111,4 @@ def test_repetition_penalty(self): if __name__ == "__main__": - unittest.main() + unittest.main(verbosity=3) diff --git a/test/srt/test_bench_one_batch.py b/test/srt/test_bench_one_batch.py index c1bc98e8e04..c6562170d61 100644 --- a/test/srt/test_bench_one_batch.py +++ b/test/srt/test_bench_one_batch.py @@ -5,24 +5,46 @@ DEFAULT_MOE_MODEL_NAME_FOR_TEST, is_in_ci, run_bench_one_batch, + write_github_step_summary, ) class TestBenchOneBatch(unittest.TestCase): - def test_default(self): + def test_bs1(self): output_throughput = run_bench_one_batch(DEFAULT_MODEL_NAME_FOR_TEST, []) if is_in_ci(): + write_github_step_summary( + f"### test_bs1\n" + f"output_throughput : {output_throughput:.2f} token/s\n" + ) self.assertGreater(output_throughput, 135) - def test_moe_default(self): + def test_moe_tp2_bs1(self): output_throughput = run_bench_one_batch( DEFAULT_MOE_MODEL_NAME_FOR_TEST, ["--tp", "2"] ) if is_in_ci(): + write_github_step_summary( + f"### test_moe_tp2_bs1\n" + f"output_throughput : {output_throughput:.2f} token/s\n" + ) self.assertGreater(output_throughput, 125) + def test_torch_compile_tp2_bs1(self): + output_throughput = run_bench_one_batch( + DEFAULT_MODEL_NAME_FOR_TEST, + ["--tp", "2", "--enable-torch-compile", "--cuda-graph-max-bs", "2"], + ) + + if is_in_ci(): + write_github_step_summary( + f"### test_torch_compile_tp2_bs1\n" + f"output_throughput : {output_throughput:.2f} token/s\n" + ) + self.assertGreater(output_throughput, 240) + if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_bench_serving.py b/test/srt/test_bench_serving.py index 34a7b6c9670..8233438fcaf 100644 --- a/test/srt/test_bench_serving.py +++ b/test/srt/test_bench_serving.py @@ -1,6 +1,8 @@ import unittest from sglang.test.test_utils import ( + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, DEFAULT_FP8_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MOE_MODEL_NAME_FOR_TEST, @@ -47,7 +49,7 @@ def test_offline_throughput_non_stream_small_batch_size(self): ) # There is a regression with torch 2.5 # This number was 950 for torch 2.4 - self.assertGreater(res["output_throughput"], 800) + self.assertGreater(res["output_throughput"], 1000) def test_offline_throughput_without_radix_cache(self): res = run_bench_serving( @@ -112,7 +114,7 @@ def test_offline_throughput_default_fp8(self): f"### test_offline_throughput_default_fp8\n" f'Output throughput: {res["output_throughput"]:.2f} token/s\n' ) - self.assertGreater(res["output_throughput"], 3850) + self.assertGreater(res["output_throughput"], 3900) def test_online_latency_default(self): res = run_bench_serving( @@ -125,12 +127,42 @@ def test_online_latency_default(self): if is_in_ci(): write_github_step_summary( f"### test_online_latency_default\n" - f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} token/s\n' + f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} ms\n' ) - self.assertLess(res["median_e2e_latency_ms"], 12000) + self.assertLess(res["median_e2e_latency_ms"], 11000) self.assertLess(res["median_ttft_ms"], 86) self.assertLess(res["median_itl_ms"], 10) + def test_online_latency_eagle(self): + res = run_bench_serving( + model=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + num_prompts=50, + request_rate=1, + disable_ignore_eos=True, + dataset_name="sharegpt", + other_server_args=[ + "--speculative-algorithm", + "EAGLE", + "--speculative-draft-model-path", + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + "--speculative-num-steps", + "5", + "--speculative-eagle-topk", + "8", + "--speculative-num-draft-tokens", + "64", + "--mem-fraction-static", + "0.7", + ], + ) + + if is_in_ci(): + write_github_step_summary( + f"### test_online_latency_eagle\n" + f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} ms\n' + ) + self.assertLess(res["median_e2e_latency_ms"], 450) + def test_moe_offline_throughput_default(self): res = run_bench_serving( model=DEFAULT_MOE_MODEL_NAME_FOR_TEST, @@ -144,7 +176,7 @@ def test_moe_offline_throughput_default(self): f"### test_moe_offline_throughput_default\n" f'Output throughput: {res["output_throughput"]:.2f} token/s\n' ) - self.assertGreater(res["output_throughput"], 2150) + self.assertGreater(res["output_throughput"], 2200) def test_moe_offline_throughput_without_radix_cache(self): res = run_bench_serving( @@ -159,7 +191,7 @@ def test_moe_offline_throughput_without_radix_cache(self): f"### test_moe_offline_throughput_without_radix_cache\n" f'Output throughput: {res["output_throughput"]:.2f} token/s\n' ) - self.assertGreater(res["output_throughput"], 2150) + self.assertGreater(res["output_throughput"], 2200) if __name__ == "__main__": diff --git a/test/srt/test_custom_allreduce.py b/test/srt/test_custom_allreduce.py new file mode 100644 index 00000000000..5f6f5d9b491 --- /dev/null +++ b/test/srt/test_custom_allreduce.py @@ -0,0 +1,164 @@ +import os +import random +import socket +import unittest +from typing import Any + +import ray +import torch +import torch.distributed as dist + +from sglang.srt.distributed import init_distributed_environment +from sglang.srt.distributed.communication_op import ( # noqa + tensor_model_parallel_all_reduce, +) +from sglang.srt.distributed.parallel_state import ( + get_tensor_model_parallel_group, + graph_capture, + initialize_model_parallel, +) + + +def get_open_port() -> int: + # try ipv4 + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + except OSError: + # try ipv6 + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def multi_process_parallel( + world_size: int, + cls: Any, + test_target: Any, +) -> None: + + # Using ray helps debugging the error when it failed + # as compared to multiprocessing. + # NOTE: We need to set working_dir for distributed tests, + # otherwise we may get import errors on ray workers + ray.init(log_to_driver=False) + + distributed_init_port = get_open_port() + refs = [] + for rank in range(world_size): + refs.append(test_target.remote(cls, world_size, rank, distributed_init_port)) + ray.get(refs) + + ray.shutdown() + + +class TestCustomAllReduce(unittest.TestCase): + @classmethod + def setUpClass(cls): + random.seed(42) + # 512B to 32MB + cls.test_sizes = [512, 4096, 32768, 262144, 2097152, 16777216, 33554432] + cls.world_sizes = [2, 4, 6, 8] + cls.test_loop = 10 + + def test_graph_allreduce(self): + for world_size in self.world_sizes: + if world_size > torch.cuda.device_count(): + continue + multi_process_parallel(world_size, self, self.graph_allreduce) + + def test_eager_allreduce(self): + for world_size in self.world_sizes: + if world_size > torch.cuda.device_count(): + continue + multi_process_parallel(world_size, self, self.eager_allreduce) + + @ray.remote(num_gpus=1, max_calls=1) + def graph_allreduce(self, world_size, rank, distributed_init_port): + del os.environ["CUDA_VISIBLE_DEVICES"] + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + init_distributed_environment( + world_size=world_size, + rank=rank, + distributed_init_method=distributed_init_method, + local_rank=rank, + ) + initialize_model_parallel(tensor_model_parallel_size=world_size) + group = get_tensor_model_parallel_group().device_group + + # A small all_reduce for warmup. + # this is needed because device communicators might be created lazily + # (e.g. NCCL). This will ensure that the communicator is initialized + # before any communication happens, so that this group can be used for + # graph capture immediately. + data = torch.zeros(1) + data = data.to(device=device) + torch.distributed.all_reduce(data, group=group) + torch.cuda.synchronize() + del data + + for sz in self.test_sizes: + for dtype in [torch.float32, torch.float16, torch.bfloat16]: + for _ in range(self.test_loop): + with graph_capture() as graph_capture_context: + # use integers so result matches NCCL exactly + inp1 = torch.randint( + 1, + 16, + (sz,), + dtype=dtype, + device=torch.cuda.current_device(), + ) + inp2 = torch.randint( + 1, + 16, + (sz,), + dtype=dtype, + device=torch.cuda.current_device(), + ) + torch.cuda.synchronize() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph( + graph, stream=graph_capture_context.stream + ): + out1 = tensor_model_parallel_all_reduce(inp1) + # the input buffer is immediately modified to test + # synchronization + dist.all_reduce(inp1, group=group) + out2 = tensor_model_parallel_all_reduce(inp2) + dist.all_reduce(inp2, group=group) + graph.replay() + torch.testing.assert_close(out1, inp1) + torch.testing.assert_close(out2, inp2) + + @ray.remote(num_gpus=1, max_calls=1) + def eager_allreduce(self, world_size, rank, distributed_init_port): + del os.environ["CUDA_VISIBLE_DEVICES"] + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + init_distributed_environment( + world_size=world_size, + rank=rank, + distributed_init_method=distributed_init_method, + local_rank=rank, + ) + initialize_model_parallel(tensor_model_parallel_size=world_size) + group = get_tensor_model_parallel_group().device_group + + for sz in self.test_sizes: + for dtype in [torch.float32, torch.float16, torch.bfloat16]: + for _ in range(self.test_loop): + inp1 = torch.randint( + 1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device() + ) + out1 = tensor_model_parallel_all_reduce(inp1) + dist.all_reduce(inp1, group=group) + torch.testing.assert_close(out1, inp1) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py new file mode 100644 index 00000000000..b01c260496a --- /dev/null +++ b/test/srt/test_eagle_infer.py @@ -0,0 +1,180 @@ +import random +import threading +import time +import unittest +from types import SimpleNamespace + +import requests + +import sglang as sgl +from sglang.srt.hf_transformers_utils import get_tokenizer +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.test_utils import ( + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestEAGLEEngine(unittest.TestCase): + + def test_eagle_accuracy(self): + prompt = "Today is a sunny day and I like" + sampling_params = {"temperature": 0, "max_new_tokens": 8} + + # Get the reference output + ref_engine = sgl.Engine(model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST) + ref_output = ref_engine.generate(prompt, sampling_params)["text"] + ref_engine.shutdown() + + # Launch EAGLE engine + engine = sgl.Engine( + model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + speculative_algorithm="EAGLE", + speculative_num_steps=5, + speculative_eagle_topk=8, + speculative_num_draft_tokens=64, + mem_fraction_static=0.7, + ) + + # Case 1: Test the output of EAGLE engine is the same as normal engine + out1 = engine.generate(prompt, sampling_params)["text"] + print(f"{out1=}, {ref_output=}") + self.assertEqual(out1, ref_output) + + # Case 2: Test the output of EAGLE engine does not contain unexpected EOS + prompt = "[INST] <>\\nYou are a helpful assistant.\\n<>\\nToday is a sunny day and I like [/INST]" + sampling_params = { + "temperature": 0, + "max_new_tokens": 1024, + "skip_special_tokens": False, + } + + tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST) + out2 = engine.generate(prompt, sampling_params)["text"] + print(f"{out2=}") + tokens = tokenizer.encode(out2, truncation=False) + assert tokenizer.eos_token_id not in tokens + + # Case 3: Batched prompts + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = {"temperature": 0, "max_new_tokens": 30} + outputs = engine.generate(prompts, sampling_params) + for prompt, output in zip(prompts, outputs): + print("===============================") + print(f"Prompt: {prompt}\nGenerated text: {output['text']}") + + # Shutdown the engine + engine.shutdown() + + +prompts = [ + "[INST] <>\\nYou are a helpful assistant.\\n<>\\nToday is a sunny day and I like[/INST]" + '[INST] <>\\nYou are a helpful assistant.\\n<>\\nWhat are the mental triggers in Jeff Walker\'s Product Launch Formula and "Launch" book?[/INST]', + "[INST] <>\\nYou are a helpful assistant.\\n<>\\nSummarize Russell Brunson's Perfect Webinar Script...[/INST]", + "[INST] <>\\nYou are a helpful assistant.\\n<>\\nwho are you?[/INST]", + "[INST] <>\\nYou are a helpful assistant.\\n<>\\nwhere are you from?[/INST]", +] + + +class TestEAGLEServer(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--speculative-algorithm", + "EAGLE", + "--speculative-draft-model-path", + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + "--speculative-num-steps", + "5", + "--speculative-eagle-topk", + "8", + "--speculative-num-draft-tokens", + "64", + "--mem-fraction-static", + "0.7", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def send_request(self): + time.sleep(random.uniform(0, 2)) + for prompt in prompts: + url = self.base_url + "/generate" + data = { + "text": prompt, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 1024, + }, + } + response = requests.post(url, json=data) + assert response.status_code == 200 + + def send_requests_abort(self): + for prompt in prompts: + try: + time.sleep(random.uniform(0, 2)) + url = self.base_url + "/generate" + data = { + "model": "base", + "text": prompt, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 1024, + }, + } + # set timeout = 1s,mock disconnected + requests.post(url, json=data, timeout=1) + except Exception as e: + print(e) + pass + + def test_request_abort(self): + concurrency = 4 + threads = [ + threading.Thread(target=self.send_request) for _ in range(concurrency) + ] + [ + threading.Thread(target=self.send_requests_abort) + for _ in range(concurrency) + ] + for worker in threads: + worker.start() + for p in threads: + p.join() + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + + self.assertGreater(metrics["accuracy"], 0.20) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_ebnf_constrained.py b/test/srt/test_ebnf_constrained.py new file mode 100644 index 00000000000..5e852bec6e4 --- /dev/null +++ b/test/srt/test_ebnf_constrained.py @@ -0,0 +1,240 @@ +""" +python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_email +python3 -m unittest test_ebnf_constrained.TestEBNFConstrained.test_ebnf_generate_greeting +""" + +import json +import unittest + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +def setup_class(cls, disable_overlap: bool): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.ebnf_grammar = 'root ::= "test"' # Default grammar + + other_args = [ + "--max-running-requests", + "10", + "--grammar-backend", + "xgrammar", + ] + + if disable_overlap: + other_args += ["--disable-overlap-schedule"] + + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + +class TestEBNFConstrained(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_class(cls, disable_overlap=False) + cls.check_jump_forward = False + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def run_decode( + self, + ebnf, + expected_patterns, + prompt, + return_logprob=False, + top_logprobs_num=0, + n=1, + ): + response = requests.post( + self.base_url + "/generate", + json={ + "text": prompt, + "sampling_params": { + "temperature": 0 if n == 1 else 0.5, + "max_new_tokens": 128, + "n": n, + "ebnf": ebnf, + }, + "stream": False, + "return_logprob": return_logprob, + "top_logprobs_num": top_logprobs_num, + "logprob_start_len": 0, + }, + ) + + ret = response.json() + print(json.dumps(ret, indent=2)) + print("=" * 100) + + if not isinstance(ret, list): + self.fail(f"Expected response to be a list, but got {type(ret)}") + + for item in ret: + text = item.get("text", "").strip() + if not text: + self.fail("Generated text is empty.") + + match = False + for pattern in expected_patterns: + if self.regex_match(text, pattern): + match = True + break + if not match: + self.fail(f"Text '{text}' does not match any of the allowed patterns.") + + def regex_match(self, text, pattern): + import re + + return re.match(pattern, text) is not None + + def test_ebnf_generate_email(self): + self.__class__.ebnf_grammar = 'root ::= "user@example.com"' + allowed_patterns = [r"^user@example\.com$"] + prompt = "Generate an email address:" + + self.run_decode( + ebnf=self.__class__.ebnf_grammar, + expected_patterns=allowed_patterns, + prompt=prompt, + n=3, + ) + + def test_ebnf_generate_greeting(self): + self.__class__.ebnf_grammar = 'root ::= "Hello" | "Hi" | "Hey"' + allowed_patterns = [r"^(Hello|Hi|Hey)$"] + prompt = "Generate a greeting:" + + self.run_decode( + ebnf=self.__class__.ebnf_grammar, + expected_patterns=allowed_patterns, + prompt=prompt, + n=3, + ) + + def test_ebnf_generate_number(self): + self.__class__.ebnf_grammar = """ + root ::= digit digit digit + digit ::= [0-9] + """ + allowed_patterns = [r"^\d{3}$"] + prompt = "Generate a three-digit number:" + + self.run_decode( + ebnf=self.__class__.ebnf_grammar, + expected_patterns=allowed_patterns, + prompt=prompt, + n=3, + ) + + def test_ebnf_generate_phone(self): + self.__class__.ebnf_grammar = """ + root ::= "(" area ")" " " prefix "-" line + area ::= [0-9] [0-9] [0-9] + prefix ::= [0-9] [0-9] [0-9] + line ::= [0-9] [0-9] [0-9] [0-9] + """ + allowed_patterns = [r"^\(\d{3}\) \d{3}-\d{4}$"] + prompt = "Generate a phone number:" + + self.run_decode( + ebnf=self.__class__.ebnf_grammar, + expected_patterns=allowed_patterns, + prompt=prompt, + n=3, + ) + + def test_ebnf_generate_date(self): + self.__class__.ebnf_grammar = """ + root ::= year "-" month "-" day + year ::= "2024" + month ::= "01" | "02" | "03" | "04" | "05" | "06" | "07" | "08" | "09" | "10" | "11" | "12" + day ::= "01" | "02" | "03" | "04" | "05" | "06" | "07" | "08" | "09" | "10" | + "11" | "12" | "13" | "14" | "15" | "16" | "17" | "18" | "19" | "20" | + "21" | "22" | "23" | "24" | "25" | "26" | "27" | "28" | "29" | "30" | "31" + """ + allowed_patterns = [r"^2024-(0[1-9]|1[0-2])-(0[1-9]|[12]\d|3[01])$"] + prompt = "Generate a date in YYYY-MM-DD format:" + + self.run_decode( + ebnf=self.__class__.ebnf_grammar, + expected_patterns=allowed_patterns, + prompt=prompt, + n=3, + ) + + def test_ebnf_generate_hex_color(self): + self.__class__.ebnf_grammar = """ + root ::= "#" hex hex hex hex hex hex + hex ::= [0-9] | [A-F] + """ + allowed_patterns = [r"^#[0-9A-F]{6}$"] + prompt = "Generate a hex color code:" + + self.run_decode( + ebnf=self.__class__.ebnf_grammar, + expected_patterns=allowed_patterns, + prompt=prompt, + n=3, + ) + + def test_ebnf_generate_complex_json(self): + self.__class__.ebnf_grammar = """ + root ::= object + object ::= "{" ws pair (ws "," ws pair)* ws "}" + pair ::= "\\"name\\"" ws ":" ws value | + "\\"age\\"" ws ":" ws number | + "\\"city\\"" ws ":" ws string + value ::= string | number + string ::= "\\"" [a-zA-Z0-9 ]+ "\\"" + number ::= [1-9] [0-9]* + ws ::= [ ]* + """ + allowed_patterns = [ + r'^{\s*"name"\s*:\s*"[a-zA-Z0-9 ]+"\s*,\s*"age"\s*:\s*[1-9][0-9]*\s*,\s*"city"\s*:\s*"[a-zA-Z0-9 ]+"\s*}$', + ] + prompt = "Generate a simple JSON with name, age, and city:" + + self.run_decode( + ebnf=self.__class__.ebnf_grammar, + expected_patterns=allowed_patterns, + prompt=prompt, + n=3, + ) + + def test_ebnf_generate_custom_log_format(self): + self.__class__.ebnf_grammar = """ + root ::= logentry + logentry ::= "[" datetime "] " level ": System.process - " message + datetime ::= "2024-01-01T12:00:00Z" + level ::= "INFO" + message ::= "Operation " [a-z]+ " successfully" + """ + allowed_patterns = [ + r"^\[2024-01-01T12:00:00Z\] INFO: System\.process - Operation [a-z]+ successfully$" + ] + prompt = "Generate a log entry:" + + self.run_decode( + ebnf=self.__class__.ebnf_grammar, + expected_patterns=allowed_patterns, + prompt=prompt, + n=3, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_fp8_kernel.py b/test/srt/test_fp8_kernel.py new file mode 100644 index 00000000000..fe92bfd0769 --- /dev/null +++ b/test/srt/test_fp8_kernel.py @@ -0,0 +1,127 @@ +import unittest + +import torch + +from sglang.srt.layers.quantization.fp8_kernel import ( + per_token_group_quant_fp8, + w8a8_block_fp8_matmul, +) + + +class TestFP8Base(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.M = 256 + # test non-aligned + cls.N = 1024 + 64 + cls.K = 512 + cls.group_size = 128 + cls.quant_type = torch.float8_e4m3fn + cls.output_type = torch.float16 + + @staticmethod + def _make_A(M, K, group_size, out_dtype): + quant_A = torch.rand( + M, K // group_size, group_size, dtype=torch.float32, device="cuda" + ) + # -1 ~ 1 + quant_A = quant_A * 2 - 1 + # scaling abs max to fmax + finfo = torch.finfo(out_dtype) + fmax = finfo.max + scaling = fmax / quant_A.abs().amax(-1, keepdim=True) + quant_A *= scaling + quant_A = quant_A.to(out_dtype).to(torch.float32) + + # create scale and A + scale = torch.rand(M, K // group_size, dtype=torch.float32, device="cuda") + scale /= fmax + A = quant_A * scale[..., None] + + A = A.reshape(M, K) + quant_A = quant_A.reshape(M, K).to(out_dtype) + return A, quant_A, scale + + @staticmethod + def _make_B(K, N, group_size, out_dtype): + def _aligned_size(a, b): + return (a + b - 1) // b * b + + K_aligned = _aligned_size(K, group_size) + N_aligned = _aligned_size(N, group_size) + + quant_B = torch.rand( + K_aligned // group_size, + group_size, + N_aligned // group_size, + group_size, + dtype=torch.float32, + device="cuda", + ) + quant_B = quant_B * 2 - 1 + + # scaling abs max to fmax + finfo = torch.finfo(out_dtype) + fmax = finfo.max + scaling = fmax / quant_B.abs().amax((1, 3), keepdim=True) + quant_B *= scaling + quant_B = quant_B.to(out_dtype).to(torch.float32) + + scale = torch.rand( + K_aligned // group_size, + 1, + N_aligned // group_size, + 1, + dtype=torch.float32, + device="cuda", + ) + scale /= fmax + + B = quant_B * scale + + B = B.reshape(K_aligned, N_aligned)[:K, :N] + quant_B = quant_B.reshape(K_aligned, N_aligned).to(out_dtype)[:K, :N] + scale = scale.reshape(K_aligned // group_size, N_aligned // group_size) + return B, quant_B, scale + + +class TestPerTokenGroupQuantFP8(TestFP8Base): + def test_per_token_group_quant_fp8(self): + if torch.cuda.get_device_capability()[0] < 9: + return + A, A_quant_gt, scale_gt = self._make_A( + M=self.M, K=self.K, group_size=self.group_size, out_dtype=self.quant_type + ) + A_quant, scale = per_token_group_quant_fp8( + x=A, group_size=self.group_size, dtype=self.quant_type + ) + torch.testing.assert_close(scale, scale_gt) + diff = (A_quant.to(torch.float16) - A_quant_gt.to(torch.float16)).abs() + diff_count = (diff > 1e-5).count_nonzero() + assert diff_count / diff.numel() < 1e-4 + + +class TestW8A8BlockFP8Matmul(TestFP8Base): + def test_w8a8_block_fp8_matmul(self): + if torch.cuda.get_device_capability()[0] < 9: + return + A, A_quant_gt, A_scale_gt = self._make_A( + M=self.M, K=self.K, group_size=self.group_size, out_dtype=self.quant_type + ) + B, B_quant_gt, B_scale_gt = self._make_B( + K=self.K, N=self.N, group_size=self.group_size, out_dtype=self.quant_type + ) + C_gt = A.to(self.output_type) @ B.to(self.output_type) + C = w8a8_block_fp8_matmul( + A=A_quant_gt, + B=B_quant_gt.T.contiguous(), + As=A_scale_gt, + Bs=B_scale_gt.T.contiguous(), + block_size=[128, 128], + output_dtype=self.output_type, + ) + torch.testing.assert_close(C, C_gt, atol=0.5, rtol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_fp8_kvcache.py b/test/srt/test_fp8_kvcache.py new file mode 100644 index 00000000000..4a8a2434699 --- /dev/null +++ b/test/srt/test_fp8_kvcache.py @@ -0,0 +1,113 @@ +import os +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestFp8KvcacheBase(unittest.TestCase): + model_config = None + + @classmethod + def setUpClass(cls): + if cls.model_config is None: + raise NotImplementedError("model_config must be specified in subclass") + + cls.model = cls.model_config["model_name"] + cls.base_url = DEFAULT_URL_FOR_TEST + dirpath = os.path.dirname(__file__) + config_file = os.path.join(dirpath, cls.model_config["config_filename"]) + + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--kv-cache-dtype", + "fp8_e4m3", + "--quantization-param-path", + config_file, + ], + ) + + +class TestFp8KvcacheLlama(TestFp8KvcacheBase): + model_config = { + "model_name": DEFAULT_MODEL_NAME_FOR_TEST, + "config_filename": "kv_cache_scales_llama3_8b.json", + } + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mgsm_en(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mgsm_en", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + self.assertGreater(metrics["score"], 0.80) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + self.assertGreaterEqual(metrics["score"], 0.65) + + +class TestFp8KvcacheQwen(TestFp8KvcacheBase): + model_config = { + "model_name": DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN, + "config_filename": "kv_cache_scales_qwen2_1_5b.json", + } + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mgsm_en(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mgsm_en", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + self.assertGreater(metrics["score"], 0.01) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + self.assertGreaterEqual(metrics["score"], 0.3) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_function_calling.py b/test/srt/test_function_calling.py new file mode 100644 index 00000000000..24f341a5e47 --- /dev/null +++ b/test/srt/test_function_calling.py @@ -0,0 +1,249 @@ +import json +import time +import unittest + +import openai + +from sglang.srt.hf_transformers_utils import get_tokenizer +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestOpenAIServerFunctionCalling(unittest.TestCase): + @classmethod + def setUpClass(cls): + # Replace with the model name needed for testing; if not required, reuse DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + + # Start the local OpenAI Server. If necessary, you can add other parameters such as --enable-tools. + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + other_args=[ + # If your server needs extra parameters to test function calling, please add them here. + "--tool-call-parser", + "llama3", + ], + ) + cls.base_url += "/v1" + cls.tokenizer = get_tokenizer(cls.model) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_function_calling_format(self): + """ + Test: Whether the function call format returned by the AI is correct. + When returning a tool call, message.content should be None, and tool_calls should be a list. + """ + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + tools = [ + { + "type": "function", + "function": { + "name": "add", + "description": "Compute the sum of two numbers", + "parameters": { + "type": "object", + "properties": { + "a": { + "type": "int", + "description": "A number", + }, + "b": { + "type": "int", + "description": "A number", + }, + }, + "required": ["a", "b"], + }, + }, + } + ] + + messages = [{"role": "user", "content": "Compute (3+5)"}] + response = client.chat.completions.create( + model=self.model, + messages=messages, + temperature=0.8, + top_p=0.8, + stream=False, + tools=tools, + ) + + content = response.choices[0].message.content + tool_calls = response.choices[0].message.tool_calls + + assert content is None, ( + "When function call is successful, message.content should be None, " + f"but got: {content}" + ) + assert ( + isinstance(tool_calls, list) and len(tool_calls) > 0 + ), "tool_calls should be a non-empty list" + + function_name = tool_calls[0].function.name + assert function_name == "add", "Function name should be 'add'" + + def test_function_calling_streaming_simple(self): + """ + Test: Whether the function name can be correctly recognized in streaming mode. + - Expect a function call to be found, and the function name to be correct. + - Verify that streaming mode returns at least multiple chunks. + """ + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for", + }, + "unit": { + "type": "string", + "description": "Weather unit (celsius or fahrenheit)", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["city", "unit"], + }, + }, + } + ] + + messages = [{"role": "user", "content": "What is the temperature in Paris?"}] + + response_stream = client.chat.completions.create( + model=self.model, + messages=messages, + temperature=0.8, + top_p=0.8, + stream=True, + tools=tools, + ) + + chunks = list(response_stream) + self.assertTrue(len(chunks) > 0, "Streaming should return at least one chunk") + + found_function_name = False + for chunk in chunks: + choice = chunk.choices[0] + # Check whether the current chunk contains tool_calls + if choice.delta.tool_calls: + tool_call = choice.delta.tool_calls[0] + if tool_call.function.name: + self.assertEqual( + tool_call.function.name, + "get_current_weather", + "Function name should be 'get_current_weather'", + ) + found_function_name = True + break + + self.assertTrue( + found_function_name, + "Target function name 'get_current_weather' was not found in the streaming chunks", + ) + + def test_function_calling_streaming_args_parsing(self): + """ + Test: Whether the function call arguments returned in streaming mode can be correctly concatenated into valid JSON. + - The user request requires multiple parameters. + - AI may return the arguments in chunks that need to be concatenated. + """ + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + + tools = [ + { + "type": "function", + "function": { + "name": "add", + "description": "Compute the sum of two integers", + "parameters": { + "type": "object", + "properties": { + "a": { + "type": "int", + "description": "First integer", + }, + "b": { + "type": "int", + "description": "Second integer", + }, + }, + "required": ["a", "b"], + }, + }, + } + ] + + messages = [ + {"role": "user", "content": "Please sum 5 and 7, just call the function."} + ] + + response_stream = client.chat.completions.create( + model=self.model, + messages=messages, + temperature=0.9, + top_p=0.9, + stream=True, + tools=tools, + ) + + argument_fragments = [] + function_name = None + for chunk in response_stream: + choice = chunk.choices[0] + if choice.delta.tool_calls: + tool_call = choice.delta.tool_calls[0] + # Record the function name on first occurrence + function_name = tool_call.function.name or function_name + # In case of multiple chunks, JSON fragments may need to be concatenated + if tool_call.function.arguments: + argument_fragments.append(tool_call.function.arguments) + + self.assertEqual(function_name, "add", "Function name should be 'add'") + joined_args = "".join(argument_fragments) + self.assertTrue( + len(joined_args) > 0, + "No parameter fragments were returned in the function call", + ) + + # Check whether the concatenated JSON is valid + try: + args_obj = json.loads(joined_args) + except json.JSONDecodeError: + self.fail( + "The concatenated tool call arguments are not valid JSON, parsing failed" + ) + + self.assertIn("a", args_obj, "Missing parameter 'a'") + self.assertIn("b", args_obj, "Missing parameter 'b'") + self.assertEqual( + args_obj["a"], + 5, + "Parameter a should be 5", + ) + self.assertEqual(args_obj["b"], 7, "Parameter b should be 7") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_fused_moe.py b/test/srt/test_fused_moe.py new file mode 100644 index 00000000000..80aeab257c3 --- /dev/null +++ b/test/srt/test_fused_moe.py @@ -0,0 +1,126 @@ +import unittest + +import torch +from vllm.model_executor.layers.fused_moe import fused_moe as fused_moe_vllm + +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe + + +class TestFusedMOE(unittest.TestCase): + NUM_EXPERTS = [8, 64] + TOP_KS = [2, 6] + + def torch_naive_moe(self, a, w1, w2, score, topk): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[ + i + ].transpose(0, 1) + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) + + def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False): + if use_fp8_w8a8: + # AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 + capability = torch.cuda.get_device_capability() + if not (capability[0] >= 9 or capability == (8, 9)): + return + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + w1 = w1.to(torch.float8_e4m3fn) + w2 = w2.to(torch.float8_e4m3fn) + score = torch.randn((m, e), device="cuda", dtype=dtype) + + w1_scale = torch.randn(e, dtype=torch.float32, device="cuda") + w2_scale = torch.randn(e, dtype=torch.float32, device="cuda") + a1_scale = torch.randn(1, dtype=torch.float32, device="cuda") + a2_scale = torch.randn(1, dtype=torch.float32, device="cuda") + + sglang_output = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) + + vllm_output = fused_moe_vllm( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) + + torch.testing.assert_close(sglang_output, vllm_output, atol=2e-2, rtol=0) + + else: + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + score = torch.randn((m, e), device="cuda", dtype=dtype) + + triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) + torch_output = self.torch_naive_moe(a, w1, w2, score, topk) + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) + + def test_various_configurations(self): + m_values = [1, 33, 64, 222, 1024 * 128] + n_values = [128, 1024, 2048] + k_values = [128, 511, 1024] + dtypes = [torch.float16, torch.bfloat16] + fp8_modes = [False, True] + + for m in m_values: + for n in n_values: + for k in k_values: + for e in self.NUM_EXPERTS: + for topk in self.TOP_KS: + for dtype in dtypes: + for use_fp8_w8a8 in fp8_modes: + with self.subTest( + m=m, + n=n, + k=k, + e=e, + topk=topk, + dtype=dtype, + fp8=use_fp8_w8a8, + ): + self._test_case( + m, + n, + k, + e, + topk, + dtype, + use_fp8_w8a8=use_fp8_w8a8, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py index 1a857d0da6e..adb5c18fbe2 100644 --- a/test/srt/test_json_constrained.py +++ b/test/srt/test_json_constrained.py @@ -95,15 +95,6 @@ def run_decode(self, json_schema, return_logprob=False, top_logprobs_num=0, n=1) self.assertIsInstance(js_obj["name"], str) self.assertIsInstance(js_obj["population"], int) - # Make sure jump forward is triggered - # NOTE: The overlap scheduler does not support jump forward so we only do this test - # when --disable-overlap-schedule is set. - if self.check_jump_forward: - self.assertGreater( - ret["meta_info"]["completion_tokens"], - ret["meta_info"]["completion_tokens_wo_jump_forward"], - ) - def test_json_generate(self): self.run_decode(json_schema=self.json_schema) diff --git a/test/srt/test_metrics.py b/test/srt/test_metrics.py index 3b73e500d77..2837107a1e6 100644 --- a/test/srt/test_metrics.py +++ b/test/srt/test_metrics.py @@ -51,12 +51,14 @@ def test_metrics_enabled(self): # Verify essential metrics are present essential_metrics = [ "sglang:num_running_reqs", + "sglang:num_used_tokens", "sglang:token_usage", "sglang:gen_throughput", + "sglang:num_queue_reqs", "sglang:cache_hit_rate", - "sglang:func_latency_seconds", "sglang:prompt_tokens_total", "sglang:generation_tokens_total", + "sglang:num_requests_total", "sglang:time_to_first_token_seconds", "sglang:time_per_output_token_seconds", "sglang:e2e_request_latency_seconds", diff --git a/test/srt/test_mla.py b/test/srt/test_mla.py index b8105a84af1..34bc4b44645 100644 --- a/test/srt/test_mla.py +++ b/test/srt/test_mla.py @@ -2,6 +2,7 @@ from types import SimpleNamespace from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MLA_MODEL_NAME_FOR_TEST, @@ -20,7 +21,7 @@ def setUpClass(cls): cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--tp", "2", "--trust-remote-code"], + other_args=["--trust-remote-code"], ) @classmethod @@ -52,5 +53,37 @@ def test_mgsm_en(self): self.assertGreater(metrics["score"], 0.8) +class TestDeepseekV3(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = "lmzheng/sglang-ci-dsv3-test" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--trust-remote-code"], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.62) + + if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_mla_fp8.py b/test/srt/test_mla_fp8.py index 769bdf34da8..4fe18b526b1 100644 --- a/test/srt/test_mla_fp8.py +++ b/test/srt/test_mla_fp8.py @@ -21,8 +21,6 @@ def setUpClass(cls): cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ - "--tp", - "2", "--trust-remote-code", "--kv-cache-dtype", "fp8_e5m2", diff --git a/test/srt/test_moe_ep.py b/test/srt/test_moe_ep.py index 4d9fd435edb..9f87eb24d71 100644 --- a/test/srt/test_moe_ep.py +++ b/test/srt/test_moe_ep.py @@ -44,7 +44,7 @@ def test_mmlu(self): ) metrics = run_eval(args) - assert metrics["score"] >= 0.5 + self.assertGreater(metrics["score"], 0.5) def test_mgsm_en(self): args = SimpleNamespace( @@ -56,7 +56,7 @@ def test_mgsm_en(self): ) metrics = run_eval(args) - assert metrics["score"] >= 0.8 + self.assertGreater(metrics["score"], 0.8) class TestEpMoEFP8(unittest.TestCase): diff --git a/test/srt/test_moe_eval_accuracy_large.py b/test/srt/test_moe_eval_accuracy_large.py index 6f3affbba4d..dc420f00dfa 100644 --- a/test/srt/test_moe_eval_accuracy_large.py +++ b/test/srt/test_moe_eval_accuracy_large.py @@ -71,7 +71,7 @@ def test_mgsm_en(self): ) metrics = run_eval(args) - self.assertGreater(metrics["score"], 0.62) + self.assertGreater(metrics["score"], 0.61) if __name__ == "__main__": diff --git a/test/srt/test_nightly_gsm8k_eval.py b/test/srt/test_nightly_gsm8k_eval.py index 8466c2c6489..6fe36171504 100644 --- a/test/srt/test_nightly_gsm8k_eval.py +++ b/test/srt/test_nightly_gsm8k_eval.py @@ -1,6 +1,5 @@ import json import os -import subprocess import unittest import warnings from datetime import datetime @@ -16,27 +15,29 @@ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, + is_in_ci, popen_launch_server, + write_github_step_summary, ) MODEL_SCORE_THRESHOLDS = { - "meta-llama/Llama-3.1-8B-Instruct": 0.83, + "meta-llama/Llama-3.1-8B-Instruct": 0.82, "mistralai/Mistral-7B-Instruct-v0.3": 0.58, - "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": 0.84, + "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": 0.85, "google/gemma-2-27b-it": 0.92, - "meta-llama/Llama-3.1-70B-Instruct": 0.96, - "mistralai/Mixtral-8x7B-Instruct-v0.1": 0.64, - "Qwen/Qwen2-57B-A14B-Instruct": 0.87, - "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8": 0.84, + "meta-llama/Llama-3.1-70B-Instruct": 0.95, + "mistralai/Mixtral-8x7B-Instruct-v0.1": 0.63, + "Qwen/Qwen2-57B-A14B-Instruct": 0.86, + "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8": 0.83, "neuralmagic/Mistral-7B-Instruct-v0.3-FP8": 0.54, - "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8": 0.83, + "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8": 0.84, "neuralmagic/gemma-2-2b-it-FP8": 0.60, - "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8": 0.95, - "neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8": 0.61, - "neuralmagic/Qwen2-72B-Instruct-FP8": 0.95, + "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8": 0.94, + "neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8": 0.62, + "neuralmagic/Qwen2-72B-Instruct-FP8": 0.94, "neuralmagic/Qwen2-57B-A14B-Instruct-FP8": 0.82, "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4": 0.84, - "hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4": 0.84, + "hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4": 0.83, } @@ -44,7 +45,7 @@ def parse_models(model_string): return [model.strip() for model in model_string.split(",") if model.strip()] -def launch_server(base_url, model, is_fp8, is_tp2): +def popen_launch_server_wrapper(base_url, model, is_fp8, is_tp2): other_args = ["--log-level-http", "warning", "--trust-remote-code"] if is_fp8: if "Llama-3" in model or "gemma-2" in model: @@ -67,7 +68,6 @@ def launch_server(base_url, model, is_fp8, is_tp2): base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=other_args, - return_stdout_stderr=(subprocess.DEVNULL, subprocess.DEVNULL), ) return process @@ -99,6 +99,9 @@ def write_results_to_json(model, metrics, mode="a"): def check_model_scores(results): failed_models = [] + summary = " | model | score | threshold |\n" + summary += "| ----- | ----- | --------- |\n" + for model, score in results: threshold = MODEL_SCORE_THRESHOLDS.get(model) if threshold is None: @@ -111,11 +114,19 @@ def check_model_scores(results): f"Model {model} score ({score:.4f}) is below threshold ({threshold:.4f})" ) + line = f"| {model} | {score} | {threshold} |\n" + summary += line + + print(summary) + + if is_in_ci(): + write_github_step_summary(f"### TestNightlyGsm8KEval\n{summary}") + if failed_models: raise AssertionError("\n".join(failed_models)) -class TestEvalAccuracyLarge(unittest.TestCase): +class TestNightlyGsm8KEval(unittest.TestCase): @classmethod def setUpClass(cls): cls.model_groups = [ @@ -127,13 +138,6 @@ def setUpClass(cls): ] cls.base_url = DEFAULT_URL_FOR_TEST - def setUp(self): - self.process = None - - def tearDown(self): - if self.process: - kill_process_tree(self.process.pid) - def test_mgsm_en_all_models(self): warnings.filterwarnings( "ignore", category=ResourceWarning, message="unclosed.*socket" @@ -144,7 +148,9 @@ def test_mgsm_en_all_models(self): for model_group, is_fp8, is_tp2 in self.model_groups: for model in model_group: with self.subTest(model=model): - self.process = launch_server(self.base_url, model, is_fp8, is_tp2) + process = popen_launch_server_wrapper( + self.base_url, model, is_fp8, is_tp2 + ) args = SimpleNamespace( base_url=self.base_url, @@ -163,8 +169,7 @@ def test_mgsm_en_all_models(self): is_first = False all_results.append((model, metrics["score"])) - - self.tearDown() + kill_process_tree(process.pid) try: with open("results.json", "r") as f: diff --git a/test/srt/test_nightly_human_eval.py b/test/srt/test_nightly_human_eval.py index 626e6fb153f..6558b9effb9 100644 --- a/test/srt/test_nightly_human_eval.py +++ b/test/srt/test_nightly_human_eval.py @@ -4,7 +4,7 @@ import subprocess import unittest -from test_nightly_gsm8k_eval import launch_server, parse_models +from test_nightly_gsm8k_eval import parse_models, popen_launch_server_wrapper from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( @@ -12,19 +12,28 @@ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2, DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1, DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2, + DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_URL_FOR_TEST, + is_in_ci, ) -class TestEvalAccuracyLarge(unittest.TestCase): +class TestNightlyHumanEval(unittest.TestCase): @classmethod def setUpClass(cls): - cls.model_groups = [ - (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1), False, False), - (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2), False, True), - (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1), True, False), - (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2), True, True), - ] + if is_in_ci(): + cls.model_groups = [([DEFAULT_MODEL_NAME_FOR_TEST], False, False)] + else: + cls.model_groups = [ + (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1), False, False), + (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2), False, True), + ( + parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1), + True, + False, + ), + (parse_models(DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2), True, True), + ] cls.base_url = DEFAULT_URL_FOR_TEST cls.process = None cls.eval_process = None @@ -84,7 +93,7 @@ def test_human_eval_all_models(self): # NOTE: only Llama for now if "Llama" in model: with self.subTest(model=model): - self.process = launch_server( + self.process = popen_launch_server_wrapper( self.base_url, model, is_fp8, is_tp2 ) self.run_evalplus(model) diff --git a/test/srt/test_nightly_math_eval.py b/test/srt/test_nightly_math_eval.py new file mode 100644 index 00000000000..3a4eb0adfe4 --- /dev/null +++ b/test/srt/test_nightly_math_eval.py @@ -0,0 +1,46 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestEvalAccuracyLarge(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--log-level-http", "warning"], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_math(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="math", + num_examples=5000, + num_threads=1024, + ) + + metrics = run_eval(args) + self.assertGreaterEqual( + metrics["score"], 0.519 - 0.02 + ) # -2% to account for sampling variance + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index d007bed31ef..23e0287292b 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -5,6 +5,7 @@ """ import json +import re import time import unittest @@ -13,6 +14,7 @@ from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( + DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -535,5 +537,132 @@ def test_response_prefill(self): ) +# ------------------------------------------------------------------------- +# EBNF Test Class: TestOpenAIServerEBNF +# Launches the server with xgrammar, has only EBNF tests +# ------------------------------------------------------------------------- +class TestOpenAIServerEBNF(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + + # passing xgrammar specifically + other_args = ["--grammar-backend", "xgrammar"] + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + other_args=other_args, + ) + cls.base_url += "/v1" + cls.tokenizer = get_tokenizer(DEFAULT_SMALL_MODEL_NAME_FOR_TEST) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_ebnf(self): + """ + Ensure we can pass `ebnf` to the local openai server + and that it enforces the grammar. + """ + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + ebnf_grammar = r""" + root ::= "Hello" | "Hi" | "Hey" + """ + pattern = re.compile(r"^(Hello|Hi|Hey)[.!?]*\s*$") + + response = client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "You are a helpful EBNF test bot."}, + {"role": "user", "content": "Say a greeting (Hello, Hi, or Hey)."}, + ], + temperature=0, + max_tokens=32, + extra_body={"ebnf": ebnf_grammar}, + ) + text = response.choices[0].message.content.strip() + print("EBNF test output:", repr(text)) + self.assertTrue(len(text) > 0, "Got empty text from EBNF generation") + self.assertRegex(text, pattern, f"Text '{text}' doesn't match EBNF choices") + + def test_ebnf_strict_json(self): + """ + A stricter EBNF that produces exactly {"name":"Alice"} format + with no trailing punctuation or extra fields. + """ + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + ebnf_grammar = r""" + root ::= "{" pair "}" + pair ::= "\"name\"" ":" string + string ::= "\"" [A-Za-z]+ "\"" + """ + pattern = re.compile(r'^\{"name":"[A-Za-z]+"\}$') + + response = client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "EBNF mini-JSON generator."}, + { + "role": "user", + "content": "Generate single key JSON with only letters.", + }, + ], + temperature=0, + max_tokens=64, + extra_body={"ebnf": ebnf_grammar}, + ) + text = response.choices[0].message.content.strip() + print("EBNF strict JSON test output:", repr(text)) + self.assertTrue(len(text) > 0, "Got empty text from EBNF strict JSON test") + self.assertRegex( + text, pattern, f"Text '{text}' not matching the EBNF strict JSON shape" + ) + + +class TestOpenAIEmbedding(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + + # Configure embedding-specific args + other_args = ["--is-embedding", "--enable-metrics"] + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + other_args=other_args, + ) + cls.base_url += "/v1" + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_embedding_single(self): + """Test single embedding request""" + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + response = client.embeddings.create(model=self.model, input="Hello world") + self.assertEqual(len(response.data), 1) + self.assertTrue(len(response.data[0].embedding) > 0) + + def test_embedding_batch(self): + """Test batch embedding request""" + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + response = client.embeddings.create( + model=self.model, input=["Hello world", "Test text"] + ) + self.assertEqual(len(response.data), 2) + self.assertTrue(len(response.data[0].embedding) > 0) + self.assertTrue(len(response.data[1].embedding) > 0) + + if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_regex_constrained.py b/test/srt/test_regex_constrained.py new file mode 100644 index 00000000000..6d5acec15e2 --- /dev/null +++ b/test/srt/test_regex_constrained.py @@ -0,0 +1,186 @@ +""" +python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_email +python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_greeting +""" + +import json +import unittest + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +def setup_class(cls, disable_overlap: bool): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + + other_args = [ + "--max-running-requests", + "10", + "--grammar-backend", + "xgrammar", + ] + + if disable_overlap: + other_args += ["--disable-overlap-schedule"] + + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + +class TestRegexConstrained(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_class(cls, disable_overlap=False) + cls.check_jump_forward = False + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def run_decode( + self, + regex, + prompt, + return_logprob=False, + top_logprobs_num=0, + n=1, + ): + response = requests.post( + self.base_url + "/generate", + json={ + "text": prompt, + "sampling_params": { + "temperature": 0 if n == 1 else 0.5, + "max_new_tokens": 128, + "n": n, + "regex": regex, + }, + "stream": False, + "return_logprob": return_logprob, + "top_logprobs_num": top_logprobs_num, + "logprob_start_len": 0, + }, + ) + + ret = response.json() + print(json.dumps(ret, indent=2)) + print("=" * 100) + + if not isinstance(ret, list): + self.fail(f"Expected response to be a list, but got {type(ret)}") + + for item in ret: + text = item.get("text", "").strip() + if not text: + self.fail("Generated text is empty.") + + if not self.regex_match(text, regex): + self.fail(f"Text '{text}' does not match regex pattern.") + + def regex_match(self, text, pattern): + import re + + return re.match(pattern, text) is not None + + def test_regex_generate_email(self): + pattern = r"^user@example\.com$" + prompt = "Generate an email address:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + def test_regex_generate_greeting(self): + pattern = r"^(Hello|Hi|Hey)$" + prompt = "Generate a greeting:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + def test_regex_generate_number(self): + pattern = r"^\d{3}$" + prompt = "Generate a three-digit number:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + def test_regex_generate_phone(self): + pattern = r"^\(\d{3}\) \d{3}-\d{4}$" + prompt = "Generate a phone number:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + def test_regex_generate_date(self): + pattern = r"^2024-(0[1-9]|1[0-2])-(0[1-9]|[12]\d|3[01])$" + prompt = "Generate a date in YYYY-MM-DD format:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + def test_regex_generate_hex_color(self): + pattern = r"^#[0-9A-F]{6}$" + prompt = "Generate a hex color code:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + def test_regex_generate_complex_json(self): + pattern = r'^\{\s*"name"\s*:\s*"[a-zA-Z0-9 ]+"\s*,\s*"age"\s*:\s*[1-9][0-9]*\s*,\s*"city"\s*:\s*"[a-zA-Z0-9 ]+"\s*\}$' + prompt = "Generate a simple JSON with name, age, and city:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + def test_regex_generate_custom_log_format(self): + pattern = r"^\[2024-01-01T12:00:00Z\] INFO: System\.process - Operation [a-z]+ successfully$" + prompt = "Generate a log entry:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + +class TestJumpForward(TestRegexConstrained): + @classmethod + def setUpClass(cls): + setup_class(cls, disable_overlap=True) + cls.check_jump_forward = True + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_release_memory_occupation.py b/test/srt/test_release_memory_occupation.py new file mode 100644 index 00000000000..c84b64e77df --- /dev/null +++ b/test/srt/test_release_memory_occupation.py @@ -0,0 +1,98 @@ +import time +import unittest + +import torch +from transformers import AutoModelForCausalLM + +import sglang as sgl +from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST + +# (temporarily) set to true to observe memory usage in nvidia-smi more clearly +_DEBUG_EXTRA = True + + +class TestReleaseMemoryOccupation(unittest.TestCase): + def test_release_and_resume_occupation(self): + prompt = "Today is a sunny day and I like" + sampling_params = {"temperature": 0, "max_new_tokens": 8} + model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + expect_output = " to spend it outdoors. I decided to" + + engine = sgl.Engine( + model_path=model_name, + random_seed=42, + enable_memory_saver=True, + # disable_cuda_graph=True, # for debugging only + ) + hf_model_new = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype="bfloat16" + ) + + print("generate (#1)") + outputs = engine.generate(prompt, sampling_params)["text"] + self.assertEqual(outputs, expect_output) + + if _DEBUG_EXTRA: + time.sleep(3) + + self.assertEqual( + _try_allocate_big_tensor(), + False, + "Should not be able to allocate big tensors before releasing", + ) + + print("release_memory_occupation start") + t = time.time() + engine.release_memory_occupation() + if _DEBUG_EXTRA: + print("release_memory_occupation", time.time() - t) + + if _DEBUG_EXTRA: + time.sleep(5) + + self.assertEqual( + _try_allocate_big_tensor(), + True, + "Should be able to allocate big tensors aftre releasing", + ) + + if _DEBUG_EXTRA: + time.sleep(5) + + print("resume_memory_occupation start") + t = time.time() + engine.resume_memory_occupation() + if _DEBUG_EXTRA: + print("resume_memory_occupation", time.time() - t) + + self.assertEqual( + _try_allocate_big_tensor(), + False, + "Should not be able to allocate big tensors after resuming", + ) + + print("update_weights_from_tensor") + # As if: PPO has updated hf model's weights, and now we sync it to SGLang + engine.update_weights_from_tensor(list(hf_model_new.named_parameters())) + + print("generate (#2)") + outputs = engine.generate(prompt, sampling_params)["text"] + self.assertEqual(outputs, expect_output) + + if _DEBUG_EXTRA: + time.sleep(4) + + engine.shutdown() + + +def _try_allocate_big_tensor(size: int = 20_000_000_000): + try: + torch.empty((size,), dtype=torch.uint8, device="cuda") + torch.cuda.empty_cache() + return True + except torch.cuda.OutOfMemoryError: + return False + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_request_length_validation.py b/test/srt/test_request_length_validation.py new file mode 100644 index 00000000000..713e3e21e56 --- /dev/null +++ b/test/srt/test_request_length_validation.py @@ -0,0 +1,71 @@ +import unittest + +import openai + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestRequestLengthValidation(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + + # Start server with auto truncate disabled + cls.process = popen_launch_server( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + other_args=("--max-total-tokens", "1000", "--context-length", "100"), + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_input_length_validation(self): + client = openai.Client(api_key=self.api_key, base_url=f"{self.base_url}/v1") + + long_text = "hello " * 100 # Will tokenize to more than context length + + with self.assertRaises(openai.BadRequestError) as cm: + client.chat.completions.create( + model=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + messages=[ + {"role": "user", "content": long_text}, + ], + temperature=0, + ) + + self.assertIn("is longer than the model's context length", str(cm.exception)) + + def test_max_tokens_validation(self): + client = openai.Client(api_key=self.api_key, base_url=f"{self.base_url}/v1") + + long_text = "hello " + + with self.assertRaises(openai.BadRequestError) as cm: + client.chat.completions.create( + model=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + messages=[ + {"role": "user", "content": long_text}, + ], + temperature=0, + max_tokens=500, + ) + + self.assertIn( + "Requested token count exceeds the model's maximum context", + str(cm.exception), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_schedule_policy.py b/test/srt/test_schedule_policy.py new file mode 100644 index 00000000000..52c5b828984 --- /dev/null +++ b/test/srt/test_schedule_policy.py @@ -0,0 +1,52 @@ +import unittest + +from sglang.srt.managers.schedule_batch import Req +from sglang.srt.managers.schedule_policy import ( + CacheAgnosticPolicy, + CacheAwarePolicy, + SchedulePolicy, +) +from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode +from sglang.srt.sampling.sampling_params import SamplingParams + + +class TestSchedulePolicy(unittest.TestCase): + + def setUp(self): + self.tree_cache = RadixCache(None, None, False) + + def test_init_with_cache_aware_policy(self): + policy = SchedulePolicy(policy="lpm", tree_cache=self.tree_cache) + self.assertEqual(policy.policy, CacheAwarePolicy.LPM) + + def test_init_with_cache_agnostic_policy(self): + policy = SchedulePolicy(policy="fcfs", tree_cache=self.tree_cache) + self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS) + + def test_init_with_unknown_policy(self): + with self.assertRaises(ValueError): + SchedulePolicy(policy="invalid", tree_cache=self.tree_cache) + + def test_init_with_disabled_cache(self): + disabled_tree_cache = RadixCache(None, None, disable=True) + policy = SchedulePolicy(policy="lpm", tree_cache=disabled_tree_cache) + self.assertEqual(policy.policy, CacheAgnosticPolicy.FCFS) + + def test_calc_priority_fcfs(self): + tree_cache = RadixCache(None, None, False) + waiting_queue = [ + Req(1, "a b", [1, 2], SamplingParams()), + Req(3, "a b c", [1, 2, 3], SamplingParams()), + Req(2, "a", [1], SamplingParams()), + ] + + policy = SchedulePolicy(policy="fcfs", tree_cache=tree_cache) + policy.calc_priority(waiting_queue) + # Check if FCFS keeps the original order + self.assertEqual(waiting_queue[0].rid, 1) + self.assertEqual(waiting_queue[1].rid, 3) + self.assertEqual(waiting_queue[2].rid, 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_session_control.py b/test/srt/test_session_control.py index 47169aeaa36..2915133f437 100644 --- a/test/srt/test_session_control.py +++ b/test/srt/test_session_control.py @@ -1,11 +1,16 @@ """ Usage: python3 -m unittest test_session_control.TestSessionControl.test_session_control +python3 -m unittest test_session_control.TestSessionControl.test_session_control_with_branching +python3 -m unittest test_session_control.TestSessionControl.test_session_control_backtrack_with_abort python3 -m unittest test_session_control.TestSessionControlVision.test_session_control """ +import asyncio +import json import unittest +import aiohttp import requests from sglang.srt.hf_transformers_utils import get_tokenizer @@ -18,6 +23,10 @@ ) +def remove_prefix(text: str, prefix: str) -> str: + return text[len(prefix) :] if text.startswith(prefix) else text + + class TestSessionControl(unittest.TestCase): @classmethod def setUpClass(cls): @@ -31,23 +40,34 @@ def setUpClass(cls): def tearDownClass(cls): kill_process_tree(cls.process.pid) - def test_session_control(self): + def test_session_control(self, gen_len=12): chunks = [ "Let me tell you something about France.", "The capital of France is", + "The population of the city is", "A brief history about that city is", - "To plan a travel, the budget is", ] tokenizer = get_tokenizer(self.model) chunks_ids = [tokenizer.encode(x) for x in chunks] + for i in range(1, len(chunks_ids)): + if chunks_ids[i][0] == tokenizer.bos_token_id: + chunks_ids[i] = chunks_ids[i][1:] # 1. using session control + requests.post(self.base_url + "/flush_cache") session_id = requests.post( self.base_url + "/open_session", json={"capacity_of_str_len": 1000}, ).json() rid = None + # open an existing session, should get session_id as None + response = requests.post( + self.base_url + "/open_session", + json={"capacity_of_str_len": 1000, "session_id": session_id}, + ).json() + assert isinstance(response, dict) and "error" in response + first_rid = None outputs_from_session = [] for i, chunk_ids in enumerate(chunks_ids): @@ -55,11 +75,16 @@ def test_session_control(self): self.base_url + "/generate", json={ "input_ids": chunk_ids, - "session": [session_id, rid], + "session_params": { + "id": session_id, + "rid": rid, + "offset": -1, + "replace": True, + }, "sampling_params": { "temperature": 0, "max_new_tokens": ( - 16 if i > 0 else 0 + gen_len if i > 0 else 1 ), # prefill only for the first chunk "no_stop_trim": True, "skip_special_tokens": False, @@ -77,10 +102,15 @@ def test_session_control(self): self.base_url + "/generate", json={ "input_ids": chunks_ids[-1], - "session": [session_id, first_rid], + "session_params": { + "id": session_id, + "rid": first_rid, + "offset": -1, + "replace": True, + }, "sampling_params": { "temperature": 0, - "max_new_tokens": 16, + "max_new_tokens": gen_len, "no_stop_trim": True, "skip_special_tokens": False, }, @@ -93,10 +123,15 @@ def test_session_control(self): self.base_url + "/generate", json={ "input_ids": chunks_ids[-1], - "session": [session_id, rid], + "session_params": { + "id": session_id, + "rid": rid, + "offset": -1, + "replace": True, + }, "sampling_params": { "temperature": 0, - "max_new_tokens": 16, + "max_new_tokens": gen_len, "no_stop_trim": True, "skip_special_tokens": False, }, @@ -115,10 +150,15 @@ def test_session_control(self): self.base_url + "/generate", json={ "input_ids": chunks_ids[-1], - "session": [session_id, first_rid], + "session_params": { + "id": session_id, + "rid": first_rid, + "offset": -1, + "replace": True, + }, "sampling_params": { "temperature": 0, - "max_new_tokens": 16, + "max_new_tokens": gen_len, "no_stop_trim": True, "skip_special_tokens": False, }, @@ -127,6 +167,8 @@ def test_session_control(self): assert response["meta_info"]["finish_reason"]["type"] == "abort" # 2. not use session control + requests.post(self.base_url + "/flush_cache") + input_ids_first_req = None input_ids = [] outputs_normal = [] @@ -139,7 +181,7 @@ def test_session_control(self): "sampling_params": { "temperature": 0, "max_new_tokens": ( - 16 if i > 0 else 0 + gen_len if i > 0 else 1 ), # prefill only for the first chunk "no_stop_trim": True, "skip_special_tokens": False, @@ -150,7 +192,7 @@ def test_session_control(self): output_ids = tokenizer.encode(response["text"]) if output_ids[0] == tokenizer.bos_token_id: output_ids = output_ids[1:] - input_ids += output_ids + input_ids += output_ids[:-1] outputs_normal.append(response["text"]) if i == 0: input_ids_first_req = input_ids.copy() @@ -162,7 +204,7 @@ def test_session_control(self): "input_ids": input_ids_first_req, "sampling_params": { "temperature": 0, - "max_new_tokens": 16, + "max_new_tokens": gen_len, "no_stop_trim": True, "skip_special_tokens": False, }, @@ -174,7 +216,282 @@ def test_session_control(self): print(outputs_from_session) print("outputs from normal queries:") print(outputs_normal) - assert outputs_from_session == outputs_normal + assert ( + outputs_from_session == outputs_normal + ), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}" + + async def async_generate(self, payload): + url = self.base_url + "/generate" + async with aiohttp.ClientSession() as session: + async with session.post(url=url, json=payload) as response: + assert response.status == 200 + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") + if chunk == "[DONE]": + yield "", None, "" + else: + data = json.loads(chunk) + finish_reason = ( + data["meta_info"]["finish_reason"]["type"] + if data["meta_info"]["finish_reason"] + else "" + ) + yield data["text"], data["meta_info"]["id"], finish_reason + + async def run_session_control_backtrack_with_abort(self, replace): + chunks = [ + "Let me tell you something about France.", + "The capital of France is", + ] + tokenizer = get_tokenizer(self.model) + chunks_ids = [tokenizer.encode(x) for x in chunks] + for i in range(1, len(chunks_ids)): + if chunks_ids[i][0] == tokenizer.bos_token_id: + chunks_ids[i] = chunks_ids[i][1:] + + # 1. using session control + requests.post(self.base_url + "/flush_cache") + session_id = requests.post( + self.base_url + "/open_session", + json={"capacity_of_str_len": 1000}, + ).json() + rid = None + + payload = { + "input_ids": chunks_ids[0], + "session_params": { + "id": session_id, + "rid": rid, + "offset": -1, + "replace": True, + }, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 100, + "no_stop_trim": True, + "skip_special_tokens": False, + "ignore_eos": True, + }, + "stream": True, + } + gen_so_far = "" + finish_reason = "" + second_output = "" + async for chunk, rid, finish_reason_chunk in self.async_generate(payload): + gen_so_far += chunk + if finish_reason == "": + finish_reason = finish_reason_chunk + if len(gen_so_far) > 50 and second_output == "": + payload2 = { + "input_ids": chunks_ids[1], + "session_params": { + "id": session_id, + "rid": rid, + "offset": 50, + "replace": replace, + }, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 32, + "no_stop_trim": True, + "skip_special_tokens": False, + }, + "stream": False, + "stream_output": True, + } + response = requests.post( + url=self.base_url + "/generate", json=payload2 + ).json() + second_output = response["text"] + if replace: + assert finish_reason == "abort" + print("first request output:") + print(gen_so_far) + print("second request output:") + print(second_output) + + # close the session + ret = requests.post( + self.base_url + "/close_session", + json={"session_id": session_id}, + ) + assert ret.status_code == 200 + + if not replace: + assert response["meta_info"]["finish_reason"]["type"] == "abort" + else: + # 2. not using session control + requests.post(self.base_url + "/flush_cache") + output_ids = tokenizer.encode(gen_so_far) + if output_ids[0] == tokenizer.bos_token_id: + output_ids = output_ids[1:] + input_ids = chunks_ids[0] + output_ids + input_ids = input_ids[:50] + chunks_ids[1] + payload = { + "input_ids": input_ids, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 32, + "no_stop_trim": True, + "skip_special_tokens": False, + }, + "stream": False, + "stream_output": True, + } + response = requests.post( + url=self.base_url + "/generate", json=payload + ).json() + output_no_session = response["text"] + print("second request output without session:") + print(output_no_session) + assert ( + second_output == output_no_session + ), f"second_output: {second_output}, output_no_session: {output_no_session}" + + def test_session_control_backtrack_with_abort(self): + asyncio.run(self.run_session_control_backtrack_with_abort(replace=True)) + asyncio.run(self.run_session_control_backtrack_with_abort(replace=False)) + + def run_session_control_with_branching( + self, root_prompt, chunks_per_step, gen_len=16 + ): + for x in chunks_per_step: + assert len(x) == len(chunks_per_step[0]) + + # 1. using session control + requests.post(self.base_url + "/flush_cache") + session_id = requests.post( + self.base_url + "/open_session", + json={"capacity_of_str_len": 1000}, + ).json() + + outputs_from_session = [] + # send the root prompt + response = requests.post( + self.base_url + "/generate", + json={ + "text": root_prompt, + "session_params": { + "id": session_id, + "rid": None, + "offset": 0, + "replace": False, + }, + "sampling_params": { + "temperature": 0, + "max_new_tokens": gen_len, + "no_stop_trim": True, + "skip_special_tokens": False, + }, + }, + ).json() + rid_per_branch = [response["meta_info"]["id"]] * len(chunks_per_step[0]) + outputs_from_session.append(response["text"]) + + # send the prompts in branches + for chunks_for_branches in chunks_per_step: + for j, chunk in enumerate(chunks_for_branches): + response = requests.post( + self.base_url + "/generate", + json={ + "text": chunk, + "session_params": { + "id": session_id, + "rid": rid_per_branch[j], + "offset": 0, + "replace": False, + }, + "sampling_params": { + "temperature": 0, + "max_new_tokens": gen_len, + "no_stop_trim": True, + "skip_special_tokens": False, + }, + }, + ).json() + rid = response["meta_info"]["id"] + rid_per_branch[j] = rid + outputs_from_session.append(response["text"]) + + # close the session + ret = requests.post( + self.base_url + "/close_session", + json={"session_id": session_id}, + ) + assert ret.status_code == 200 + + # 2. not use session control + requests.post(self.base_url + "/flush_cache") + + outputs_normal = [] + input_texts = [root_prompt] * len(chunks_per_step[0]) + # send the root prompt + response = requests.post( + self.base_url + "/generate", + json={ + "text": root_prompt, + "sampling_params": { + "temperature": 0, + "max_new_tokens": gen_len, + "no_stop_trim": True, + "skip_special_tokens": False, + }, + }, + ).json() + outputs_normal.append(response["text"]) + input_texts = [x + response["text"] for x in input_texts] + + # send the prompts in branches + for chunks_for_branches in chunks_per_step: + for j, chunk in enumerate(chunks_for_branches): + input_texts[j] += chunk + response = requests.post( + self.base_url + "/generate", + json={ + "text": input_texts[j], + "sampling_params": { + "temperature": 0, + "max_new_tokens": gen_len, + "no_stop_trim": True, + "skip_special_tokens": False, + }, + }, + ).json() + outputs_normal.append(response["text"]) + input_texts[j] += response["text"] + + print("====== outputs from chunked queries with session control: =======") + print(outputs_from_session) + print("====== outputs from normal queries: =======") + print(outputs_normal) + assert ( + outputs_from_session == outputs_normal + ), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}" + + def test_session_control_with_branching(self): + root_prompt = "First, let me explain in one sentence about AI" + chunks_per_step = [ + [ + "Then, briefly, the positive side of AI is", + "But, briefly, AI could be harmful to human", + ], + ["For example", "For example"], + ] + self.run_session_control_with_branching( + root_prompt=root_prompt, chunks_per_step=chunks_per_step, gen_len=8 + ) + + root_prompt = "I have three apples." + chunks_per_step = [ + ["I then give one apple to my friend", "My friend give me another apple."], + ["I still have", "I now have"], + ] + self.run_session_control_with_branching( + root_prompt=root_prompt, chunks_per_step=chunks_per_step, gen_len=8 + ) class TestSessionControlVision(unittest.TestCase): @@ -197,39 +514,60 @@ def test_session_control(self): text_chunks = [ "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n", "<|im_start|>user\n\nDescribe this image in a very short sentence.<|im_end|>\n<|im_start|>assistant\n", - "<|im_start|>user\n\nIs this image same with the previous image? Answer yes or no.<|im_end|>\n<|im_start|>assistant\n", - "<|im_start|>user\n\nIs this image same with the previous image? Answer yes or no.<|im_end|>\n<|im_start|>assistant\n", + "<|im_start|>user\n\nIs this image same with one of the previous images?<|im_end|>\n<|im_start|>assistant\n", + "<|im_start|>user\n\nIs this image same with one of the previous images?<|im_end|>\n<|im_start|>assistant\n", + "<|im_start|>user\nDescribe this image in a very short sentence.<|im_end|>\nassistant:", ] image_chunks = [ - "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png", "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png", "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png", + "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png", ] - assert len(text_chunks) == len(image_chunks) + 1 + + assert ( + len(text_chunks) == len(image_chunks) + 2 + ) # the first and the last prompt does not contain images tokenizer = get_tokenizer(self.model) text_input_ids = [tokenizer.encode(x) for x in text_chunks] + for i in range(1, len(text_input_ids)): + if text_input_ids[i][0] == tokenizer.bos_token_id: + text_input_ids[i] = text_input_ids[i][1:] + gen_len = 32 # 1. using session control + requests.post(self.base_url + "/flush_cache") session_id = requests.post( self.base_url + "/open_session", json={"capacity_of_str_len": 1000}, ).json() rid = None + # open an existing session, should get session_id as None + response = requests.post( + self.base_url + "/open_session", + json={"capacity_of_str_len": 1000, "session_id": session_id}, + ).json() + assert isinstance(response, dict) and "error" in response + first_rid = None outputs_from_session = [] - for i in range(len(text_input_ids)): + for i in range(len(text_input_ids[:-1])): response = requests.post( self.base_url + "/generate", json={ "input_ids": text_input_ids[i], "image_data": image_chunks[i - 1] if i > 0 else None, "modalities": ["multi-images"], - "session": [session_id, rid], + "session_params": { + "id": session_id, + "rid": rid, + "offset": 0, + "replace": True, + }, "sampling_params": { "temperature": 0, "max_new_tokens": ( - 16 if i > 0 else 0 + gen_len if i > 0 else 0 ), # prefill only for the first chunk "no_stop_trim": True, "skip_special_tokens": False, @@ -247,12 +585,15 @@ def test_session_control(self): self.base_url + "/generate", json={ "input_ids": text_input_ids[-1], - "image_data": image_chunks[-1:], - "modalities": ["multi-images"], - "session": [session_id, first_rid], + "session_params": { + "id": session_id, + "rid": first_rid, + "offset": 0, + "replace": True, + }, "sampling_params": { "temperature": 0, - "max_new_tokens": 16, + "max_new_tokens": gen_len, "no_stop_trim": True, "skip_special_tokens": False, }, @@ -265,12 +606,15 @@ def test_session_control(self): self.base_url + "/generate", json={ "input_ids": text_input_ids[-1], - "image_data": image_chunks[-1:], - "modalities": ["multi-images"], - "session": [session_id, rid], + "session_params": { + "id": session_id, + "rid": rid, + "offset": 0, + "replace": True, + }, "sampling_params": { "temperature": 0, - "max_new_tokens": 16, + "max_new_tokens": gen_len, "no_stop_trim": True, "skip_special_tokens": False, }, @@ -289,10 +633,15 @@ def test_session_control(self): self.base_url + "/generate", json={ "input_ids": text_input_ids[-1], - "session": [session_id, first_rid], + "session_params": { + "id": session_id, + "rid": first_rid, + "offset": 0, + "replace": True, + }, "sampling_params": { "temperature": 0, - "max_new_tokens": 16, + "max_new_tokens": gen_len, "no_stop_trim": True, "skip_special_tokens": False, }, @@ -306,7 +655,7 @@ def test_session_control(self): input_ids_first_req = None input_ids = [] outputs_normal = [] - for i in range(len(text_input_ids)): + for i in range(len(text_input_ids[:-1])): input_ids += text_input_ids[i] image_data = image_chunks[:i] if i > 0 else None response = requests.post( @@ -318,7 +667,7 @@ def test_session_control(self): "sampling_params": { "temperature": 0, "max_new_tokens": ( - 16 if i > 0 else 0 + gen_len if i > 0 else 0 ), # prefill only for the first chunk "no_stop_trim": True, "skip_special_tokens": False, @@ -339,11 +688,9 @@ def test_session_control(self): self.base_url + "/generate", json={ "input_ids": input_ids_first_req, - "image_data": image_chunks[-1:], - "modalities": ["multi-images"], "sampling_params": { "temperature": 0, - "max_new_tokens": 16, + "max_new_tokens": gen_len, "no_stop_trim": True, "skip_special_tokens": False, }, @@ -355,7 +702,9 @@ def test_session_control(self): print(outputs_from_session) print("outputs from normal queries:") print(outputs_normal) - assert outputs_from_session == outputs_normal + assert ( + outputs_from_session == outputs_normal + ), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}" if __name__ == "__main__": diff --git a/test/srt/test_skip_tokenizer_init.py b/test/srt/test_skip_tokenizer_init.py index bc99b23ad58..db70944091f 100644 --- a/test/srt/test_skip_tokenizer_init.py +++ b/test/srt/test_skip_tokenizer_init.py @@ -1,11 +1,8 @@ -""" -python3 -m unittest test_skip_tokenizer_init.TestSkipTokenizerInit.test_parallel_sample -""" - import json import unittest import requests +from transformers import AutoTokenizer from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( @@ -15,35 +12,63 @@ popen_launch_server, ) +_server_process = None +_base_url = None +_tokenizer = None + + +def setUpModule(): + """ + Launch the server once before all tests and initialize the tokenizer. + """ + global _server_process, _base_url, _tokenizer + _server_process = popen_launch_server( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_URL_FOR_TEST, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--skip-tokenizer-init"], + ) + _base_url = DEFAULT_URL_FOR_TEST + + _tokenizer = AutoTokenizer.from_pretrained( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, use_fast=False + ) + print(">>> setUpModule: Server launched, tokenizer ready") + + +def tearDownModule(): + """ + Terminate the server once after all tests have completed. + """ + global _server_process + if _server_process is not None: + kill_process_tree(_server_process.pid) + _server_process = None + print(">>> tearDownModule: Server terminated") -class TestSkipTokenizerInit(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--skip-tokenizer-init"], - ) - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) +class TestSkipTokenizerInit(unittest.TestCase): + def run_decode( + self, + prompt_text="The capital of France is", + max_new_tokens=32, + return_logprob=False, + top_logprobs_num=0, + n=1, + ): + input_ids = _tokenizer(prompt_text, return_tensors="pt")["input_ids"][ + 0 + ].tolist() - def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1): - max_new_tokens = 32 - input_ids = [128000, 791, 6864, 315, 9822, 374] # The capital of France is response = requests.post( - self.base_url + "/generate", + _base_url + "/generate", json={ "input_ids": input_ids, "sampling_params": { "temperature": 0 if n == 1 else 0.5, "max_new_tokens": max_new_tokens, "n": n, - "stop_token_ids": [119690], + "stop_token_ids": [_tokenizer.eos_token_id], }, "stream": False, "return_logprob": return_logprob, @@ -52,23 +77,37 @@ def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1): }, ) ret = response.json() - print(json.dumps(ret)) + print(json.dumps(ret, indent=2)) def assert_one_item(item): - assert len(item["token_ids"]) == item["meta_info"]["completion_tokens"] - assert len(item["token_ids"]) == max_new_tokens - assert item["meta_info"]["prompt_tokens"] == len(input_ids) - - if return_logprob: - assert len(item["meta_info"]["input_token_logprobs"]) == len( - input_ids - ), f'{len(item["meta_info"]["input_token_logprobs"])} vs. f{len(input_ids)}' - assert len(item["meta_info"]["output_token_logprobs"]) == max_new_tokens - + if item["meta_info"]["finish_reason"]["type"] == "stop": + self.assertEqual( + item["meta_info"]["finish_reason"]["matched"], + _tokenizer.eos_token_id, + ) + elif item["meta_info"]["finish_reason"]["type"] == "length": + self.assertEqual( + len(item["token_ids"]), item["meta_info"]["completion_tokens"] + ) + self.assertEqual(len(item["token_ids"]), max_new_tokens) + self.assertEqual(item["meta_info"]["prompt_tokens"], len(input_ids)) + + if return_logprob: + self.assertEqual( + len(item["meta_info"]["input_token_logprobs"]), + len(input_ids), + f'{len(item["meta_info"]["input_token_logprobs"])} mismatch with {len(input_ids)}', + ) + self.assertEqual( + len(item["meta_info"]["output_token_logprobs"]), + max_new_tokens, + ) + + # Determine whether to assert a single item or multiple items based on n if n == 1: assert_one_item(ret) else: - assert len(ret) == n + self.assertEqual(len(ret), n) for i in range(n): assert_one_item(ret[i]) @@ -82,10 +121,10 @@ def test_parallel_sample(self): def test_logprob(self): for top_logprobs_num in [0, 3]: - self.run_decode( - return_logprob=True, - top_logprobs_num=top_logprobs_num, - ) + self.run_decode(return_logprob=True, top_logprobs_num=top_logprobs_num) + + def test_eos_behavior(self): + self.run_decode(max_new_tokens=256) if __name__ == "__main__": diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index aff1d4a78fc..68db1d69983 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -4,11 +4,16 @@ """ import json +import random +import time import unittest +from concurrent.futures import ThreadPoolExecutor +from typing import Optional import numpy as np import requests +from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_SMALL_MODEL_NAME_FOR_TEST, @@ -24,7 +29,14 @@ def setUpClass(cls): cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( - cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=( + "--enable-custom-logit-processor", + "--mem-fraction-static", + "0.8", + ), ) @classmethod @@ -147,14 +159,26 @@ def test_logprob_with_chunked_prefill(self): }, "return_logprob": True, "logprob_start_len": -1, + "top_logprobs_num": 5, }, ) response_json = response.json() - print(json.dumps(response_json, indent=2)) + # print(json.dumps(response_json, indent=2)) res = response_json self.assertEqual(res["meta_info"]["completion_tokens"], new_tokens) + + # Test the number of tokens are correct self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), new_tokens) + self.assertEqual(len(res["meta_info"]["output_top_logprobs"]), new_tokens) + + # Test the top-1 tokens are the same as output tokens (because temp = 0.0) + for i in range(new_tokens): + self.assertListEqual( + res["meta_info"]["output_token_logprobs"][i], + res["meta_info"]["output_top_logprobs"][i][0], + ) + self.assertEqual(len(res["meta_info"]["output_top_logprobs"][i]), 5) def test_logprob_match(self): """Test the output logprobs are close to the input logprobs if we run a prefill again.""" @@ -213,6 +237,232 @@ def run_generate( max_diff = np.max(diff) self.assertLess(max_diff, 0.25) + def run_logprob_check(self, arg): + ( + input_len, + output_len, + temperature, + logprob_start_len, + return_logprob, + top_logprobs_num, + ) = arg + input_ids = list(range(input_len)) + + response = requests.post( + self.base_url + "/generate", + json={ + "input_ids": input_ids, + "sampling_params": { + "temperature": temperature, + "max_new_tokens": output_len, + }, + "return_logprob": return_logprob, + "logprob_start_len": logprob_start_len, + "top_logprobs_num": top_logprobs_num, + }, + ) + response_json = response.json() + + res = response_json + self.assertEqual(res["meta_info"]["prompt_tokens"], input_len) + self.assertEqual(res["meta_info"]["completion_tokens"], output_len) + + # Test the number of tokens are correct + if return_logprob: + # This is because if logprob_start_len == 0, we added a padding for the first token. + # In other cases, we do not add the padding + delta = 0 if logprob_start_len == 0 else 1 + + self.assertEqual( + len(res["meta_info"]["input_token_logprobs"]) + + logprob_start_len + + delta, + res["meta_info"]["prompt_tokens"], + ) + self.assertEqual(len(res["meta_info"]["output_token_logprobs"]), output_len) + + if top_logprobs_num: + self.assertEqual( + len(res["meta_info"]["input_top_logprobs"]) + + logprob_start_len + + delta, + res["meta_info"]["prompt_tokens"], + ) + self.assertEqual( + len(res["meta_info"]["output_top_logprobs"]), output_len + ) + + for i in range(output_len): + self.assertEqual( + len(res["meta_info"]["output_top_logprobs"][i]), + top_logprobs_num, + ) + + # Test the top-1 tokens are the same as output tokens if temperature == 0 + if temperature == 0: + self.assertListEqual( + res["meta_info"]["output_token_logprobs"][i], + res["meta_info"]["output_top_logprobs"][i][0], + ) + + def test_logprob_mixed(self): + args = [] + temperature = 0 + # input_len, output_len, temperature, logprob_start_len, return_logprob, top_logprobs_num + for input_len in [1000, 2000]: + for output_len in [4, 8]: + for logprob_start_len in [0, 500, 1000]: + for return_logprob in [True, False]: + for top_logprobs_num in [0, 5]: + + if logprob_start_len >= input_len: + continue + + args.append( + ( + input_len, + output_len, + temperature, + logprob_start_len, + return_logprob, + top_logprobs_num, + ) + ) + + random.shuffle(args) + + with ThreadPoolExecutor(8) as executor: + list(executor.map(self.run_logprob_check, args)) + + def test_logprob_grammar(self): + prompts = "Question: Is Paris the Capital of France? Answer:" + allowed_tokens = [" Yes", " No"] + + response = requests.post( + self.base_url + "/generate", + json={ + "text": prompts, + "sampling_params": { + "temperature": 1.0, + "max_new_tokens": 1, + "regex": "( Yes| No)", + }, + "return_logprob": True, + "top_logprobs_num": 5, # The grammar constraint allows all prefix tokens so we need to use a larger top_k. + "return_text_in_logprobs": True, + }, + ) + response_json = response.json() + output_top_logprobs = response_json["meta_info"]["output_top_logprobs"][0] + print(f"{output_top_logprobs=}") + + # Parse results + # This is becaues the grammar constraint allows all prefix tokens + logprobs = [None] * 2 + for i in range(len(output_top_logprobs)): + try: + idx = allowed_tokens.index(output_top_logprobs[i][2]) + except ValueError: + # Not found + continue + logprobs[idx] = output_top_logprobs[i][0] + + self.assertTrue(all(x is not None for x in logprobs)) + + def run_custom_logit_processor(self, target_token_id: Optional[int] = None): + """Test custom logit processor with custom params. + + If target_token_id is None, the custom logit processor won't be passed in. + """ + + custom_params = {"token_id": target_token_id} + + class DeterministicLogitProcessor(CustomLogitProcessor): + """A dummy logit processor that changes the logits to always + sample the given token id. + """ + + def __call__(self, logits, custom_param_list): + assert logits.shape[0] == len(custom_param_list) + key = "token_id" + + for i, param_dict in enumerate(custom_param_list): + # Mask all other tokens + logits[i, :] = -float("inf") + # Assign highest probability to the specified token + logits[i, param_dict[key]] = 0.0 + return logits + + prompts = "Question: Is Paris the Capital of France? Answer:" + + # Base case json data to be posted to the server. + base_json = { + "text": prompts, + "sampling_params": {"temperature": 0.0}, + "return_logprob": True, + } + + # Custom json data with custom logit processor and params. + custom_json = base_json.copy() + # Only set the custom logit processor if target_token_id is not None. + if target_token_id is not None: + custom_json["custom_logit_processor"] = ( + DeterministicLogitProcessor().to_str() + ) + custom_json["sampling_params"]["custom_params"] = custom_params + + custom_response = requests.post( + self.base_url + "/generate", + json=custom_json, + ).json() + + output_token_logprobs = custom_response["meta_info"]["output_token_logprobs"] + sampled_tokens = [x[1] for x in output_token_logprobs] + + # The logit processor should always sample the given token as the logits is deterministic. + if target_token_id is not None: + self.assertTrue( + all(x == custom_params["token_id"] for x in sampled_tokens), + # Print the detailed test case info if the test fails. + f"{target_token_id=}\n{sampled_tokens=}\n{custom_response=}", + ) + + def test_custom_logit_processor(self): + """Test custom logit processor with a single request.""" + self.run_custom_logit_processor(target_token_id=5) + + def test_custom_logit_processor_batch_mixed(self): + """Test a batch of requests mixed of requests with and without custom logit processor.""" + target_token_ids = list(range(32)) + [None] * 16 + random.shuffle(target_token_ids) + with ThreadPoolExecutor(len(target_token_ids)) as executor: + list(executor.map(self.run_custom_logit_processor, target_token_ids)) + + def test_cache_tokens(self): + for _ in range(2): + time.sleep(1) + response = requests.post(self.base_url + "/flush_cache") + assert response.status_code == 200 + + def send_and_check_cached_tokens(input_ids): + response = requests.post( + self.base_url + "/generate", + json={ + "input_ids": list(input_ids), + "sampling_params": { + "max_new_tokens": 1, + }, + }, + ) + response_json = response.json() + return response_json["meta_info"]["cached_tokens"] + + self.assertEqual(send_and_check_cached_tokens(range(0, 100)), 0) + self.assertEqual(send_and_check_cached_tokens(range(0, 10000)), 100) + self.assertEqual(send_and_check_cached_tokens(range(0, 10000)), 9999) + self.assertEqual(send_and_check_cached_tokens(range(0, 1000)), 999) + self.assertEqual(send_and_check_cached_tokens(range(0, 11000)), 10000) + def test_get_server_info(self): response = requests.get(self.base_url + "/get_server_info") response_json = response.json() diff --git a/test/srt/test_srt_engine.py b/test/srt/test_srt_engine.py index 7479b646837..c535d5c0686 100644 --- a/test/srt/test_srt_engine.py +++ b/test/srt/test_srt_engine.py @@ -1,6 +1,6 @@ """ Usage: -python3 -m unittest test_srt_engine.TestSRTEngine.test_3_sync_streaming_combination +python3 -m unittest test_srt_engine.TestSRTEngine.test_4_sync_async_stream_combination """ import asyncio @@ -44,64 +44,97 @@ def test_1_engine_runtime_consistency(self): print(out2) self.assertEqual(out1, out2) - def test_2_engine_multiple_generate(self): + def test_2_engine_runtime_encode_consistency(self): + prompt = "Today is a sunny day and I like" + model_path = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST + + engine = sgl.Engine(model_path=model_path, is_embedding=True, random_seed=42) + out1 = torch.tensor(engine.encode(prompt)["embedding"]) + engine.shutdown() + + runtime = sgl.Runtime(model_path=model_path, is_embedding=True, random_seed=42) + out2 = torch.tensor(json.loads(runtime.encode(prompt))["embedding"]) + runtime.shutdown() + + self.assertTrue(torch.allclose(out1, out2, atol=1e-5, rtol=1e-3)) + + def test_3_engine_token_ids_consistency(self): # just to ensure there is no issue running multiple generate calls prompt = "Today is a sunny day and I like" model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST - sampling_params = {"temperature": 0, "max_new_tokens": 8} - engine = sgl.Engine(model_path=model_path, random_seed=42) - engine.generate(prompt, sampling_params) - engine.generate(prompt, sampling_params) - engine.shutdown() + engine = sgl.Engine( + model_path=model_path, random_seed=42, disable_radix_cache=True + ) + out1 = engine.generate(prompt, sampling_params)["text"] - def test_3_sync_streaming_combination(self): + tokenizer = get_tokenizer(model_path) + token_ids = tokenizer.encode(prompt) + out2 = engine.generate(input_ids=token_ids, sampling_params=sampling_params)[ + "text" + ] - prompt = "AI safety is..." - sampling_params = {"temperature": 0.8, "top_p": 0.95} + engine.shutdown() - async def async_streaming(engine): + print("==== Answer 1 ====") + print(out1) - generator = await engine.async_generate( - prompt, sampling_params, stream=True - ) + print("==== Answer 2 ====") + print(out2) + self.assertEqual(out1, out2) - async for output in generator: - print(output["text"], end="", flush=True) - print() + def test_4_sync_async_stream_combination(self): + prompt = "AI safety is" + sampling_params = {"temperature": 0.8, "top_p": 0.95} # Create an LLM. llm = sgl.Engine( model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, ) - # 1. sync + non streaming - print("\n\n==== 1. sync + non streaming ====") - output = llm.generate(prompt, sampling_params) + if True: + # 1. sync + non streaming + print("\n\n==== 1. sync + non streaming ====") + output = llm.generate(prompt, sampling_params) + print(output["text"]) + + # 2. sync + streaming + print("\n\n==== 2. sync + streaming ====") + output_generator = llm.generate(prompt, sampling_params, stream=True) + offset = 0 + for output in output_generator: + print(output["text"][offset:], end="", flush=True) + offset = len(output["text"]) + print() - print(output["text"]) + if True: + loop = asyncio.get_event_loop() + # 3. async + non_streaming + print("\n\n==== 3. async + non streaming ====") + output = loop.run_until_complete( + llm.async_generate(prompt, sampling_params) + ) + print(output["text"]) - # 2. sync + streaming - print("\n\n==== 2. sync + streaming ====") - output_generator = llm.generate(prompt, sampling_params, stream=True) - for output in output_generator: - print(output["text"], end="", flush=True) - print() + # 4. async + streaming + async def async_streaming(engine): + generator = await engine.async_generate( + prompt, sampling_params, stream=True + ) - loop = asyncio.get_event_loop() - # 3. async + non_streaming - print("\n\n==== 3. async + non streaming ====") - output = loop.run_until_complete(llm.async_generate(prompt, sampling_params)) - print(output["text"]) + offset = 0 + async for output in generator: + print(output["text"][offset:], end="", flush=True) + offset = len(output["text"]) + print() - # 4. async + streaming - print("\n\n==== 4. async + streaming ====") - loop.run_until_complete(async_streaming(llm)) + print("\n\n==== 4. async + streaming ====") + loop.run_until_complete(async_streaming(llm)) llm.shutdown() - def test_4_gsm8k(self): + def test_5_gsm8k(self): args = SimpleNamespace( model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, @@ -113,46 +146,7 @@ def test_4_gsm8k(self): metrics = run_eval(args) self.assertGreater(metrics["accuracy"], 0.3) - def test_5_prompt_input_ids_consistency(self): - prompt = "The capital of UK is" - - model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST - engine = sgl.Engine( - model_path=model_path, random_seed=42, disable_radix_cache=True - ) - sampling_params = {"temperature": 0, "max_new_tokens": 8} - out1 = engine.generate(prompt, sampling_params)["text"] - - tokenizer = get_tokenizer(model_path) - token_ids = tokenizer.encode(prompt) - out2 = engine.generate(input_ids=token_ids, sampling_params=sampling_params)[ - "text" - ] - - engine.shutdown() - - print("==== Answer 1 ====") - print(out1) - - print("==== Answer 2 ====") - print(out2) - self.assertEqual(out1, out2) - - def test_6_engine_runtime_encode_consistency(self): - prompt = "Today is a sunny day and I like" - model_path = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST - - engine = sgl.Engine(model_path=model_path, is_embedding=True, random_seed=42) - out1 = torch.tensor(engine.encode(prompt)["embedding"]) - engine.shutdown() - - runtime = sgl.Runtime(model_path=model_path, is_embedding=True, random_seed=42) - out2 = torch.tensor(json.loads(runtime.encode(prompt))["embedding"]) - runtime.shutdown() - - self.assertTrue(torch.allclose(out1, out2, atol=1e-5, rtol=1e-3)) - - def test_7_engine_cpu_offload(self): + def test_6_engine_cpu_offload(self): prompt = "Today is a sunny day and I like" model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST @@ -182,7 +176,7 @@ def test_7_engine_cpu_offload(self): print(out2) self.assertEqual(out1, out2) - def test_8_engine_offline_throughput(self): + def test_7_engine_offline_throughput(self): server_args = ServerArgs( model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, ) diff --git a/test/srt/test_srt_engine_with_quant_args.py b/test/srt/test_srt_engine_with_quant_args.py new file mode 100644 index 00000000000..3851ab41af1 --- /dev/null +++ b/test/srt/test_srt_engine_with_quant_args.py @@ -0,0 +1,60 @@ +import unittest + +import sglang as sgl +from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST + + +class TestSRTEngineWithQuantArgs(unittest.TestCase): + + def test_1_quantization_args(self): + + # we only test fp8 because other methods are currenly depend on vllm. We can add other methods back to test after vllm depency is resolved. + quantization_args_list = [ + # "awq", + "fp8", + # "gptq", + # "marlin", + # "gptq_marlin", + # "awq_marlin", + # "bitsandbytes", + # "gguf", + ] + + prompt = "Today is a sunny day and I like" + model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + + sampling_params = {"temperature": 0, "max_new_tokens": 8} + + for quantization_args in quantization_args_list: + engine = sgl.Engine( + model_path=model_path, random_seed=42, quantization=quantization_args + ) + engine.generate(prompt, sampling_params) + engine.shutdown() + + def test_2_torchao_args(self): + + # we don't test int8dq because currently there is conflict between int8dq and capture cuda graph + torchao_args_list = [ + # "int8dq", + "int8wo", + "fp8wo", + "fp8dq-per_tensor", + "fp8dq-per_row", + ] + [f"int4wo-{group_size}" for group_size in [32, 64, 128, 256]] + + prompt = "Today is a sunny day and I like" + model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + + sampling_params = {"temperature": 0, "max_new_tokens": 8} + + for torchao_config in torchao_args_list: + engine = sgl.Engine( + model_path=model_path, random_seed=42, torchao_config=torchao_config + ) + engine.generate(prompt, sampling_params) + engine.shutdown() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_torch_compile.py b/test/srt/test_torch_compile.py index 6f3b344b3cc..e71de339117 100644 --- a/test/srt/test_torch_compile.py +++ b/test/srt/test_torch_compile.py @@ -23,7 +23,7 @@ def setUpClass(cls): cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--enable-torch-compile"], + other_args=["--enable-torch-compile", "--cuda-graph-max-bs", "4"], ) @classmethod diff --git a/test/srt/test_triton_attention_backend.py b/test/srt/test_triton_attention_backend.py index 905590965d5..88904c55fdf 100644 --- a/test/srt/test_triton_attention_backend.py +++ b/test/srt/test_triton_attention_backend.py @@ -30,7 +30,7 @@ def test_latency(self): ) if is_in_ci(): - assert output_throughput > 153, f"{output_throughput=}" + self.assertGreater(output_throughput, 153) def test_mmlu(self): model = DEFAULT_MODEL_NAME_FOR_TEST diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py index 44abfd61bd7..2398af9b0a7 100644 --- a/test/srt/test_triton_attention_kernels.py +++ b/test/srt/test_triton_attention_kernels.py @@ -182,6 +182,7 @@ def _test_decode_attention_once(self, B, H_Q, H_KV, D): seq_len = 10 # This represents the number of tokens already in the sequence total_tokens = B * seq_len sm_scale = 1.0 / (D**0.5) + num_kv_splits = 8 # q represents the new token being generated, one per batch q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda") @@ -195,12 +196,11 @@ def _test_decode_attention_once(self, B, H_Q, H_KV, D): req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len) b_req_idx = torch.arange(B, device="cuda") - b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda") b_seq_len = torch.full((B,), seq_len, device="cuda") attn_logits = torch.empty( - (H_Q, total_tokens), - dtype=dtype, + (B, H_Q, num_kv_splits, D + 1), + dtype=torch.float32, device="cuda", ) @@ -211,10 +211,9 @@ def _test_decode_attention_once(self, B, H_Q, H_KV, D): o, req_to_token, b_req_idx, - b_start_loc, b_seq_len, attn_logits, - seq_len, + num_kv_splits, sm_scale, ) @@ -233,11 +232,12 @@ def test_decode_attention(self): for B, H_Q, H_KV, D in configs: self._test_decode_attention_once(B, H_Q, H_KV, D) - def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V): + def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V): dtype = torch.bfloat16 - seq_len = 10 # This represents the number of tokens already in the sequence + seq_len = S # This represents the number of tokens already in the sequence total_tokens = B * seq_len sm_scale = 1.0 / (D**0.5) + num_kv_splits = 8 # q represents the new token being generated, one per batch q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda") @@ -247,17 +247,16 @@ def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V): v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device="cuda") # o will have the same shape as q - o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda") - o_grouped = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda") + o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") + o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len) b_req_idx = torch.arange(B, device="cuda") - b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda") b_seq_len = torch.full((B,), seq_len, device="cuda") attn_logits = torch.empty( - (H_Q, total_tokens), - dtype=dtype, + (B, H_Q, num_kv_splits, D_V + 1), + dtype=torch.float32, device="cuda", ) @@ -268,13 +267,18 @@ def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V): o, req_to_token, b_req_idx, - b_start_loc, b_seq_len, attn_logits, - seq_len, + num_kv_splits, sm_scale, ) + attn_logits1 = torch.empty( + (B, H_Q, num_kv_splits, D_V + 1), + dtype=torch.float32, + device="cuda", + ) + decode_attention_fwd_grouped( q, k_buffer, @@ -282,21 +286,23 @@ def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V): o_grouped, req_to_token, b_req_idx, - b_start_loc, b_seq_len, - attn_logits, - seq_len, + attn_logits1, + num_kv_splits, sm_scale, ) cos_sim = torch.nn.functional.cosine_similarity( o.flatten(), o_grouped.flatten(), dim=0 ) + print(cos_sim.item()) self.assertTrue(cos_sim.item() > 0.99) self.assertTrue(torch.allclose(o, o_grouped, atol=3e-2)) def test_grouped_decode_attention(self): + seq_lens = [5, 100, 128, 500] configs = [ + (2, 16, 16, 64, 64), (2, 16, 1, 64, 64), (2, 64, 1, 13, 13), (2, 128, 1, 80, 80), @@ -304,8 +310,9 @@ def test_grouped_decode_attention(self): (2, 128, 1, 576, 512), ] - for B, H_Q, H_KV, D, D_V in configs: - self._test_grouped_decode_attention_once(B, H_Q, H_KV, D, D_V) + for S in seq_lens: + for B, H_Q, H_KV, D, D_V in configs: + self._test_grouped_decode_attention_once(B, S, H_Q, H_KV, D, D_V) if __name__ == "__main__": diff --git a/test/srt/test_update_weights_from_tensor.py b/test/srt/test_update_weights_from_tensor.py new file mode 100644 index 00000000000..f38f76c5deb --- /dev/null +++ b/test/srt/test_update_weights_from_tensor.py @@ -0,0 +1,38 @@ +import time +import unittest + +import torch + +import sglang as sgl +from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST + + +class TestUpdateWeightsFromTensor(unittest.TestCase): + def test_update_weights_from_tensor(self): + engine = sgl.Engine(model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST) + + param_names = [f"model.layers.{i}.mlp.up_proj.weight" for i in range(6, 16)] + + _check_param(engine, param_names[0], [0.0087, -0.0214, -0.0004, 0.0039, 0.0110]) + + new_tensor = torch.full((16384, 2048), 1.5) + + time_start = time.time() + engine.update_weights_from_tensor([(x, new_tensor) for x in param_names]) + print(f"Time delta: {time.time() - time_start:.03f}") + + for param_name in param_names[:3]: + _check_param(engine, param_name, [1.5] * 5) + + engine.shutdown() + + +def _check_param(engine, param_name, expect_values): + actual_values = torch.tensor(engine.get_weights_by_name(param_name))[0, :5] + assert torch.allclose( + actual_values, torch.tensor(expect_values), atol=0.002 + ), f"{actual_values=}" + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_vision_chunked_prefill.py b/test/srt/test_vision_chunked_prefill.py new file mode 100644 index 00000000000..f7725f17bee --- /dev/null +++ b/test/srt/test_vision_chunked_prefill.py @@ -0,0 +1,173 @@ +""" +Usage: +python3 -m unittest test_vision_chunked_prefill.TestVisionChunkedPrefill.test_chunked_prefill +""" + +import base64 +import io +import os +import unittest +from concurrent.futures import ThreadPoolExecutor +from typing import Union + +import numpy as np +import requests +from decord import VideoReader, cpu +from PIL import Image + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestVisionChunkedPrefill(unittest.TestCase): + def prepare_video_messages(self, video_path, max_frames_num=8): + vr = VideoReader(video_path, ctx=cpu(0)) + total_frame_num = len(vr) + uniform_sampled_frames = np.linspace( + 0, total_frame_num - 1, max_frames_num, dtype=int + ) + frame_idx = uniform_sampled_frames.tolist() + frames = vr.get_batch(frame_idx).asnumpy() + + base64_frames = [] + for frame in frames: + pil_img = Image.fromarray(frame) + buff = io.BytesIO() + pil_img.save(buff, format="JPEG") + base64_str = base64.b64encode(buff.getvalue()).decode("utf-8") + base64_frames.append(base64_str) + + messages = [{"role": "user", "content": []}] + frame_format = { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,{}"}, + "modalities": "video", + } + + for base64_frame in base64_frames: + frame_format["image_url"]["url"] = "data:image/jpeg;base64,{}".format( + base64_frame + ) + messages[0]["content"].append(frame_format.copy()) + + prompt = {"type": "text", "text": "Please describe the video briefly."} + messages[0]["content"].append(prompt) + + return messages + + def get_prompt_from_messages(self, messages): + text = ( + "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" + "<|im_start|>user\n" + ) + image_data = [] + for content in messages[0]["content"]: + if content["type"] == "image_url": + text += "\n" + image_data.append(content["image_url"]["url"]) + text += "Please describe the video briefly.<|im_end|>\n<|im_start|>assistant\n" + return text, image_data + + def generate(self, text, image_data): + response = requests.post( + self.base_url + "/generate", + json={ + "text": text, + "image_data": image_data, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 32, + "no_stop_trim": True, + "skip_special_tokens": False, + }, + "modalities": ["multi-images"], + }, + ).json() + return response["text"] + + def generate_for_video(self, batch, num_frame) -> Union[str, list[str]]: + # prepare the video input about Steven introducing ipod nano + url = "https://raw.githubusercontent.com/evolvinglmms-lab/sglang/dev/onevision_local/assets/jobs.mp4" + cache_dir = os.path.expanduser("~/.cache") + file_path = os.path.join(cache_dir, "jobs.mp4") + os.makedirs(cache_dir, exist_ok=True) + if not os.path.exists(file_path): + response = requests.get(url) + response.raise_for_status() + with open(file_path, "wb") as f: + f.write(response.content) + + if not batch: + assert isinstance(num_frame, int) + messages = self.prepare_video_messages(file_path, max_frames_num=num_frame) + text, image_data = self.get_prompt_from_messages(messages) + return self.generate(text, image_data) + else: + assert isinstance(num_frame, list) + func_args = [] + for max_frames_num in num_frame: + messages = self.prepare_video_messages( + file_path, + max_frames_num=max_frames_num, + ) + text, image_data = self.get_prompt_from_messages(messages) + func_args.append((text, image_data)) + + with ThreadPoolExecutor(max_workers=10) as executor: + responses = list(executor.map(lambda p: self.generate(*p), func_args)) + + return responses + + def run_generate(self, chunked_prefill_size, batch, num_frame): + # launch server + model = "lmms-lab/llava-onevision-qwen2-7b-ov" + # model = "meta-llama/Llama-3.2-11B-Vision-Instruct" + self.base_url = DEFAULT_URL_FOR_TEST + process = popen_launch_server( + model, + self.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--chunked-prefill-size", + f"{chunked_prefill_size}", + ], + ) + try: + return self.generate_for_video(batch, num_frame) + finally: + kill_process_tree(process.pid) + + def test_chunked_prefill(self): + output_chunked = self.run_generate( + chunked_prefill_size=1024, batch=False, num_frame=1 + ) + output_no_chunked = self.run_generate( + chunked_prefill_size=-1, batch=False, num_frame=1 + ) + + print("output with chunked prefill:") + print(output_chunked) + print("output without chunked prefill:") + print(output_no_chunked) + assert output_chunked == output_no_chunked + + output_chunked = self.run_generate( + chunked_prefill_size=1024, batch=True, num_frame=[2, 6, 8, 10] + ) + output_no_chunked = self.run_generate( + chunked_prefill_size=-1, batch=True, num_frame=[2, 6, 8, 10] + ) + + print("output with chunked prefill:") + print(output_chunked) + print("output without chunked prefill:") + print(output_no_chunked) + assert output_chunked == output_no_chunked + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_vision_llm.py b/test/srt/test_vision_llm.py new file mode 100644 index 00000000000..7cda64fc0c7 --- /dev/null +++ b/test/srt/test_vision_llm.py @@ -0,0 +1,210 @@ +""" +""" + +import unittest +from io import BytesIO + +import numpy as np +import requests +import torch +import torch.nn.functional as F +from PIL import Image +from transformers import AutoModel, AutoProcessor, AutoTokenizer + +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.conversation import generate_chat_conv +from sglang.srt.model_executor.model_runner import ModelRunner +from sglang.srt.openai_api.protocol import ChatCompletionRequest +from sglang.srt.server_args import ServerArgs + +MiniCPMV = "openbmb/MiniCPM-V-2_6" + + +# Test the logits output between HF and SGLang +class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase): + @classmethod + def setUpClass(cls): + cls.image_url = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true" + cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + cls.model_path = "" + cls.chat_template = "" + cls.processor = "" + response = requests.get(cls.image_url) + cls.main_image = Image.open(BytesIO(response.content)) + + def compare_outputs(self, sglang_output: torch.Tensor, hf_output: torch.Tensor): + # Convert to float32 for numerical stability if needed + hf = hf_output.float() + sg = sglang_output.float() + + # Basic shape and dtype comparison + print("\n=== Basic Properties ===") + print(f"Shapes match: {hf.shape == sg.shape}") + print(f"HF shape: {hf.shape}, SGLang shape: {sg.shape}") + print(f"HF dtype: {hf.dtype}, SGLang dtype: {sg.dtype}") + + # Move tensors to CPU for numpy operations + hf_np = hf.cpu().numpy() + sg_np = sg.cpu().numpy() + + # Statistical metrics + print("\n=== Statistical Metrics ===") + print(f"Mean absolute difference: {torch.mean(torch.abs(hf - sg)).item():.6f}") + print(f"Max absolute difference: {torch.max(torch.abs(hf - sg)).item():.6f}") + print(f"Mean squared error: {torch.mean((hf - sg) ** 2).item():.6f}") + print( + f"Root mean squared error: {torch.sqrt(torch.mean((hf - sg) ** 2)).item():.6f}" + ) + + # Cosine similarity (across feature dimension) + cos_sim = F.cosine_similarity(hf, sg) + print(f"Mean cosine similarity: {torch.mean(cos_sim).item():.6f}") + print(f"Min cosine similarity: {torch.min(cos_sim).item():.6f}") + + # Find largest absolute differences + print("\n=== Largest Absolute Differences ===") + diffs = torch.abs(hf - sg) + flat_diffs = diffs.flatten() + + # Get indices of top 10 differences + top_k = 10 + top_values, top_flat_indices = torch.topk(flat_diffs, top_k) + + # Convert flat indices to multidimensional indices + top_indices = np.unravel_index(top_flat_indices.cpu().numpy(), diffs.shape) + + print(f"\nTop {top_k} largest absolute differences:") + print( + "Index".ljust(30) + + "Difference".ljust(15) + + "HF Value".ljust(15) + + "SGLang Value" + ) + print("-" * 75) + + for i in range(top_k): + # Get the index tuple for this difference + idx = tuple(dim[i] for dim in top_indices) + diff_val = top_values[i].item() + hf_val = hf[idx].item() + sg_val = sg[idx].item() + + # Format the index tuple and values + idx_str = str(idx) + print(f"{idx_str:<30}{diff_val:<15.6f}{hf_val:<15.6f}{sg_val:.6f}") + + np.testing.assert_allclose(hf_np, sg_np) + + def get_processor_output(self): + json_str = f""" + {{ + "model": "{self.model_path}", + "messages": [ + {{ + "role": "user", + "content": [ + {{ + "type": "image_url", + "image_url": {{ + "url": "{self.image_url}" + }} + }}, + {{ + "type": "text", + "text": "Whats in this picture?" + }} + ] + }} + ] +}} + """ + + req = ChatCompletionRequest.model_validate_json(json_str) + + conv = generate_chat_conv(req, template_name=self.chat_template) + + text = conv.get_prompt() + + # Process inputs using processor + # FIXME: the formal arguments may differ + inputs = self.processor( + text=[text], + images=[self.main_image], + return_tensors="pt", + ).to(self.device) + + return inputs + + def get_sglang_model(self): + model_runner = ModelRunner( + model_config=ModelConfig(self.model_path, model_override_args="{}"), + mem_fraction_static=0.8, + gpu_id=0, + tp_rank=0, + tp_size=1, + nccl_port=12435, + server_args=ServerArgs( + model_path=self.model_path, + disable_cuda_graph=True, + ), + ) + return model_runner.model + + +class TestMiniCPMVLogits(VisionLLMLogitsBase): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.model_path = MiniCPMV + cls.tokenizer = AutoTokenizer.from_pretrained( + cls.model_path, trust_remote_code=True + ) + cls.processor = AutoProcessor.from_pretrained( + cls.model_path, trust_remote_code=True + ) + cls.chat_template = "minicpmv" + + cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + cls.model = AutoModel.from_pretrained( + cls.model_path, torch_dtype=torch.bfloat16, trust_remote_code=True + ).eval() + cls.model.to(cls.device) + + async def test_encode_output(self): + inputs = self.get_processor_output() + + with torch.no_grad(): + model_inputs = { + "input_ids": inputs.input_ids, + "image_bound": inputs.image_bound, + "pixel_values": inputs.pixel_values, + "tgt_sizes": inputs.tgt_sizes, + } + (hf_output, _) = self.model.get_vllm_embedding( + model_inputs, + ) + hf_output = hf_output.squeeze(0) + + with torch.no_grad(): + model = self.get_sglang_model() + input_ids = inputs["input_ids"].to(self.device).flatten() + image_inputs = model._parse_and_validate_inputs( + input_ids=input_ids, + **{ + "pixel_values": [inputs["pixel_values"]], + "tgt_sizes": [inputs["tgt_sizes"]], + "im_start_id": [self.tokenizer.im_start_id], + "im_end_id": [self.tokenizer.im_end_id], + "slice_start_id": [self.tokenizer.slice_start_id], + "slice_end_id": [self.tokenizer.slice_end_id], + }, + ) + (sglang_output, _) = model.get_embedding( + input_ids=input_ids, image_inputs=image_inputs + ) + + self.compare_outputs(sglang_output, hf_output) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index e19e6b01d51..01762202882 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -171,7 +171,7 @@ def test_multi_images_chat_completion(self): text = response.choices[0].message.content assert isinstance(text, str) print(text) - assert "man" in text or "cab" in text, text + assert "man" in text or "cab" in text or "SUV" in text or "taxi" in text, text assert "logo" in text or '"S"' in text or "SG" in text, text assert response.id assert response.created @@ -180,7 +180,9 @@ def test_multi_images_chat_completion(self): assert response.usage.total_tokens > 0 def prepare_video_messages(self, video_path): - max_frames_num = 32 + # the memory consumed by the Vision Attention varies a lot, e.g. blocked qkv vs full-sequence sdpa + # the size of the video embeds differs from the `modality` argument when preprocessed + max_frames_num = 12 vr = VideoReader(video_path, ctx=cpu(0)) total_frame_num = len(vr) uniform_sampled_frames = np.linspace( @@ -392,34 +394,33 @@ def tearDownClass(cls): def test_chat_completion(self): client = openai.Client(api_key=self.api_key, base_url=self.base_url) - response = client.chat.completions.create( - model="default", - messages=[ - { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true" + with self.assertRaises(openai.BadRequestError) as cm: + client.chat.completions.create( + model="default", + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true" + }, }, - }, - { - "type": "text", - "text": "Give a lengthy description of this picture", - }, - ], - }, - ], - temperature=0, - ) + { + "type": "text", + "text": "Give a lengthy description of this picture", + }, + ], + }, + ], + temperature=0, + ) - assert response.choices[0].finish_reason == "abort" - assert response.id - assert response.created - assert response.usage.prompt_tokens > 0 - assert response.usage.completion_tokens > 0 - assert response.usage.total_tokens > 0 + self.assertIn( + "Multimodal prompt is too long after expanding multimodal tokens.", + str(cm.exception), + ) class TestMllamaServer(TestOpenAIVisionServer): @@ -444,5 +445,24 @@ def test_video_chat_completion(self): pass +class TestMinicpmvServer(TestOpenAIVisionServer): + @classmethod + def setUpClass(cls): + cls.model = "openbmb/MiniCPM-V-2_6" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--chat-template", + "minicpmv", + ], + ) + cls.base_url += "/v1" + + if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_w8a8_quantization.py b/test/srt/test_w8a8_quantization.py new file mode 100644 index 00000000000..78579d5e2de --- /dev/null +++ b/test/srt/test_w8a8_quantization.py @@ -0,0 +1,74 @@ +import time +import unittest +from types import SimpleNamespace + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestW8A8(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = "neuralmagic/Meta-Llama-3-8B-Instruct-quantized.w8a8" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--quantization", "w8a8_int8"], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.7) + + def run_decode(self, max_new_tokens): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + }, + "ignore_eos": True, + }, + ) + return response.json() + + def test_throughput(self): + max_tokens = 256 + + tic = time.time() + res = self.run_decode(max_tokens) + tok = time.time() + print(res["text"]) + throughput = max_tokens / (tok - tic) + print(f"Throughput: {throughput} tokens/s") + assert throughput >= 140 + + +if __name__ == "__main__": + unittest.main()