From c8423ca3112f6bf638f294a548e16ab4a3e79f1f Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Wed, 14 Aug 2024 15:27:35 +0800 Subject: [PATCH 1/7] ci: update timeout and retry (#1086) Co-authored-by: Liangsheng Yin --- .github/workflows/accuracy-test.yml | 4 +++- .github/workflows/e2e-test.yml | 5 +++++ .github/workflows/moe-test.yml | 15 +++++++++++---- .github/workflows/unit-test.yml | 4 ++++ test/srt/test_moe_serving_throughput.py | 2 +- 5 files changed, 24 insertions(+), 6 deletions(-) diff --git a/.github/workflows/accuracy-test.yml b/.github/workflows/accuracy-test.yml index 16bb584f4aa..da2d98e861e 100644 --- a/.github/workflows/accuracy-test.yml +++ b/.github/workflows/accuracy-test.yml @@ -6,11 +6,13 @@ on: paths: - "python/sglang/**" - "test/**" + - ".github/workflows/accuracy-test.yml" pull_request: branches: [ main ] paths: - "python/sglang/**" - "test/**" + - ".github/workflows/accuracy-test.yml" workflow_dispatch: concurrency: @@ -43,4 +45,4 @@ jobs: run: | cd test/srt python3 test_eval_accuracy_large.py - timeout-minutes: 20 + timeout-minutes: 10 diff --git a/.github/workflows/e2e-test.yml b/.github/workflows/e2e-test.yml index 455594bd723..3a338a6577a 100644 --- a/.github/workflows/e2e-test.yml +++ b/.github/workflows/e2e-test.yml @@ -6,11 +6,13 @@ on: paths: - "python/sglang/**" - "test/**" + - ".github/workflows/e2e-test.yml" pull_request: branches: [ main ] paths: - "python/sglang/**" - "test/**" + - ".github/workflows/e2e-test.yml" workflow_dispatch: concurrency: @@ -39,13 +41,16 @@ jobs: run: | cd test/srt python3 -m unittest test_serving_throughput.TestServingThroughput.test_default + timeout-minutes: 10 - name: Benchmark Serving Throughput (w/o RadixAttention) run: | cd test/srt python3 -m unittest test_serving_throughput.TestServingThroughput.test_default_without_radix_cache + timeout-minutes: 10 - name: Benchmark Serving Throughput (w/ ChunkedPrefill) run: | cd test/srt python3 -m unittest test_serving_throughput.TestServingThroughput.test_default_with_chunked_prefill + timeout-minutes: 10 diff --git a/.github/workflows/moe-test.yml b/.github/workflows/moe-test.yml index a781f2eff80..39eb2a71dd9 100644 --- a/.github/workflows/moe-test.yml +++ b/.github/workflows/moe-test.yml @@ -6,11 +6,13 @@ on: paths: - "python/sglang/**" - "test/**" + - ".github/workflows/moe-test.yml" pull_request: branches: [ main ] paths: - "python/sglang/**" - "test/**" + - ".github/workflows/moe-test.yml" workflow_dispatch: concurrency: @@ -36,7 +38,12 @@ jobs: pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/ --force-reinstall - name: Benchmark MOE Serving Throughput - run: | - cd test/srt - python3 -m unittest test_moe_serving_throughput.TestServingThroughput.test_default - python3 -m unittest test_moe_serving_throughput.TestServingThroughput.test_default_without_radix_cache + uses: nick-fields/retry@v3 + with: + timeout_minutes: 15 + max_attempts: 2 + retry_on: error + command: | + cd test/srt + python3 -m unittest test_moe_serving_throughput.TestServingThroughput.test_default + python3 -m unittest test_moe_serving_throughput.TestServingThroughput.test_default_without_radix_cache diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index f9b79dc6745..59228585fea 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -6,11 +6,13 @@ on: paths: - "python/sglang/**" - "test/**" + - ".github/workflows/unit-test.yml" pull_request: branches: [ main ] paths: - "python/sglang/**" - "test/**" + - ".github/workflows/unit-test.yml" workflow_dispatch: concurrency: @@ -41,8 +43,10 @@ jobs: run: | cd test/srt python3 run_suite.py --suite minimal + timeout-minutes: 15 - name: Test Frontend Language run: | cd test/lang python3 run_suite.py --suite minimal + timeout-minutes: 10 diff --git a/test/srt/test_moe_serving_throughput.py b/test/srt/test_moe_serving_throughput.py index 48798c5d5f0..713eba7abb8 100644 --- a/test/srt/test_moe_serving_throughput.py +++ b/test/srt/test_moe_serving_throughput.py @@ -73,7 +73,7 @@ def test_default(self): if os.getenv("SGLANG_IS_IN_CI", "false") == "true": # A100 (PCIE) performance - assert res["output_throughput"] > 950 + assert res["output_throughput"] > 930 def test_default_without_radix_cache(self): res = self.run_test( From 616b59f384ad13b824fa8bb634444b43967f8c8a Mon Sep 17 00:00:00 2001 From: rainred <107027757+gryffindor-rr@users.noreply.github.com> Date: Wed, 14 Aug 2024 15:28:04 +0800 Subject: [PATCH 2/7] [Feature] modify Runtime to support skip_tokenizer_init (#1088) Co-authored-by: lzhang --- python/sglang/srt/server.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 7331425fae9..8f735ac0c74 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -533,11 +533,18 @@ async def async_generate( prompt: str, sampling_params: Optional[Dict] = None, ): - json_data = { - "text": prompt, - "sampling_params": sampling_params, - "stream": True, - } + if self.server_args.skip_tokenizer_init: + json_data = { + "input_ids": prompt, + "sampling_params": sampling_params, + "stream": True, + } + else: + json_data = { + "text": prompt, + "sampling_params": sampling_params, + "stream": True, + } pos = 0 timeout = aiohttp.ClientTimeout(total=3 * 3600) @@ -549,10 +556,13 @@ async def async_generate( if chunk == "data: [DONE]\n\n": break data = json.loads(chunk[5:].strip("\n")) - cur = data["text"][pos:] - if cur: - yield cur - pos += len(cur) + if hasattr(data, "text"): + cur = data["text"][pos:] + if cur: + yield cur + pos += len(cur) + else: + yield data add_request = async_generate From 8f790ac1005cfb5403a0a1e847bb0e050a4282da Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 14 Aug 2024 03:25:38 -0700 Subject: [PATCH 3/7] Fix a bug in cuda graph runner (#1094) --- python/sglang/srt/model_executor/cuda_graph_runner.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 9bfd4a646c2..a74e8eef787 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -98,8 +98,8 @@ def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile): self.req_pool_indices = torch.zeros( (self.max_bs,), dtype=torch.int32, device="cuda" ) - self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda") - self.position_ids_offsets = torch.zeros( + self.seq_lens = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda") + self.position_ids_offsets = torch.ones( (self.max_bs,), dtype=torch.int32, device="cuda" ) self.out_cache_loc = torch.zeros( @@ -201,7 +201,7 @@ def run_once(): out_cache_loc=out_cache_loc, return_logprob=False, top_logprobs_nums=0, - positions=(seq_lens - 1).to(torch.int64), + positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64), flashinfer_decode_wrapper=flashinfer_decode_wrapper, ) @@ -225,8 +225,8 @@ def replay(self, batch: ScheduleBatch): index = bisect.bisect_left(self.batch_size_list, raw_bs) bs = self.batch_size_list[index] if bs != raw_bs: - self.seq_lens.fill_(1) - self.position_ids_offsets.zero_() + self.seq_lens.zero_() + self.position_ids_offsets.fill_(1) self.out_cache_loc.zero_() # Common inputs From f14569f64aa19bcdbf51e08d0aba7e19ccfb5b88 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Wed, 14 Aug 2024 18:36:24 +0800 Subject: [PATCH 4/7] ci: remove workflow path trigger (#1096) --- .github/workflows/accuracy-test.yml | 2 -- .github/workflows/e2e-test.yml | 2 -- .github/workflows/moe-test.yml | 2 -- .github/workflows/unit-test.yml | 2 -- 4 files changed, 8 deletions(-) diff --git a/.github/workflows/accuracy-test.yml b/.github/workflows/accuracy-test.yml index da2d98e861e..374f0d2856d 100644 --- a/.github/workflows/accuracy-test.yml +++ b/.github/workflows/accuracy-test.yml @@ -6,13 +6,11 @@ on: paths: - "python/sglang/**" - "test/**" - - ".github/workflows/accuracy-test.yml" pull_request: branches: [ main ] paths: - "python/sglang/**" - "test/**" - - ".github/workflows/accuracy-test.yml" workflow_dispatch: concurrency: diff --git a/.github/workflows/e2e-test.yml b/.github/workflows/e2e-test.yml index 3a338a6577a..cb11e0db535 100644 --- a/.github/workflows/e2e-test.yml +++ b/.github/workflows/e2e-test.yml @@ -6,13 +6,11 @@ on: paths: - "python/sglang/**" - "test/**" - - ".github/workflows/e2e-test.yml" pull_request: branches: [ main ] paths: - "python/sglang/**" - "test/**" - - ".github/workflows/e2e-test.yml" workflow_dispatch: concurrency: diff --git a/.github/workflows/moe-test.yml b/.github/workflows/moe-test.yml index 39eb2a71dd9..51f7d022614 100644 --- a/.github/workflows/moe-test.yml +++ b/.github/workflows/moe-test.yml @@ -6,13 +6,11 @@ on: paths: - "python/sglang/**" - "test/**" - - ".github/workflows/moe-test.yml" pull_request: branches: [ main ] paths: - "python/sglang/**" - "test/**" - - ".github/workflows/moe-test.yml" workflow_dispatch: concurrency: diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index 59228585fea..4b61c4c4ed3 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -6,13 +6,11 @@ on: paths: - "python/sglang/**" - "test/**" - - ".github/workflows/unit-test.yml" pull_request: branches: [ main ] paths: - "python/sglang/**" - "test/**" - - ".github/workflows/unit-test.yml" workflow_dispatch: concurrency: From fe5024325b8bf952714a49575c86e9b608d01f58 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Wed, 14 Aug 2024 19:40:05 +0800 Subject: [PATCH 5/7] docs: update README (#1098) --- .github/ISSUE_TEMPLATE/1-bug-report.yml | 3 ++- .github/ISSUE_TEMPLATE/2-feature-request.yml | 6 ++++++ .github/pull_request_template.md | 7 ++++--- README.md | 2 +- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/1-bug-report.yml b/.github/ISSUE_TEMPLATE/1-bug-report.yml index c1684c14bb4..5f6734867ca 100644 --- a/.github/ISSUE_TEMPLATE/1-bug-report.yml +++ b/.github/ISSUE_TEMPLATE/1-bug-report.yml @@ -12,6 +12,7 @@ body: - label: 2. The bug has not been fixed in the latest version. - label: 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback. - label: 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed. + - label: 5. Please use English, otherwise it will be closed. - type: textarea attributes: label: Describe the bug @@ -31,7 +32,7 @@ body: attributes: label: Environment description: | - Please provide necessary environment information here with `python3 -m sglang.check_env`. + Please provide necessary environment information here with `python3 -m sglang.check_env`. Otherwise the issue will be closed. placeholder: Environment here. validations: required: true diff --git a/.github/ISSUE_TEMPLATE/2-feature-request.yml b/.github/ISSUE_TEMPLATE/2-feature-request.yml index 5ab369f8b09..31bc4a127e6 100644 --- a/.github/ISSUE_TEMPLATE/2-feature-request.yml +++ b/.github/ISSUE_TEMPLATE/2-feature-request.yml @@ -3,6 +3,12 @@ description: Suggest an idea for this project title: "[Feature] " body: +- type: checkboxes + attributes: + label: Checklist + options: + - label: 1. If the issue you raised is not a feature but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed. + - label: 2. Please use English, otherwise it will be closed. - type: textarea attributes: label: Motivation diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 20f4a10bc56..acc9682d64c 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -10,6 +10,7 @@ Briefly describe the changes made in this PR. ## Checklist -1. Ensure pre-commit `pre-commit run --all-files` or other linting tools are used to fix potential lint issues. -2. Confirm that modifications are covered by complete unit tests. If not, please add more unit tests for correctness. -3. Modify documentation as needed, such as docstrings or example tutorials. +-[] Before submitting a PR for review, make sure it has passed verification in your local development environment **at least**. +-[] Ensure pre-commit `pre-commit run --all-files` or other linting tools are used to fix potential lint issues. +-[] Confirm that modifications are covered by complete unit tests. If not, please add more unit tests for correctness. +-[] Modify documentation as needed, such as docstrings or example tutorials. diff --git a/README.md b/README.md index 117c329bb03..451e0a69348 100644 --- a/README.md +++ b/README.md @@ -88,7 +88,7 @@ docker run --gpus all \ 2. Execute the command `docker compose up -d` in your terminal. ### Common Notes -- If you cannot install FlashInfer, check out its [installation](https://docs.flashinfer.ai/installation.html#) page. If you still cannot install it, you can use the slower Triton kernels by adding `--disable-flashinfer` when launching the server. +- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is currently one of the dependencies that must be installed for SGLang. If you are using NVIDIA GPU devices below sm80, such as T4, you can't use SGLang for the time being. We expect to resolve this issue soon, so please stay tuned. If you encounter any FlashInfer-related issues on sm80+ devices (e.g., A100, L40S, H100), consider using Triton's kernel by `--disable-flashinfer --disable-flashinfer-sampling` and raise a issue. - If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`. ## Backend: SGLang Runtime (SRT) From a59636bb5e68f36308bb092674429d27c05cf125 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 14 Aug 2024 04:40:44 -0700 Subject: [PATCH 6/7] Update grok 1 model (#1095) --- benchmark/gsm8k/bench_sglang.py | 3 + python/sglang/bench_latency.py | 1 + python/sglang/srt/layers/activation.py | 1 - .../sglang/srt/layers/fused_moe/__init__.py | 1 + .../srt/layers/{ => fused_moe}/fused_moe.py | 273 ++++---- python/sglang/srt/layers/fused_moe/layer.py | 587 ++++++++++++++++++ python/sglang/srt/layers/logits_processor.py | 8 +- .../sglang/srt/model_executor/model_runner.py | 4 +- python/sglang/srt/models/grok.py | 444 ++----------- python/sglang/srt/models/mixtral.py | 1 - python/sglang/srt/utils.py | 3 +- 11 files changed, 813 insertions(+), 513 deletions(-) create mode 100644 python/sglang/srt/layers/fused_moe/__init__.py rename python/sglang/srt/layers/{ => fused_moe}/fused_moe.py (78%) create mode 100644 python/sglang/srt/layers/fused_moe/layer.py diff --git a/benchmark/gsm8k/bench_sglang.py b/benchmark/gsm8k/bench_sglang.py index 298ec11d73d..652086f913b 100644 --- a/benchmark/gsm8k/bench_sglang.py +++ b/benchmark/gsm8k/bench_sglang.py @@ -88,6 +88,9 @@ def few_shot_gsm8k(s, question): for i in range(len(states)): preds.append(get_answer_value(states[i]["answer"])) + # print(f"{preds=}") + # print(f"{labels=}") + # Compute accuracy acc = np.mean(np.array(preds) == np.array(labels)) invalid = np.mean(np.array(preds) == INVALID) diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index ee227849cf8..e500d30d1c5 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -221,6 +221,7 @@ def correctness_test( # Prepare inputs input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer) + rank_print(f"{input_ids=}") if bench_args.cut_len > 0: # Prefill diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 64d3915946d..7cd8abb6f96 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -14,7 +14,6 @@ """Fused operators for activation layers.""" import torch -import torch.nn as nn import torch.nn.functional as F from flashinfer.activation import silu_and_mul from vllm.model_executor.custom_op import CustomOp diff --git a/python/sglang/srt/layers/fused_moe/__init__.py b/python/sglang/srt/layers/fused_moe/__init__.py new file mode 100644 index 00000000000..5f7691c09fd --- /dev/null +++ b/python/sglang/srt/layers/fused_moe/__init__.py @@ -0,0 +1 @@ +from sglang.srt.layers.fused_moe.layer import FusedMoE, FusedMoEMethodBase diff --git a/python/sglang/srt/layers/fused_moe.py b/python/sglang/srt/layers/fused_moe/fused_moe.py similarity index 78% rename from python/sglang/srt/layers/fused_moe.py rename to python/sglang/srt/layers/fused_moe/fused_moe.py index c5630fa5db4..717be5ce966 100644 --- a/python/sglang/srt/layers/fused_moe.py +++ b/python/sglang/srt/layers/fused_moe/fused_moe.py @@ -1,20 +1,5 @@ -""" -Copyright 2023-2024 SGLang Team -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -""" - # Adapted from -# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/layers/fused_moe/fused_moe.py#L1 +# https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe """Fused MoE kernel.""" import functools import json @@ -24,6 +9,7 @@ import torch import triton import triton.language as tl +import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger @@ -373,6 +359,31 @@ def get_default_config( return config +def try_get_optimal_moe_config( + w1_shape: Tuple[int, ...], + w2_shape: Tuple[int, ...], + top_k: int, + dtype: Optional[str], + M: int, + override_config: Optional[Dict[str, Any]] = None, +): + if override_config: + config = override_config + else: + # First try to load optimal config from the file + E, _, N = w2_shape + configs = get_moe_configs(E, N, dtype) + + if configs: + # If an optimal configuration map has been found, look up the + # optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Else use the default config + config = get_default_config(M, E, N, w1_shape[2], top_k, dtype) + return config + + def fused_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -403,6 +414,41 @@ def fused_topk( return topk_weights, topk_ids +# This is used by the Deepseek-V2 model +def grouped_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0, +): + + assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + + scores = torch.softmax(gating_output, dim=-1) + num_token = scores.shape[0] + group_scores = ( + scores.view(num_token, num_expert_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group) + .reshape(num_token, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + return topk_weights, topk_ids + + def fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -425,24 +471,23 @@ def fused_experts( assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16] - M, _ = hidden_states.shape + num_tokens, _ = hidden_states.shape E, N, _ = w1.shape + # We execute the fused_moe kernel in chunks to circumvent this issue: + # https://github.com/vllm-project/vllm/issues/5938 + CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE + M = min(num_tokens, CHUNK_SIZE) + + get_config_func = functools.partial( + try_get_optimal_moe_config, + w1.shape, + w2.shape, + topk_ids.shape[1], + "float8" if use_fp8 else None, + override_config=override_config, + ) - if override_config: - config = override_config - else: - # First try to load optimal config from the file - configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None) - - if configs: - # If an optimal configuration map has been found, look up the - # optimal config - config = configs[min(configs.keys(), key=lambda x: abs(x - M))] - else: - # Else use the default config - config = get_default_config( - M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None - ) + config = get_config_func(M) intermediate_cache1 = torch.empty( (M, topk_ids.shape[1], N), @@ -460,56 +505,85 @@ def fused_experts( dtype=hidden_states.dtype, ) - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - topk_ids, config["BLOCK_SIZE_M"], E - ) compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16 - invoke_fused_moe_kernel( - hidden_states, - w1, - intermediate_cache1, - a1_scale, - w1_scale, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - topk_ids.shape[1], - config, - compute_type=compute_type, - use_fp8=use_fp8, - ) + if inplace: + out_hidden_states = hidden_states + else: + out_hidden_states = torch.empty_like(hidden_states) - ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + for chunk in range((num_tokens // CHUNK_SIZE) + 1): + begin_chunk_idx, end_chunk_idx = ( + chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, num_tokens), + ) + curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] + tokens_in_chunk, _ = curr_hidden_states.shape + + if tokens_in_chunk == 0: + break + + if tokens_in_chunk < CHUNK_SIZE and chunk > 0: + # Adjust the intermediate cache size and config for the last + # chunk. Note that in most cases we only have one chunk + # so the cache size and config are already set correctly and + # do not need to be adjusted. + intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] + intermediate_cache2 = intermediate_cache2[:tokens_in_chunk] + intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] + config = get_config_func(tokens_in_chunk) + + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + curr_topk_ids, config["BLOCK_SIZE_M"], E + ) - invoke_fused_moe_kernel( - intermediate_cache2, - w2, - intermediate_cache3, - a2_scale, - w2_scale, - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - True, - 1, - config, - compute_type=compute_type, - use_fp8=use_fp8, - ) + invoke_fused_moe_kernel( + curr_hidden_states, + w1, + intermediate_cache1, + a1_scale, + w1_scale, + curr_topk_weights, + curr_topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + config, + compute_type=compute_type, + use_fp8=use_fp8, + ) - if inplace: - return torch.sum( + ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + + invoke_fused_moe_kernel( + intermediate_cache2, + w2, + intermediate_cache3, + a2_scale, + w2_scale, + curr_topk_weights, + curr_topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + compute_type=compute_type, + use_fp8=use_fp8, + ) + + torch.sum( intermediate_cache3.view(*intermediate_cache3.shape), dim=1, - out=hidden_states, + out=out_hidden_states[begin_chunk_idx:end_chunk_idx], ) - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) + return out_hidden_states def fused_moe( @@ -521,6 +595,9 @@ def fused_moe( renormalize: bool, inplace: bool = False, override_config: Optional[Dict[str, Any]] = None, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, use_fp8: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, @@ -543,6 +620,10 @@ def fused_moe( Defaults to False. - override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration. + - num_expert_group: Optional[int]: additional parameter for grouped_topk + - topk_group: Optional[int]: additional parameter for grouped_topk + - use_grouped_topk: If True, use grouped_topk instead of fused_topk + note: Deepseekv2 model uses grouped_topk - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for @@ -556,12 +637,18 @@ def fused_moe( # Check constraints. assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - if hasattr(ops, "topk_softmax"): - topk_weights, topk_ids = fused_topk( - hidden_states, gating_output, topk, renormalize + if use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + topk_weights, topk_ids = grouped_topk( + hidden_states, + gating_output, + topk, + renormalize, + num_expert_group, + topk_group, ) else: - topk_weights, topk_ids = fused_topk_v0_4_3( + topk_weights, topk_ids = fused_topk( hidden_states, gating_output, topk, renormalize ) @@ -579,33 +666,3 @@ def fused_moe( a1_scale=a1_scale, a2_scale=a2_scale, ) - - -def fused_topk_v0_4_3( - hidden_states: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, -): - import vllm._moe_C as moe_kernels - - M, _ = hidden_states.shape - - topk_weights = torch.empty( - M, topk, dtype=torch.float32, device=hidden_states.device - ) - topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) - token_expert_indicies = torch.empty( - M, topk, dtype=torch.int32, device=hidden_states.device - ) - moe_kernels.topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - gating_output.float(), # TODO(woosuk): Optimize this. - ) - del token_expert_indicies # Not used. Will be used in the future. - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - - return topk_weights, topk_ids diff --git a/python/sglang/srt/layers/fused_moe/layer.py b/python/sglang/srt/layers/fused_moe/layer.py new file mode 100644 index 00000000000..0b17c14ffd8 --- /dev/null +++ b/python/sglang/srt/layers/fused_moe/layer.py @@ -0,0 +1,587 @@ +# Adapted from +# https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe +from abc import abstractmethod +from typing import List, Optional, Tuple + +import torch +from vllm.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) +from vllm.logger import init_logger +from vllm.model_executor.custom_op import CustomOp +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from vllm.model_executor.layers.quantization.fp8 import Fp8Config +from vllm.model_executor.utils import set_weight_attrs + +logger = init_logger(__name__) + + +class FusedMoEMethodBase(QuantizeMethodBase): + + @abstractmethod + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + raise NotImplementedError + + @abstractmethod + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + ) -> torch.Tensor: + raise NotImplementedError + + +class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): + """MoE method without quantization.""" + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + + # Fused gate_up_proj (column parallel) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + # down_proj (row parallel) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, hidden_size, intermediate_size, dtype=params_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + ) -> torch.Tensor: + return self.forward( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, + top_k, + renormalize, + use_grouped_topk, + num_expert_group, + topk_group, + ) + + def forward_cuda( + self, + x: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + num_expert_group: Optional[int], + topk_group: Optional[int], + ) -> torch.Tensor: + from sglang.srt.layers.fused_moe.fused_moe import fused_moe + + return fused_moe( + x, + w1, + w2, + router_logits, + top_k, + renormalize=renormalize, + inplace=True, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group, + ) + + def forward_cpu(self, *args, **kwargs): + raise NotImplementedError("The CPU backend currently does not support MoE.") + + def forward_tpu( + self, + x: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + num_expert_group: Optional[int], + topk_group: Optional[int], + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe + + assert not use_grouped_topk + assert num_expert_group is None + assert topk_group is None + return fused_moe(x, w1, w2, router_logits, top_k, renormalize) + + +class FusedMoE(torch.nn.Module): + """FusedMoE layer for MoE models. + + This layer contains both MergedColumnParallel weights (gate_up_proj / + w13) and RowParallelLinear weights (down_proj/ w2). + + Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We + copy that naming convention here and handle any remapping in the + load_weights function in each model implementation. + + Args: + num_experts: Number of experts in the model + top_k: Number of experts selected for each token + hidden_size: Input hidden state size of the transformer + intermediate_size: Intermediate size of the experts + params_dtype: Data type for the parameters. + reduce_results: Whether to all all_reduce on the output of the layer + renomalize: Whether to renormalize the logits in the fused_moe kernel + quant_config: Quantization configure. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = False, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = "", + ): + super().__init__() + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + self.tp_size = ( + tp_size if tp_size is not None else get_tensor_model_parallel_world_size() + ) + self.top_k = top_k + self.num_experts = num_experts + self.intermediate_size_per_partition = intermediate_size // self.tp_size + self.reduce_results = reduce_results + self.renormalize = renormalize + self.use_grouped_topk = use_grouped_topk + if self.use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + self.num_expert_group = num_expert_group + self.topk_group = topk_group + + if quant_config is None: + self.quant_method: Optional[QuantizeMethodBase] = ( + UnquantizedFusedMoEMethod() + ) + else: + if isinstance(quant_config, Fp8Config): + self.quant_method = Fp8MoEMethod(quant_config) + else: + self.quant_method = quant_config.get_quant_method(self, prefix) + assert self.quant_method is not None + + self.quant_method.create_weights( + layer=self, + num_experts=num_experts, + hidden_size=hidden_size, + intermediate_size=self.intermediate_size_per_partition, + params_dtype=params_dtype, + weight_loader=self.weight_loader, + ) + + def weight_loader( + self, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + weight_name: str, + shard_id: int, + expert_id: int, + pre_sharded: bool, + ): + param_data = param.data + + # Input scales can be loaded directly and should be equal. + if "input_scale" in weight_name: + if ( + param_data[expert_id] != 1 + and (param_data[expert_id] - loaded_weight).abs() > 1e-5 + ): + raise ValueError( + "input_scales of w1 and w3 of a layer " + f"must be equal. But got {param_data[expert_id]} " + f"vs. {loaded_weight}" + ) + param_data[expert_id] = loaded_weight + # Weight scales + elif "weight_scale" in weight_name: + # If we are in merged column case (gate_up_proj) + # shard_id 0 == gate_proj / w1 + # shard_id 2 == up_proj / w3 + if shard_id == 0 or shard_id == 2: + # We have to keep the weight scales of w1 and w3 because + # we need to re-quantize w1/w3 weights after weight loading. + idx = 0 if shard_id == 0 else 1 + param_data[expert_id][idx] = loaded_weight + # If we are in the row parallel case (down_proj) + # shard_id 1 == down_proj / w2 + else: + param_data[expert_id] = loaded_weight + # Weights + else: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.intermediate_size_per_partition + if pre_sharded: + shard = slice(None) + else: + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + + # w1, gate_proj case: Load into first shard of w13. + if shard_id == 0: + param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] + # w3, up_proj case: Load into second shard of w13. + elif shard_id == 2: + param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[ + shard, : + ] + # w2, down_proj case: Load into only shard of w2. + elif shard_id == 1: + param_data[expert_id, :, :] = loaded_weight[:, shard] + else: + raise ValueError(f"Shard id must be in [0,1,2] but got {shard_id}") + + def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + assert self.quant_method is not None + + # Matrix multiply. + final_hidden_states = self.quant_method.apply( + self, + x=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + num_expert_group=self.num_expert_group, + topk_group=self.topk_group, + ) + + if self.reduce_results and self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + + return final_hidden_states + + @classmethod + def make_expert_params_mapping( + cls, + ckpt_gate_proj_name: str, + ckpt_down_proj_name: str, + ckpt_up_proj_name: str, + num_experts: int, + ) -> List[Tuple[str, str, int, int]]: + + gate_up = [ckpt_gate_proj_name, ckpt_up_proj_name] + gate_down_up = [ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name] + + return ( + [ + # These are the weight scales for the experts + # (param_name, weight_name, expert_id, shard_id) + ( + ( + "experts.w13_scale" + if weight_name in gate_up + else "experts.w2_scale" + ), + f"experts.{expert_id}.{weight_name}.weight_scale", + expert_id, + shard_id, + ) + for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) + ] + + [ + # These are the weights for the experts + # (param_name, weight_name, expert_id, shard_id) + ( + ( + "experts.w13_weight" + if weight_name in gate_up + else "experts.w2_weight" + ), + f"experts.{expert_id}.{weight_name}.weight", + expert_id, + shard_id, + ) + for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) + ] + + [ + # These are the weight scales for the experts + # (param_name, weight_name, expert_id, shard_id) + ( + ( + "experts.a13_scale" + if weight_name in gate_up + else "experts.a2_scale" + ), + f"experts.{expert_id}.{weight_name}.input_scale", + expert_id, + shard_id, + ) + for expert_id in range(num_experts) + for shard_id, weight_name in enumerate(gate_down_up) + ] + ) + + +import torch +from torch.nn import Module +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + all_close_1d, + per_tensor_dequantize, +) +from vllm.utils import print_warning_once + + +class Fp8MoEMethod(FusedMoEMethodBase): + """MoE method for FP8. + Supports loading FP8 checkpoints with static weight scale and + dynamic/static activation scale. + + Also supports loading quantized FP16/BF16 model checkpoints with dynamic + activation scaling. The weight scaling factor will be initialized after + the model weights are loaded. + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: Fp8Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + + if self.quant_config.is_checkpoint_fp8_serialized: + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, hidden_size, intermediate_size, dtype=params_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_scale", w13_scale) + + w2_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w2_scale", w2_scale) + + # If loading fp8 checkpoint, pass the weight loaders. + # If loading an fp16 checkpoint, do not (we will quantize in + # process_weights_after_loading() + if self.quant_config.is_checkpoint_fp8_serialized: + set_weight_attrs(w13_scale, extra_weight_attrs) + set_weight_attrs(w2_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.quant_config.activation_scheme == "static": + if not self.quant_config.is_checkpoint_fp8_serialized: + raise ValueError( + "Found static activation scheme for checkpoint that " + "was not serialized fp8." + ) + + a13_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("a13_scale", a13_scale) + set_weight_attrs(a13_scale, extra_weight_attrs) + + a2_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("a2_scale", a2_scale) + set_weight_attrs(a2_scale, extra_weight_attrs) + else: + layer.a13_scale = None + layer.a2_scale = None + + def process_weights_after_loading(self, layer: Module) -> None: + + # If checkpoint is fp16, quantize in place. + if not self.quant_config.is_checkpoint_fp8_serialized: + w13_weight = torch.empty_like( + layer.w13_weight.data, dtype=torch.float8_e4m3fn + ) + w2_weight = torch.empty_like( + layer.w2_weight.data, dtype=torch.float8_e4m3fn + ) + + # Re-initialize w13_scale because we directly quantize + # merged w13 weights and generate a single scaling factor. + layer.w13_scale = torch.nn.Parameter( + torch.ones( + layer.num_experts, dtype=torch.float32, device=w13_weight.device + ), + requires_grad=False, + ) + for expert in range(layer.num_experts): + w13_weight[expert, :, :], layer.w13_scale[expert] = ( + ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_scale[expert] = ops.scaled_fp8_quant( + layer.w2_weight.data[expert, :, :] + ) + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + return + + # If checkpoint is fp8, we need to handle that the + # MoE kernels require single activation scale and single weight + # scale for w13 per expert. + else: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.quant_config.activation_scheme == "static": + if layer.a13_scale is None or layer.a2_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None." + ) + if not all_close_1d(layer.a13_scale) or not all_close_1d( + layer.a2_scale + ): + print_warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer. " + ) + layer.a13_scale = torch.nn.Parameter( + layer.a13_scale.max(), requires_grad=False + ) + layer.a2_scale = torch.nn.Parameter( + layer.a2_scale.max(), requires_grad=False + ) + + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_scale.max(dim=1).values + for expert_id in range(layer.num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start : start + shard_size, :], + layer.w13_scale[expert_id][shard_id], + ) + layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( + ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) + ) + start += shard_size + + layer.w13_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) + return + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + ) -> torch.Tensor: + + from sglang.srt.layers.fused_moe.fused_moe import fused_moe + + return fused_moe( + x, + layer.w13_weight, + layer.w2_weight, + router_logits, + top_k, + renormalize=renormalize, + inplace=True, + use_fp8=True, + w1_scale=layer.w13_scale, + w2_scale=layer.w2_scale, + a1_scale=layer.a13_scale, + a2_scale=layer.a2_scale, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group, + ) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index cf5045fda5e..541fa0f1530 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -164,9 +164,9 @@ def forward( last_logits = last_logits[:, : self.config.vocab_size].float() if hasattr(self.config, "final_logit_softcapping"): - last_logits /= self.config.final_logit_softcapping + last_logits.div_(self.config.final_logit_softcapping) last_logits = torch.tanh(last_logits) - last_logits *= self.config.final_logit_softcapping + last_logits.mul_(self.config.final_logit_softcapping) # Return only last_logits if logprob is not requested if not logits_metadata.return_logprob: @@ -209,9 +209,9 @@ def forward( all_logits = all_logits[:, : self.config.vocab_size].float() if hasattr(self.config, "final_logit_softcapping"): - all_logits /= self.config.final_logit_softcapping + all_logits.div_(self.config.final_logit_softcapping) all_logits = torch.tanh(all_logits) - all_logits *= self.config.final_logit_softcapping + all_logits.mul_(self.config.final_logit_softcapping) all_logprobs = all_logits del all_logits, hidden_states diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 34a40c7d71a..9da284da65b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -53,7 +53,7 @@ from sglang.srt.utils import ( get_available_gpu_memory, is_generation_model, - is_llama3_405b_fp8, + is_llama3_405b_fp8_head_16, is_multimodal_model, monkey_patch_vllm_dummy_weight_loader, monkey_patch_vllm_p2p_access_check, @@ -158,7 +158,7 @@ def load_model(self): skip_tokenizer_init=True, ) - if is_llama3_405b_fp8(self.model_config) and self.tp_size <= 8: + if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8: # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints self.model_config.hf_config.num_key_value_heads = 8 vllm_model_config.hf_config.num_key_value_heads = 8 diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 13d4330d4c3..eff746f1dd1 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -16,20 +16,17 @@ # Adapted from # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1 """Inference-only Grok1 model.""" +import warnings from typing import Iterable, List, Optional, Tuple -import numpy as np import torch import torch.nn.functional as F -import tqdm from torch import nn from transformers import PretrainedConfig -from vllm import _custom_ops as ops from vllm.config import CacheConfig from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce, ) from vllm.model_executor.layers.linear import ( QKVParallelLinear, @@ -37,7 +34,6 @@ RowParallelLinear, ) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig -from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, @@ -45,141 +41,13 @@ ) from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.utils import set_weight_attrs -from vllm.utils import print_warning_once -from sglang.srt.layers.fused_moe import fused_moe +from sglang.srt.layers.fused_moe import FusedMoE from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import InputMetadata -use_fused = True - - -class Grok1MLP(nn.Module): - def __init__( - self, - num_experts: int, - hidden_size: int, - intermediate_size: int, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - self.num_experts = num_experts - self.ffn_dim = intermediate_size - self.hidden_dim = hidden_size - - self.w1 = ReplicatedLinear( - self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config - ) - self.w2 = ReplicatedLinear( - self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config - ) - self.w3 = ReplicatedLinear( - self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config - ) - - self.act_fn = nn.GELU() - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - w1_out, _ = self.w1(hidden_states) - w1_out = self.act_fn(w1_out) - w3_out, _ = self.w3(hidden_states) - current_hidden_states = w1_out * w3_out - current_hidden_states, _ = self.w2(current_hidden_states) - return current_hidden_states - - -class Grok1MoEUnfused(nn.Module): - def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - ): - super().__init__() - self.config = config - self.rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() - self.num_total_experts = config.num_local_experts - self.top_k = config.num_experts_per_tok - if self.tp_size > self.num_total_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {self.num_total_experts}." - ) - # Split experts equally between ranks - self.expert_indicies = np.array_split( - range(self.num_total_experts), self.tp_size - )[self.rank].tolist() - if not self.expert_indicies: - raise ValueError(f"Rank {self.rank} has no experts assigned to it.") - - self.experts = nn.ModuleList( - [ - ( - Grok1MLP( - self.num_total_experts, - config.hidden_size, - config.intermediate_size, - quant_config=quant_config, - ) - if idx in self.expert_indicies - else None - ) - for idx in range(self.num_total_experts) - ] - ) - self.gate = ReplicatedLinear( - config.hidden_size, self.num_total_experts, bias=False, quant_config=None - ) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - router_logits, _ = self.gate(hidden_states) - router_logits = 30 * F.tanh(router_logits / 30) - - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk( - routing_weights, self.top_k, dim=-1 - ) - routing_weights = routing_weights.to(hidden_states.dtype) - hidden_dim = hidden_states.shape[1] - - final_hidden_states = torch.zeros( - (hidden_states.shape[0], hidden_dim), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - expert_mask = torch.nn.functional.one_hot( - selected_experts, num_classes=self.num_total_experts - ).permute(2, 1, 0) - - for expert_idx in self.expert_indicies: - expert_layer = self.experts[expert_idx] - idx, top_x = torch.where(expert_mask[expert_idx]) - - if top_x.shape[0] == 0: - continue - - # in torch it is faster to index using lists than torch tensors - top_x_list = top_x.tolist() - idx_list = idx.tolist() - - # Index the correct hidden states and compute the expert hidden state for - # the current expert. We need to make sure to multiply the output hidden - # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) - current_hidden_states = ( - expert_layer(current_state) - * routing_weights[top_x_list, idx_list, None] - ) - - # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. - final_hidden_states.index_add_(0, top_x, current_hidden_states) - - return tensor_model_parallel_all_reduce(final_hidden_states) - class Grok1MoE(nn.Module): """A tensor-parallel MoE implementation for Grok1 that shards each expert @@ -197,221 +65,42 @@ def __init__( hidden_size: int, intermediate_size: int, params_dtype: Optional[torch.dtype] = None, - tp_size: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, ): super().__init__() - self.tp_size = tp_size or get_tensor_model_parallel_world_size() - self.num_total_experts = num_experts - self.top_k = top_k self.hidden_size = hidden_size - self.intermediate_size = intermediate_size // self.tp_size - self.quant_config = quant_config - - # FIXME(pcmoritz): Make this more general to support different - # quantization schemes - self.use_fp8 = isinstance(quant_config, Fp8Config) - - if params_dtype is None: - params_dtype = torch.get_default_dtype() - self.params_dtype = params_dtype # Gate always runs at half / full precision for now. self.gate = ReplicatedLinear( - self.hidden_size, - self.num_total_experts, + hidden_size, + num_experts, bias=False, - params_dtype=self.params_dtype, + params_dtype=params_dtype, quant_config=None, ) - if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized: - params_dtype = torch.float8_e4m3fn - - self.w13_weight = nn.Parameter( - torch.empty( - self.num_total_experts, - 2 * self.intermediate_size, - self.hidden_size, - dtype=params_dtype, - ) - ) - self.w2_weight = nn.Parameter( - torch.empty( - self.num_total_experts, - self.hidden_size, - self.intermediate_size, - dtype=params_dtype, - ) - ) - - set_weight_attrs( - self.w13_weight, - { - "weight_loader": self.weight_loader, - }, - ) - set_weight_attrs( - self.w2_weight, - { - "weight_loader": self.weight_loader, - }, + self.experts = FusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=False, + quant_config=quant_config, + tp_size=tp_size, ) - # Used for fp8. - self.w13_scale = None - self.w2_scale = None - self.a13_scale = None - self.a2_scale = None - - if self.use_fp8: - # WEIGHT_SCALE (for fp8) - self.w13_scale = nn.Parameter( - torch.ones(self.num_total_experts, dtype=torch.float32), - requires_grad=False, - ) - self.w2_scale = nn.Parameter( - torch.ones(self.num_total_experts, dtype=torch.float32), - requires_grad=False, - ) - - # If loading fp8 checkpoint, pass the weight loaders. - # If loading an fp16 checkpoint, do not (we will quantize in - # process_weights_after_loading() - if quant_config.is_checkpoint_fp8_serialized: - set_weight_attrs( - self.w13_scale, - { - "weight_loader": self.weight_loader, - }, - ) - set_weight_attrs( - self.w2_scale, - { - "weight_loader": self.weight_loader, - }, - ) - - # ACT_SCALE (for fp8) - if quant_config.activation_scheme == "static": - if not quant_config.is_checkpoint_fp8_serialized: - raise ValueError( - "Found static activation scheme for checkpoint that " - "was not serialized fp8." - ) - self.a13_scale = nn.Parameter( - torch.zeros(self.num_total_experts, dtype=torch.float32), - requires_grad=False, - ) - self.a2_scale = nn.Parameter( - torch.zeros(self.num_total_experts, dtype=torch.float32), - requires_grad=False, - ) - - set_weight_attrs( - self.a13_scale, - { - "weight_loader": self.weight_loader, - }, - ) - set_weight_attrs( - self.a2_scale, - { - "weight_loader": self.weight_loader, - }, - ) - - def weight_loader( - self, - param: nn.Parameter, - loaded_weight: torch.Tensor, - weight_name: str, - expert_id: int, - pre_sharded: bool, - ): - param_data = param.data - shard_size = self.intermediate_size - if pre_sharded: - # The weight is already sharded. Readl the full shard - shard = slice(None) - else: - tp_rank = get_tensor_model_parallel_rank() - shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) - if weight_name.endswith("w1.weight"): - param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] - if weight_name.endswith("w3.weight"): - param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[ - shard, : - ] - if weight_name.endswith("w2.weight"): - param_data[expert_id, :, :] = loaded_weight[:, shard] - if "act_scale" in weight_name or "weight_scale" in weight_name: - param_data[expert_id] = loaded_weight - - def process_weights_after_loading(self): - # Fp8 is the only case where we need to process after loading. - if not self.use_fp8: - return - - # If checkpoint is fp16, quantize here. - if not self.quant_config.is_checkpoint_fp8_serialized: - w13_weight = torch.empty_like( - self.w13_weight.data, dtype=torch.float8_e4m3fn - ) - w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn) - for expert in range(self.num_total_experts): - w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant( - self.w13_weight.data[expert, :, :] - ) - w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant( - self.w2_weight.data[expert, :, :] - ) - self.w13_weight = nn.Parameter(w13_weight, requires_grad=False) - self.w2_weight = nn.Parameter(w2_weight, requires_grad=False) - - # If checkpoint is fp8 + static, cleanup act_scales. - # Since state_dict has an act_scale per expert but our kernels - # are passed one act_scale shared across all experts. - elif self.quant_config.activation_scheme == "static": - if self.a13_scale is None or self.a2_scale is None: - raise ValueError( - "QuantConfig has static quantization, but found " - "activation scales are None." - ) - - if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale): - print_warning_once( - "Found act_scales that are not equal for fp8 MoE layer. " - "Using the maximum across experts for each layer. " - ) - - self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False) - self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - num_tokens, hidden_size = hidden_states.shape + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = fused_moe( - hidden_states, - self.w13_weight, - self.w2_weight, - router_logits, - self.top_k, - renormalize=False, - inplace=True, - use_fp8=self.use_fp8, - w1_scale=self.w13_scale, - w2_scale=self.w2_scale, - a1_scale=self.a13_scale, - a2_scale=self.a2_scale, - ) - - if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) - - return final_hidden_states.view(num_tokens, hidden_size) + router_logits = 30.0 * F.tanh(router_logits / 30.0) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) class Grok1Attention(nn.Module): @@ -478,6 +167,7 @@ def __init__( layer_id=layer_id, logit_cap=logit_cap, ) + # TODO(lianmin): load logit cap from config def forward( self, @@ -502,7 +192,7 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) self.self_attn = Grok1Attention( hidden_size=self.hidden_size, @@ -513,18 +203,13 @@ def __init__( rope_theta=rope_theta, quant_config=quant_config, ) - if use_fused: - self.block_sparse_moe = Grok1MoE( - num_experts=config.num_local_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - quant_config=quant_config, - ) - else: - self.block_sparse_moe = Grok1MoEUnfused( - config=config, quant_config=quant_config - ) + self.block_sparse_moe = Grok1MoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + ) self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -536,6 +221,7 @@ def forward( hidden_states: torch.Tensor, input_metadata: InputMetadata, ) -> torch.Tensor: + # Self Attention hidden_states = ( self.post_attn_norm( self.self_attn( @@ -547,11 +233,11 @@ def forward( + hidden_states ) + # Fully Connected hidden_states = ( self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states))) + hidden_states ) - return hidden_states @@ -593,7 +279,6 @@ def forward( for i in range(len(self.layers)): hidden_states = self.layers[i](positions, hidden_states, input_metadata) - hidden_states = self.norm(hidden_states) hidden_states.mul_(self.config.output_multiplier_scale) return hidden_states @@ -615,8 +300,8 @@ def __init__( # Monkey patch _prepare_weights to load pre-sharded weights setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights) + warnings.filterwarnings("ignore", category=FutureWarning) - @torch.no_grad() def forward( self, input_ids: torch.Tensor, @@ -637,50 +322,17 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "v_proj", "v"), ] - if use_fused: - expert_params_mapping = ( - [ - # These are the weight scales for the experts - # (param_name, weight_name, expert_id) - ( - "w13_scale" if weight_name in ["w1", "w3"] else "w2_scale", - f"experts.{expert_id}.{weight_name}.weight_scale", - expert_id, - ) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] - ] - + [ - # These are the weights for the experts - # (param_name, weight_name, expert_id) - ( - "w13_weight" if weight_name in ["w1", "w3"] else "w2_weight", - f"experts.{expert_id}.{weight_name}.weight", - expert_id, - ) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] - ] - + [ - # These are the activation scales for the experts - # (param_name, weight_name, expert_id) - ( - "a13_scale" if weight_name in ["w1", "w3"] else "a2_scale", - f"experts.{expert_id}.{weight_name}.act_scale", - expert_id, - ) - for expert_id in range(self.config.num_local_experts) - for weight_name in ["w1", "w2", "w3"] - ] - ) - else: - expert_params_mapping = [] + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts, + ) params_dict = dict(self.named_parameters()) - if get_tensor_model_parallel_rank() == 0: - weights = tqdm.tqdm(weights, total=int(len(params_dict) * 3.4)) for name, loaded_weight in weights: - # print(get_tensor_model_parallel_rank(), name) if "rotary_emb.inv_freq" in name: continue @@ -691,21 +343,25 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: - for param_name, weight_name, expert_id in expert_params_mapping: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue name = name.replace(weight_name, param_name) + param = params_dict[name] weight_loader = param.weight_loader weight_loader( param, loaded_weight, weight_name, + shard_id=shard_id, expert_id=expert_id, pre_sharded=get_tensor_model_parallel_world_size() > 1, ) @@ -714,6 +370,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if name is None: + continue + param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader @@ -721,11 +380,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight) -def all_close_1d(x: torch.Tensor) -> bool: - assert len(x.shape) == 1 - return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) - - old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights") diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index d11f6c95198..45de85d8791 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -32,7 +32,6 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 2d20881c8f4..9761c851a52 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -35,7 +35,6 @@ import torch.distributed as dist from fastapi.responses import JSONResponse from packaging import version as pkg_version -from starlette.middleware.base import BaseHTTPMiddleware from torch.nn.parameter import Parameter from triton.runtime.cache import ( FileCacheManager, @@ -644,7 +643,7 @@ def set_ulimit(target_soft_limit=65535): logger.warn(f"Fail to set RLIMIT_NOFILE: {e}") -def is_llama3_405b_fp8(model_config): +def is_llama3_405b_fp8_head_16(model_config): """Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads.""" if ( model_config.hf_config.architectures[0] == "LlamaForCausalLM" From 67c0d832a644090810a479d6d4655555a07d44a7 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Wed, 14 Aug 2024 20:25:39 +0800 Subject: [PATCH 7/7] docs: update pr template (#1099) --- .github/pull_request_template.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index acc9682d64c..0926cfbe9c4 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,16 +1,16 @@ -Thank you for your contribution, we really appreciate it. The following instructions will help improve your pull request and make it easier to receive feedback. If there are any items you don't understand, don't worry. Just submit the pull request and ask the maintainers for help. + ## Motivation -Please explain the motivation behind this PR and the goal you aim to achieve with it. + ## Modification -Briefly describe the changes made in this PR. + ## Checklist --[] Before submitting a PR for review, make sure it has passed verification in your local development environment **at least**. --[] Ensure pre-commit `pre-commit run --all-files` or other linting tools are used to fix potential lint issues. --[] Confirm that modifications are covered by complete unit tests. If not, please add more unit tests for correctness. --[] Modify documentation as needed, such as docstrings or example tutorials. +- [ ] Before submitting a PR for review, make sure it has passed verification in your local development environment **at least**. +- [ ] Ensure pre-commit `pre-commit run --all-files` or other linting tools are used to fix potential lint issues. +- [ ] Confirm that modifications are covered by complete unit tests. If not, please add more unit tests for correctness. +- [ ] Modify documentation as needed, such as docstrings or example tutorials.