From 7b9b4f44267fecbf08e7ed866e94b583ba64d3ae Mon Sep 17 00:00:00 2001 From: Jhin <47354855+jhinpan@users.noreply.github.com> Date: Mon, 27 Jan 2025 20:10:45 -0600 Subject: [PATCH 01/12] Docs fix about EAGLE and streaming output (#3166) Co-authored-by: Chayenne Co-authored-by: Chayenne Co-authored-by: Jhin --- .github/workflows/execute-notebook.yml | 2 +- docs/backend/function_calling.ipynb | 10 +++++- docs/backend/offline_engine_api.ipynb | 48 +++++++++++++------------ docs/backend/speculative_decoding.ipynb | 13 ++++--- docs/start/install.md | 5 ++- python/sglang/utils.py | 42 ++++++++++++++++++++++ 6 files changed, 91 insertions(+), 29 deletions(-) diff --git a/.github/workflows/execute-notebook.yml b/.github/workflows/execute-notebook.yml index e03edd6ce79..49d649797ed 100644 --- a/.github/workflows/execute-notebook.yml +++ b/.github/workflows/execute-notebook.yml @@ -42,7 +42,7 @@ jobs: python -m ipykernel install --user --name python3 --display-name "Python 3" - name: Execute notebooks - timeout-minutes: 30 + timeout-minutes: 40 run: | cd docs make clean diff --git a/docs/backend/function_calling.ipynb b/docs/backend/function_calling.ipynb index 3de80aadf11..05e7108e60e 100644 --- a/docs/backend/function_calling.ipynb +++ b/docs/backend/function_calling.ipynb @@ -507,7 +507,15 @@ ], "metadata": { "language_info": { - "name": "python" + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" } }, "nbformat": 4, diff --git a/docs/backend/offline_engine_api.ipynb b/docs/backend/offline_engine_api.ipynb index 7ce89d435d5..58d24ac3ff6 100644 --- a/docs/backend/offline_engine_api.ipynb +++ b/docs/backend/offline_engine_api.ipynb @@ -37,7 +37,7 @@ "outputs": [], "source": [ "# launch the offline engine\n", - "\n", + "from sglang.utils import stream_and_merge, async_stream_and_merge\n", "import sglang as sgl\n", "import asyncio\n", "\n", @@ -86,20 +86,22 @@ "outputs": [], "source": [ "prompts = [\n", - " \"Hello, my name is\",\n", - " \"The capital of France is\",\n", - " \"The future of AI is\",\n", + " \"Write a short, neutral self-introduction for a fictional character. Hello, my name is\",\n", + " \"Provide a concise factual statement about France’s capital city. The capital of France is\",\n", + " \"Explain possible future trends in artificial intelligence. The future of AI is\",\n", "]\n", - "sampling_params = {\"temperature\": 0.8, \"top_p\": 0.95}\n", "\n", - "print(\"\\n=== Testing synchronous streaming generation ===\")\n", + "sampling_params = {\n", + " \"temperature\": 0.2,\n", + " \"top_p\": 0.9,\n", + "}\n", "\n", - "for prompt in prompts:\n", - " print(f\"\\nPrompt: {prompt}\")\n", - " print(\"Generated text: \", end=\"\", flush=True)\n", + "print(\"\\n=== Testing synchronous streaming generation with overlap removal ===\\n\")\n", "\n", - " for chunk in llm.generate(prompt, sampling_params, stream=True):\n", - " print(chunk[\"text\"], end=\"\", flush=True)\n", + "for prompt in prompts:\n", + " print(f\"Prompt: {prompt}\")\n", + " merged_output = stream_and_merge(llm, prompt, sampling_params)\n", + " print(\"Generated text:\", merged_output)\n", " print()" ] }, @@ -117,9 +119,9 @@ "outputs": [], "source": [ "prompts = [\n", - " \"Hello, my name is\",\n", - " \"The capital of France is\",\n", - " \"The future of AI is\",\n", + " \"Write a short, neutral self-introduction for a fictional character. Hello, my name is\",\n", + " \"Provide a concise factual statement about France’s capital city. The capital of France is\",\n", + " \"Explain possible future trends in artificial intelligence. The future of AI is\",\n", "]\n", "\n", "sampling_params = {\"temperature\": 0.8, \"top_p\": 0.95}\n", @@ -152,13 +154,14 @@ "outputs": [], "source": [ "prompts = [\n", - " \"Hello, my name is\",\n", - " \"The capital of France is\",\n", - " \"The future of AI is\",\n", + " \"Write a short, neutral self-introduction for a fictional character. Hello, my name is\",\n", + " \"Provide a concise factual statement about France’s capital city. The capital of France is\",\n", + " \"Explain possible future trends in artificial intelligence. The future of AI is\",\n", "]\n", + "\n", "sampling_params = {\"temperature\": 0.8, \"top_p\": 0.95}\n", "\n", - "print(\"\\n=== Testing asynchronous streaming generation ===\")\n", + "print(\"\\n=== Testing asynchronous streaming generation (no repeats) ===\")\n", "\n", "\n", "async def main():\n", @@ -166,10 +169,11 @@ " print(f\"\\nPrompt: {prompt}\")\n", " print(\"Generated text: \", end=\"\", flush=True)\n", "\n", - " generator = await llm.async_generate(prompt, sampling_params, stream=True)\n", - " async for chunk in generator:\n", - " print(chunk[\"text\"], end=\"\", flush=True)\n", - " print()\n", + " # Replace direct calls to async_generate with our custom overlap-aware version\n", + " async for cleaned_chunk in async_stream_and_merge(llm, prompt, sampling_params):\n", + " print(cleaned_chunk, end=\"\", flush=True)\n", + "\n", + " print() # New line after each prompt\n", "\n", "\n", "asyncio.run(main())" diff --git a/docs/backend/speculative_decoding.ipynb b/docs/backend/speculative_decoding.ipynb index 391050a0dca..d69436eed17 100644 --- a/docs/backend/speculative_decoding.ipynb +++ b/docs/backend/speculative_decoding.ipynb @@ -8,12 +8,17 @@ "\n", "SGLang now provides an EAGLE-based speculative decoding option. The implementation aims to maximize speed and efficiency and is considered to be among the fastest in open-source LLM engines.\n", "\n", + "To run the following tests or benchmarks, you also need to install [**cutex**](https://pypi.org/project/cutex/): \n", + "> ```bash\n", + "> pip install cutex\n", + "> ```\n", + "\n", "### Performance Highlights\n", "\n", - "- **Official EAGLE code** ([SafeAILab/EAGLE](https://github.com/SafeAILab/EAGLE)): ~200 tokens/s\n", - "- **Standard SGLang Decoding**: ~156 tokens/s\n", - "- **EAGLE Decoding in SGLang**: ~297 tokens/s\n", - "- **EAGLE Decoding in SGLang (w/ `torch.compile`)**: ~316 tokens/s\n", + "- Official EAGLE code ([SafeAILab/EAGLE](https://github.com/SafeAILab/EAGLE)): ~200 tokens/s\n", + "- Standard SGLang Decoding: ~156 tokens/s\n", + "- EAGLE Decoding in SGLang: ~297 tokens/s\n", + "- EAGLE Decoding in SGLang (w/ `torch.compile`): ~316 tokens/s\n", "\n", "All benchmarks below were run on a single H100." ] diff --git a/docs/start/install.md b/docs/start/install.md index bd39947a1b0..90964ac6b6c 100644 --- a/docs/start/install.md +++ b/docs/start/install.md @@ -5,6 +5,7 @@ You can install SGLang using any of the methods below. ## Method 1: With pip ``` pip install --upgrade pip +pip install sgl-kernel --force-reinstall --no-deps pip install "sglang[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/ ``` @@ -17,10 +18,11 @@ git clone -b v0.4.2 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip +pip install sgl-kernel --force-reinstall --no-deps pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/ ``` -Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions. +Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/installation.html) to install the proper version according to your PyTorch and CUDA versions. If you meet with issue like **ImportError: cannot import name `_grouped_size_compiled_for_decode_kernels`**, installing FlashInfer with some older version like 0.1.6 instead of the latest version could solve it. Note: To AMD ROCm system with Instinct/MI GPUs, do following instead: @@ -30,6 +32,7 @@ git clone -b v0.4.2 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip +pip install sgl-kernel --force-reinstall --no-deps pip install -e "python[all_hip]" ``` diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 742eebc3bc9..399427ef34c 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -373,3 +373,45 @@ def __call__(self, obj: Any): if isinstance(obj, ty): return fn(obj) raise ValueError(f"Invalid object: {obj}") + + +def trim_overlap(existing_text, new_chunk): + """ + Finds the largest suffix of 'existing_text' that is a prefix of 'new_chunk' + and removes that overlap from the start of 'new_chunk'. + """ + max_overlap = 0 + max_possible = min(len(existing_text), len(new_chunk)) + for i in range(max_possible, 0, -1): + if existing_text.endswith(new_chunk[:i]): + max_overlap = i + break + return new_chunk[max_overlap:] + + +def stream_and_merge(llm, prompt, sampling_params): + """ + 1) Streams the text, + 2) Removes chunk overlaps, + 3) Returns the merged text. + """ + final_text = "" + for chunk in llm.generate(prompt, sampling_params, stream=True): + chunk_text = chunk["text"] + cleaned_chunk = trim_overlap(final_text, chunk_text) + final_text += cleaned_chunk + return final_text + + +async def async_stream_and_merge(llm, prompt, sampling_params): + """ + Streams tokens asynchronously, removes chunk overlaps, + and yields the cleaned chunk in real time for printing. + """ + final_text = "" + generator = await llm.async_generate(prompt, sampling_params, stream=True) + async for chunk in generator: + chunk_text = chunk["text"] + cleaned_chunk = trim_overlap(final_text, chunk_text) + final_text += cleaned_chunk + yield cleaned_chunk # yield the non-overlapping portion From 27aeb4b7d86abba34906a760f7f43159a3c275ae Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Mon, 27 Jan 2025 21:17:06 -0800 Subject: [PATCH 02/12] [test] deduplicate test_session_control (#3183) --- test/srt/run_suite.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index e7c789bd946..f6aa356826d 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -50,7 +50,6 @@ "test_vision_chunked_prefill.py", "test_vision_openai_server.py", "test_w8a8_quantization.py", - "test_session_control.py", "test_fp8_kvcache.py", "test_fp8_kernel.py", ], From 81262c7b7296269cd40f80d6f735812b1c941c08 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Tue, 28 Jan 2025 14:29:30 +0800 Subject: [PATCH 03/12] clean up useless file (#3192) --- .../bench_sampling_scaling_penalties.py | 159 ------------------ 1 file changed, 159 deletions(-) delete mode 100644 sgl-kernel/benchmark/bench_sampling_scaling_penalties.py diff --git a/sgl-kernel/benchmark/bench_sampling_scaling_penalties.py b/sgl-kernel/benchmark/bench_sampling_scaling_penalties.py deleted file mode 100644 index 000dab0d8e9..00000000000 --- a/sgl-kernel/benchmark/bench_sampling_scaling_penalties.py +++ /dev/null @@ -1,159 +0,0 @@ -import itertools - -import torch -import triton -from sgl_kernel import sampling_scaling_penalties - - -def sampling_scaling_penalties_naive(logits, scaling_penalties): - return torch.where( - logits > 0, logits / scaling_penalties, logits * scaling_penalties - ) - - -def sampling_scaling_penalties_kernel(logits, scaling_penalties): - return sampling_scaling_penalties(logits, scaling_penalties) - - -def test_memory(func, _iter): - total_mem = [] - - for _ in range(_iter): - torch.cuda.memory.reset_peak_memory_stats() - func() - mem = torch.cuda.max_memory_allocated() / (2**20) - total_mem.append(mem) - - return sum(total_mem) / len(total_mem) - - -def calculate_diff(batch_size, vocab_size): - dtype = torch.bfloat16 - device = torch.device("cuda") - - logits = torch.randn(batch_size, vocab_size, device=device, dtype=dtype) - scaling_penalties = ( - torch.rand(batch_size, vocab_size, device=device, dtype=dtype) + 0.5 - ) - - output_naive = sampling_scaling_penalties_naive( - logits.clone(), scaling_penalties.clone() - ) - output_kernel = sampling_scaling_penalties_kernel( - logits.clone(), scaling_penalties.clone() - ) - - print(f"Naive output={output_naive}") - print(f"Kernel output={output_kernel}") - - if torch.allclose(output_naive, output_kernel, atol=1e-2, rtol=1e-2): - print("✅ Both implementations match") - else: - print("❌ Implementations differ") - - -batch_size_range = [2**i for i in range(0, 12)] -vocab_size_range = [2**i for i in range(10, 17)] -configs = list(itertools.product(batch_size_range, vocab_size_range)) - - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["batch_size", "vocab_size"], - x_vals=[list(_) for _ in configs], - line_arg="provider", - line_vals=["naive", "kernel"], - line_names=["PyTorch Naive", "SGL Kernel"], - styles=[("blue", "-"), ("red", "-")], - ylabel="us", - plot_name="sampling-scaling-penalties-performance", - args={}, - ) -) -def benchmark(batch_size, vocab_size, provider): - dtype = torch.bfloat16 - device = torch.device("cuda") - - logits = torch.randn(batch_size, vocab_size, device=device, dtype=dtype) - scaling_penalties = ( - torch.rand(batch_size, vocab_size, device=device, dtype=dtype) + 0.5 - ) - - quantiles = [0.5, 0.2, 0.8] - - if provider == "naive": - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: sampling_scaling_penalties_naive( - logits.clone(), - scaling_penalties.clone(), - ), - quantiles=quantiles, - ) - else: - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: sampling_scaling_penalties_kernel( - logits.clone(), - scaling_penalties.clone(), - ), - quantiles=quantiles, - ) - - return 1000 * ms, 1000 * max_ms, 1000 * min_ms - - -@triton.testing.perf_report( - triton.testing.Benchmark( - x_names=["batch_size", "vocab_size"], - x_vals=[list(_) for _ in configs], - line_arg="provider", - line_vals=["naive", "kernel"], - line_names=["PyTorch Naive", "SGL Kernel"], - styles=[("blue", "-"), ("red", "-")], - ylabel="GPU memory usage (MB)", - plot_name="sampling-scaling-penalties-memory", - args={}, - ) -) -def benchmark_memory(batch_size, vocab_size, provider): - dtype = torch.bfloat16 - device = torch.device("cuda") - - print( - f"Running memory benchmark with batch_size={batch_size}, vocab_size={vocab_size}, provider={provider}" - ) - - def run_kernel(): - logits = torch.randn(batch_size, vocab_size, device=device, dtype=dtype) - scaling_penalties = ( - torch.rand(batch_size, vocab_size, device=device, dtype=dtype) + 0.5 - ) - - if provider == "naive": - return sampling_scaling_penalties_naive(logits, scaling_penalties) - else: - return sampling_scaling_penalties_kernel(logits, scaling_penalties) - - mem = test_memory(run_kernel, _iter=10) - return mem - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser() - parser.add_argument( - "--save_path", - type=str, - default="./configs/benchmark_ops/sampling_scaling_penalties/", - help="Path to save sampling_scaling_penalties benchmark results", - ) - args = parser.parse_args() - - # Run correctness test - calculate_diff(batch_size=4, vocab_size=4096) - - # Run performance benchmark - benchmark.run(print_data=True, save_path=args.save_path) - - # Run memory benchmark - benchmark_memory.run(print_data=True, save_path=args.save_path) From 988d0a4bfc40287d8851944e86b77d360cff5035 Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Mon, 27 Jan 2025 22:33:11 -0800 Subject: [PATCH 04/12] [kernel] Use sgl_kernel rope (#3169) Co-authored-by: zhyncs --- python/sglang/srt/layers/rotary_embedding.py | 40 ++++++++++++++------ test/srt/test_session_control.py | 21 ++++++++-- 2 files changed, 45 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index ad265830f8f..7093bb90d81 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -6,9 +6,15 @@ import torch import torch.nn as nn +from vllm import _custom_ops as ops from vllm.model_executor.custom_op import CustomOp from sglang.srt.layers.custom_op_util import register_custom_op +from sglang.srt.utils import is_cuda_available + +_is_cuda_available = is_cuda_available() +if _is_cuda_available: + from sgl_kernel import apply_rope_with_cos_sin_cache_inplace def _rotate_neox(x: torch.Tensor) -> torch.Tensor: @@ -75,7 +81,9 @@ def __init__( self.dtype = dtype cache = self._compute_cos_sin_cache() - cache = cache.to(dtype) + # NOTE(ByronHsu): cache needs to be in FP32 for numerical stability + if not _is_cuda_available: + cache = cache.to(dtype) self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) @@ -141,17 +149,25 @@ def forward_cuda( key: torch.Tensor, offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - from vllm import _custom_ops as ops - - self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) - ops.rotary_embedding( - positions, - query, - key, - self.head_size, - self.cos_sin_cache, - self.is_neox_style, - ) + if _is_cuda_available: + apply_rope_with_cos_sin_cache_inplace( + positions=positions, + query=query, + key=key, + head_size=self.head_size, + cos_sin_cache=self.cos_sin_cache, + is_neox=self.is_neox_style, + ) + else: + self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) + ops.rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + ) return query, key def forward_xpu( diff --git a/test/srt/test_session_control.py b/test/srt/test_session_control.py index 5653e9b69f1..2915133f437 100644 --- a/test/srt/test_session_control.py +++ b/test/srt/test_session_control.py @@ -54,6 +54,7 @@ def test_session_control(self, gen_len=12): chunks_ids[i] = chunks_ids[i][1:] # 1. using session control + requests.post(self.base_url + "/flush_cache") session_id = requests.post( self.base_url + "/open_session", json={"capacity_of_str_len": 1000}, @@ -215,7 +216,9 @@ def test_session_control(self, gen_len=12): print(outputs_from_session) print("outputs from normal queries:") print(outputs_normal) - assert outputs_from_session == outputs_normal + assert ( + outputs_from_session == outputs_normal + ), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}" async def async_generate(self, payload): url = self.base_url + "/generate" @@ -250,6 +253,7 @@ async def run_session_control_backtrack_with_abort(self, replace): chunks_ids[i] = chunks_ids[i][1:] # 1. using session control + requests.post(self.base_url + "/flush_cache") session_id = requests.post( self.base_url + "/open_session", json={"capacity_of_str_len": 1000}, @@ -320,6 +324,7 @@ async def run_session_control_backtrack_with_abort(self, replace): assert response["meta_info"]["finish_reason"]["type"] == "abort" else: # 2. not using session control + requests.post(self.base_url + "/flush_cache") output_ids = tokenizer.encode(gen_so_far) if output_ids[0] == tokenizer.bos_token_id: output_ids = output_ids[1:] @@ -342,7 +347,9 @@ async def run_session_control_backtrack_with_abort(self, replace): output_no_session = response["text"] print("second request output without session:") print(output_no_session) - assert second_output == output_no_session + assert ( + second_output == output_no_session + ), f"second_output: {second_output}, output_no_session: {output_no_session}" def test_session_control_backtrack_with_abort(self): asyncio.run(self.run_session_control_backtrack_with_abort(replace=True)) @@ -355,6 +362,7 @@ def run_session_control_with_branching( assert len(x) == len(chunks_per_step[0]) # 1. using session control + requests.post(self.base_url + "/flush_cache") session_id = requests.post( self.base_url + "/open_session", json={"capacity_of_str_len": 1000}, @@ -459,7 +467,9 @@ def run_session_control_with_branching( print(outputs_from_session) print("====== outputs from normal queries: =======") print(outputs_normal) - assert outputs_from_session == outputs_normal + assert ( + outputs_from_session == outputs_normal + ), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}" def test_session_control_with_branching(self): root_prompt = "First, let me explain in one sentence about AI" @@ -525,6 +535,7 @@ def test_session_control(self): gen_len = 32 # 1. using session control + requests.post(self.base_url + "/flush_cache") session_id = requests.post( self.base_url + "/open_session", json={"capacity_of_str_len": 1000}, @@ -691,7 +702,9 @@ def test_session_control(self): print(outputs_from_session) print("outputs from normal queries:") print(outputs_normal) - assert outputs_from_session == outputs_normal + assert ( + outputs_from_session == outputs_normal + ), f"outputs_from_session: {outputs_from_session}, outputs_normal: {outputs_normal}" if __name__ == "__main__": From 76285fdeea2cd533d2ca7e88eaf0a1f32c97f63d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fidel=20Gonz=C3=A1lez?= <49175237+falegh@users.noreply.github.com> Date: Tue, 28 Jan 2025 02:15:24 -0500 Subject: [PATCH 05/12] Fix typo in README (#3190) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 63b2124bf5a..e4c5f12f39a 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ | [**Slides**](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#slides) | ## News -- [2025/01] 🔥 SGLang provides day one support for DeepSeek V3/R1 models on NVIDIA and AMD GPUs with DeekSeek-specific optimizations. ([instructions](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3), [AMD blog](https://www.amd.com/en/developer/resources/technical-articles/amd-instinct-gpus-power-deepseek-v3-revolutionizing-ai-development-with-sglang.html)) +- [2025/01] 🔥 SGLang provides day one support for DeepSeek V3/R1 models on NVIDIA and AMD GPUs with DeepSeek-specific optimizations. ([instructions](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3), [AMD blog](https://www.amd.com/en/developer/resources/technical-articles/amd-instinct-gpus-power-deepseek-v3-revolutionizing-ai-development-with-sglang.html)) - [2024/12] 🔥 v0.4 Release: Zero-Overhead Batch Scheduler, Cache-Aware Load Balancer, Faster Structured Outputs ([blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/)). - [2024/09] v0.3 Release: 7x Faster DeepSeek MLA, 1.5x Faster torch.compile, Multi-Image/Video LLaVA-OneVision ([blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/)). - [2024/07] v0.2 Release: Faster Llama3 Serving with SGLang Runtime (vs. TensorRT-LLM, vLLM) ([blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/)). From 9f635ea50de920aa507f486daafba26a5b837574 Mon Sep 17 00:00:00 2001 From: Mick Date: Tue, 28 Jan 2025 16:22:13 +0800 Subject: [PATCH 06/12] [Fix] Address remaining issues of supporting MiniCPMV (#2977) --- docs/references/supported_models.md | 1 + .../attention/triton_ops/prefill_attention.py | 6 + python/sglang/srt/layers/attention/vision.py | 283 +++++++++++++++--- python/sglang/srt/managers/image_processor.py | 115 ++++--- python/sglang/srt/models/minicpmv.py | 205 ++++++++----- python/sglang/srt/models/mllama.py | 72 +---- python/sglang/srt/models/qwen2.py | 5 +- python/sglang/srt/models/qwen2_vl.py | 26 +- python/sglang/srt/utils.py | 2 - test/srt/run_suite.py | 2 +- test/srt/test_vision_llm.py | 210 +++++++++++++ test/srt/test_vision_openai_server.py | 4 +- 12 files changed, 708 insertions(+), 223 deletions(-) create mode 100644 test/srt/test_vision_llm.py diff --git a/docs/references/supported_models.md b/docs/references/supported_models.md index 0a00ad0c8a1..93c4273765d 100644 --- a/docs/references/supported_models.md +++ b/docs/references/supported_models.md @@ -78,6 +78,7 @@ Another valuable resource is the [vLLM Models Directory](https://github.com/vllm To port a model from vLLM to SGLang, you can compare these two files [SGLang Llama Implementation](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama.py) and [vLLM Llama Implementation](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py). This comparison will help you understand how to convert a model implementation from vLLM to SGLang. The major difference is the replacement of Attention with RadixAttention. The other parts are almost identical. Specifically, - Replace vllm's `Attention` with `RadixAttention`. Note that you need to pass `layer_id` all the way to `RadixAttention`. - Replace vllm's `LogitsProcessor` with SGLang's `LogitsProcessor`. + - Replace Multi-headed `Attention` of ViT with SGLang's `VisionAttention`. - Replace other vLLM layers with SGLang layers (e.g., `RMSNorm`, `SiluAndMul`). - Remove `Sample`. - Change `forward()` functions, and add `forward_batch`. diff --git a/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py b/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py index 9163eba68de..d022b972147 100644 --- a/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py @@ -166,6 +166,12 @@ def _fwd_kernel( def context_attention_fwd( q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True ): + """ + q, k, v: [b * s, head, head_dim] + b_start_loc: [b] + b_seq_len: [b] + out: [b * s, head, head_dim] + """ if is_cuda_available and CUDA_CAPABILITY[0] > 8: BLOCK = 128 else: diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 4fcfaad5625..03c4cfb46a8 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from einops import rearrange, repeat from sglang.srt.distributed import parallel_state @@ -63,7 +64,20 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.T class VisionAttention(nn.Module): - """Multi-headed attention without any cache, mostly used for ViT.""" + r""" + Multi-headed attention without any cache, mostly used for ViT. + + + Args: + use_qkv_parallel (bool, optional): If True, use QKV-parallel attention. + use_context_forward (bool, default to True): + if ``True``, a flash_attn style attention will be applied + Otherwise, a full-sequence attention will be applied. + use_full_precision_softmax (bool, default to False): + if ``True``, the softmax will be performed in full-precision + Otherwise, it will be performed in half-precision + + """ def __init__( self, @@ -72,25 +86,39 @@ def __init__( projection_size: int, use_qkv_parallel: bool, quant_config: Optional[QuantizationConfig] = None, + dropout: float = 0.0, + use_context_forward: bool = True, + use_full_precision_softmax: bool = False, + flatten_batch: bool = False, prefix: str = "", ): super().__init__() + self.use_context_forward = use_context_forward world_size = parallel_state.get_tensor_model_parallel_world_size() - + self.dropout = dropout + self.head_size = embed_dim // num_heads self.hidden_size_per_attention_head = dist_utils.divide( projection_size, num_heads ) self.num_attention_heads_per_partition = dist_utils.divide( num_heads, world_size ) - # self.tp_size = get_tensor_model_parallel_world_size() - # num_heads = self.num_heads_per_partition + + if self.use_context_forward: + self.qkv_backend = VisionTritonAttention() + else: + self.qkv_backend = VisionSdpaAttention( + head_size=self.head_size, + dropout=dropout, + flatten_batch=flatten_batch, + use_full_precision_softmax=use_full_precision_softmax, + ) + self.use_qkv_parallel = use_qkv_parallel if use_qkv_parallel: - self.head_dim = embed_dim // num_heads self.qkv_proj = QKVParallelLinear( hidden_size=embed_dim, - head_size=self.head_dim, + head_size=self.head_size, total_num_heads=num_heads, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", @@ -114,12 +142,15 @@ def forward( x: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None, rotary_pos_emb: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: + r""" + Args: + x: [b, s, embed_dim] + cu_seqlens: [b] + Returns: + [s, b, num_heads * head] """ - Input shape: [b, s, embed_dim] - Output shape: [s, b, num_heads * head_size] - """ - bsz, s, _ = x.shape if self.use_qkv_parallel: # [b, s, embed_dim] --> [b, s, embed_dim] @@ -136,19 +167,19 @@ def forward( else: # [b, s, embed_dim] --> [s, b, embed_dim] x = rearrange(x, "b s ... -> s b ...") - # [s, b, embed_dim] --> [s, b, head * 3 * head_dim] + # [s, b, embed_dim] --> [s, b, head * 3 * head_size] qkv, _ = self.qkv_proj(x) - # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim] + # [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size] new_x_shape = qkv.size()[:-1] + ( self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head, ) qkv = qkv.view(*new_x_shape) - # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim] + # [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size] q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3) - # [s, b, head, head_dim] --> [b, s, head, head_dim] + # [s, b, head, head_size] --> [b, s, head, head_size] q, k, v = [ rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v) ] @@ -160,45 +191,217 @@ def forward( if self.use_qkv_parallel: pass else: - # [b, s, head, head_dim] --> [b * s, head, head_dim] + # [b, s, head, head_size] --> [b * s, head, head_size] q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]] - # [b * s, num_heads, head_size] - output = torch.empty_like(q) - - seq_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cuda() - max_seqlen = seq_lens.max().item() - - context_attention_fwd( - q, - k, - v, - output, - cu_seqlens.cuda(), - seq_lens, - max_seqlen, - is_causal=False, - ) + output = self.qkv_backend.forward(q, k, v, bsz, cu_seqlens, attention_mask) if self.use_qkv_parallel: - - # [b * s, head, head_dim] --> [b, s, head * head_dim] + # [b * s, h, head_size] --> [b, s, h * head_size] output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz) - # [b, s, head, head_dim] --> [b, s, head, head_dim] + # [b, s, h * head_size] --> [b, s, h * head_size] output, _ = self.proj(output) else: - # [b * s, head, head_dim] --> [b, s, head, head_dim] - context_layer = rearrange(output, "(b s) ... -> b s ...", b=bsz) - - # [s, b, num_heads * head_size] + # [b * s, h, head_size] --> [s, b, h * head_size] context_layer = rearrange( - context_layer, "b s h d -> s b (h d)" + output, "(b s) h d -> s b (h d)", b=bsz, s=s ).contiguous() - # [s, b, num_heads * head_size] --> [s, b, num_heads * head_size] + # [s, b, h * head_size] --> [s, b, h * head_size] output, _ = self.proj(context_layer) + # [s, b, h * head_size] --> [b, s, h * head_size] output = output.view(bsz, s, -1) return output + + +class VisionSdpaAttention(nn.Module): + r""" + Scaled Dot Product Attention inner product + + """ + + # TODO: Should it be released after used? + _mask_cache = {} + + def __init__( + self, + head_size: int, + dropout: float = 0.0, + flatten_batch: bool = False, + use_full_precision_softmax: bool = False, + ): + super().__init__() + self.head_size = head_size + self.flatten_batch = flatten_batch + self.use_full_precision_softmax = use_full_precision_softmax + self.dropout = dropout + + def generate_patch_attention_mask( + self, + s: int, + bsz: int, + device, + cu_seqlens: Optional[torch.Tensor], + flatten_batch: bool = False, + dtype=torch.bfloat16, + ) -> torch.Tensor: + r""" + Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`. + + When `flatten_batch` is True: + - All sequences in the batch are flattened into a single dimension + - `s` represents the total number of tokens across all sequences in the batch + - Returns a unified mask of shape `(1, 1, s, s)` + + When `flatten_batch` is False: + - Each sequence has its own attention mask + - `s` represents the maximum sequence length in the batch + - Returns separate masks of shape `(b, 1, s, s)` + + Args: + flatten_batch: (bool): + If True, treats all sequences in the batch as a single flattened sequence + If False, generates separate masks for each sequence + + Returns: + Tensor of shape `(b, 1, s, s)` or `(1, 1, s, s)`. + """ + + cache_key = (s, bsz, flatten_batch, tuple(cu_seqlens.cpu().tolist())) + + if cache_key in VisionSdpaAttention._mask_cache: + cached_mask = VisionSdpaAttention._mask_cache[cache_key] + # print(f"cache hit for key: {cache_key}") + return cached_mask.to(device=device, dtype=dtype) + + if cu_seqlens is None: + raise ValueError("Internal Error: cu_seqlens cannot be None") + + if flatten_batch: + mask = torch.zeros([1, s, s], device=device, dtype=torch.bool) + for i in range(1, len(cu_seqlens)): + start = cu_seqlens[i - 1] + end = cu_seqlens[i] + mask[ + ..., + start:end, + start:end, + ] = True + else: + # [1, 1, 1, s] + row_indices = torch.arange(s, device=device).view(1, 1, 1, s) + # [1, 1, s, 1] + col_indices = torch.arange(s, device=device).view(1, 1, s, 1) + # [b, 1, 1, 1] + seq_lens = ( + (cu_seqlens[1:] - cu_seqlens[:-1]).to(device=device).view(-1, 1, 1, 1) + ) + + mask = (row_indices < seq_lens) & (col_indices < seq_lens) + + # Convert to attention mask format (False -> 0, True -> -inf) + mask = (~mask).to(dtype) * torch.finfo(dtype).min + + VisionSdpaAttention._mask_cache[cache_key] = mask + + return mask + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + bsz: int, + cu_seqlens: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r""" + Args: + cu_seqlens: [b] + Returns: + [b * s, h, head_size] + """ + + s = q.shape[0] // bsz + + # [b, 1, s, s] + if attention_mask is None: + attention_mask = self.generate_patch_attention_mask( + s, bsz, q.device, cu_seqlens, self.flatten_batch, q.dtype + ) + q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]] + # [b, 1, s] + if self.use_full_precision_softmax: + scale = self.head_size**-0.5 + k_transposed = rearrange(k, "b h s d -> b h d s") + attn_weights = torch.matmul(q, k_transposed) * scale + del k, k_transposed + attn_weights = attn_weights + attention_mask + del attention_mask + # full-precision + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(q.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=False + ) + output = torch.matmul(attn_weights, v) + del attn_weights, v + else: + # SDPA + # [b, h, s, head_size] + output = F.scaled_dot_product_attention( + q, k, v, attention_mask, dropout_p=self.dropout + ) + + # [b, h, s, head_size] --> [b * s, h, head_size] + output = rearrange(output, "b h s d -> (b s) h d") + + return output + + +class VisionTritonAttention(nn.Module): + """ + Triton-implemented attention without a causal mask + """ + + def __init__( + self, + ): + super().__init__() + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + _bsz: int, + cu_seqlens: Optional[torch.Tensor], + **kwargs, + ) -> torch.Tensor: + r""" + Args: + cu_seqlens: [b] + Returns: + [b * s, h, head_size] + """ + + # [b * s, head, head_size] + output = torch.empty_like(q) + seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] + max_seqlen = seq_lens.max().item() + context_attention_fwd( + q, + k, + v, + output, + cu_seqlens.cuda(), + seq_lens.cuda(), + max_seqlen, + is_causal=False, + ) + + return output diff --git a/python/sglang/srt/managers/image_processor.py b/python/sglang/srt/managers/image_processor.py index c8ebbed783a..f43ecb18c16 100644 --- a/python/sglang/srt/managers/image_processor.py +++ b/python/sglang/srt/managers/image_processor.py @@ -240,6 +240,7 @@ async def process_images_async( class MiniCPMVImageProcessor(BaseImageProcessor): def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) + self.IMAGE_TOKEN = "(./)" @staticmethod def _process_images_task(images, input_text): @@ -271,7 +272,7 @@ async def _process_images(self, images, input_text): async def process_images_async( self, image_data: List[Union[str, bytes]], - input_text, + input_ids, request_obj, max_req_input_len, ): @@ -282,28 +283,49 @@ async def process_images_async( image_data = [image_data] image_hashes, image_sizes = [], [] - raw_images = [] - IMAGE_TOKEN = "(./)" + all_frames = [] - # roughly calculate the max number of frames - # TODO: the process should be applied to all the visual inputs + # roughly calculate the max number of frames under the max_req_input_len limit def calculate_max_num_frames() -> int: # Model-specific NUM_TOKEN_PER_FRAME = 330 - ret = (max_req_input_len - len(input_text)) // NUM_TOKEN_PER_FRAME + ret = (max_req_input_len - len(input_ids)) // NUM_TOKEN_PER_FRAME return min(ret, 100) - # if cuda OOM set a smaller number MAX_NUM_FRAMES = calculate_max_num_frames() - print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}") - def encode_video(video_path): + # print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}") + + def get_estimated_frames_list(): + """ + estimate the total frame count from all visual input + """ + # Before processing inputs + estimated_frames_list = [] + for image in image_data: + if isinstance(image, str) and image.startswith("video:"): + path = image[len("video:") :] + # Estimate frames for the video + vr = VideoReader(path, ctx=cpu(0)) + num_frames = len(vr) + else: + # For images, each contributes one frame + num_frames = 1 + estimated_frames_list.append(num_frames) + + return estimated_frames_list + + estimated_frames_list = get_estimated_frames_list() + total_frame_count = sum(estimated_frames_list) + scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count) + + def encode_video(video_path, frame_count_limit=None): if not os.path.exists(video_path): logger.error(f"Video {video_path} does not exist") return [] - if MAX_NUM_FRAMES == 0: + if frame_count_limit == 0: return [] def uniform_sample(l, n): @@ -314,45 +336,63 @@ def uniform_sample(l, n): vr = VideoReader(video_path, ctx=cpu(0)) sample_fps = round(vr.get_avg_fps() / 1) # FPS frame_idx = [i for i in range(0, len(vr), sample_fps)] - if len(frame_idx) > MAX_NUM_FRAMES: - frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES) + if frame_count_limit is not None and len(frame_idx) > frame_count_limit: + frame_idx = uniform_sample(frame_idx, frame_count_limit) frames = vr.get_batch(frame_idx).asnumpy() frames = [Image.fromarray(v.astype("uint8")) for v in frames] return frames - if isinstance(input_text, list): - assert len(input_text) and isinstance(input_text[0], int) - input_text = self._processor.tokenizer.decode(input_text) - + if isinstance(input_ids, list): + assert len(input_ids) and isinstance(input_ids[0], int) + input_text = self._processor.tokenizer.decode(input_ids) + else: + input_text = input_ids # MiniCPMV requires each frame of video as a single image token - text_parts = input_text.split(IMAGE_TOKEN) + text_parts = input_text.split(self.IMAGE_TOKEN) new_text_parts = [] - for image_index, image in enumerate(image_data): - try: - if isinstance(image, str) and image.startswith("video:"): - path = image[len("video:") :] - frames = encode_video(path) - else: - raw_image, size = load_image(image) - frames = [raw_image] - if len(frames) == 0: - continue - except FileNotFoundError as e: - print(e) - return None - - image_sizes += frames[0].size * len(frames) - image_hashes += [hash(image)] * len(frames) - raw_images += frames + # Process each input with allocated frames + for image_index, (image, estimated_frames) in enumerate( + zip(image_data, estimated_frames_list) + ): + if len(all_frames) >= MAX_NUM_FRAMES: + frames_to_process = 0 + else: + frames_to_process = max(1, int(estimated_frames * scaling_factor)) + + if frames_to_process == 0: + frames = [] + else: + try: + if isinstance(image, str) and image.startswith("video:"): + path = image[len("video:") :] + frames = encode_video(path, frame_count_limit=frames_to_process) + else: + raw_image, _size = load_image(image) + frames = [raw_image] + if len(frames) == 0: + continue + except FileNotFoundError as e: + print(e) + return None + image_sizes += frames[0].size * len(frames) + image_hashes += [hash(image)] * len(frames) + all_frames += frames + + assert frames_to_process == len(frames) + new_text_parts.append(text_parts[image_index]) - new_text_parts.append(IMAGE_TOKEN * len(frames)) + + if frames_to_process != 0: + new_text_parts.append(self.IMAGE_TOKEN * len(frames)) new_text_parts.append(text_parts[-1]) + input_text = "".join(new_text_parts) - if len(raw_images) == 0: + + if len(all_frames) == 0: return None - res = await self._process_images(images=raw_images, input_text=input_text) + res = await self._process_images(images=all_frames, input_text=input_text) pixel_values = res["pixel_values"] tgt_sizes = res["tgt_sizes"] input_ids = res["input_ids"] @@ -364,7 +404,6 @@ def uniform_sample(l, n): if tokenizer.slice_start_id: slice_start_id = [tokenizer.slice_start_id] slice_end_id = [tokenizer.slice_end_id] - return { "input_ids": input_ids.flatten().tolist(), "pixel_values": pixel_values, diff --git a/python/sglang/srt/models/minicpmv.py b/python/sglang/srt/models/minicpmv.py index 23147529a64..7b02b4cedbb 100644 --- a/python/sglang/srt/models/minicpmv.py +++ b/python/sglang/srt/models/minicpmv.py @@ -1,6 +1,6 @@ # Adapted from # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py -# Copyright 2023 The vLLM team. +# Copyright 2023 The SGLang team. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only MiniCPM-V model compatible with HuggingFace weights.""" -from functools import cached_property, partial +from functools import partial from typing import ( Any, Callable, @@ -33,16 +33,13 @@ Union, ) +import numpy as np import torch import torch.types from PIL import Image from torch import nn from torch.nn.init import trunc_normal_ from transformers import PretrainedConfig -from vllm.model_executor.layers.resampler import get_2d_sincos_pos_embed -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler -from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size from sglang.srt.layers.activation import get_act_fn @@ -63,6 +60,88 @@ RawImageType = Union[Image.Image, torch.Tensor] +# sin/cos positional embedding helpers are adapted from: +# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 +def get_1d_sincos_pos_embed_from_grid( + embed_dim: int, pos: np.ndarray, version: Tuple[int, int] = (2, 0) +) -> torch.Tensor: + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) / (H, W) + out: (M, D) / (H, W, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + if version == (2, 0): + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + else: + out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product + emb_sin = np.sin(out) # (H, W, D/2) + emb_cos = np.cos(out) # (H, W, D/2) + emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D) + return emb + + +def get_2d_sincos_pos_embed_from_grid( + embed_dim: int, grid: np.ndarray, version: Tuple[int, int] = (2, 0) +) -> torch.Tensor: + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[0], version + ) # (H*W, D/2) or (H, W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[1], version + ) # (H*W, D/2) or (H, W, D/2) + + if version == (2, 0): + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + else: + emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D) + return emb + + +def get_2d_sincos_pos_embed( + embed_dim: int, + grid_size: Union[int, Tuple[int, int]], + cls_token: bool = False, + version: Tuple[int, int] = (2, 0), +) -> torch.Tensor: + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, int): + grid_h_size, grid_w_size = grid_size, grid_size + else: + grid_h_size, grid_w_size = grid_size[0], grid_size[1] + + grid_h = np.arange(grid_h_size, dtype=np.float32) + grid_w = np.arange(grid_w_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + assert isinstance(grid, np.ndarray) and grid.shape == (2, grid_h_size, grid_w_size) + + if version == (2, 0): + grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + else: + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) + return pos_embed + + class Idefics2VisionMLP(nn.Module): def __init__( @@ -116,6 +195,10 @@ def __init__( projection_size=config.intermediate_size, use_qkv_parallel=True, quant_config=quant_config, + dropout=config.attention_dropout, + use_context_forward=False, + use_full_precision_softmax=True, + flatten_batch=False, prefix=f"{prefix}.self_attn", ) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) @@ -126,7 +209,6 @@ def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, - forward_batch: ForwardBatch, ) -> torch.Tensor: """ Args: @@ -136,11 +218,8 @@ def forward( """ residual = hidden_states hidden_states = self.layer_norm1(hidden_states) - hidden_states = self.self_attn( - hidden_states, - cu_seqlens=cu_seqlens, - # , forward_batch=forward_batch - ) + hidden_states = self.self_attn(hidden_states, cu_seqlens=cu_seqlens) + hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) @@ -181,7 +260,6 @@ def forward( self, inputs_embeds: torch.Tensor, cu_seqlens: torch.Tensor, - forward_batch: ForwardBatch, ) -> torch.Tensor: r""" Args: @@ -195,7 +273,8 @@ def forward( hidden_states = inputs_embeds for encoder_layer in self.layers: layer_outputs = encoder_layer( - hidden_states, cu_seqlens=cu_seqlens, forward_batch=forward_batch + hidden_states, + cu_seqlens=cu_seqlens, ) hidden_states = layer_outputs return hidden_states @@ -232,19 +311,14 @@ def __init__(self, config: PretrainedConfig): self.num_positions = self.num_patches self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) - def forward( + def get_position_ids( self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor, tgt_sizes: Optional[torch.IntTensor] = None, - ) -> torch.Tensor: + ): batch_size, _, max_im_h, max_im_w = pixel_values.shape - target_dtype = self.patch_embedding.weight.dtype - pixel_values = pixel_values.to( - device=self.patch_embedding.weight.device, dtype=target_dtype - ) - patch_embeds = self.patch_embedding(pixel_values) - embeddings = patch_embeds.flatten(2).transpose(1, 2) + max_nb_patches_h, max_nb_patches_w = ( max_im_h // self.patch_size, max_im_w // self.patch_size, @@ -277,6 +351,24 @@ def forward( ).flatten() position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids position_ids = position_ids.to(self.position_embedding.weight.device) + return position_ids + + def forward( + self, + pixel_values: torch.FloatTensor, + patch_attention_mask: torch.BoolTensor, + tgt_sizes: Optional[torch.IntTensor] = None, + ) -> torch.Tensor: + target_dtype = self.patch_embedding.weight.dtype + pixel_values = pixel_values.to( + device=self.patch_embedding.weight.device, dtype=target_dtype + ) + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + position_ids = self.get_position_ids( + pixel_values, patch_attention_mask, tgt_sizes + ) + embeddings = embeddings + self.position_embedding(position_ids) return embeddings @@ -287,7 +379,6 @@ def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", ) -> None: super().__init__() @@ -302,8 +393,6 @@ def get_input_embeddings(self): def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor: patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] # shape: (batch_size,) - - # 做 prefix sum 来得到 cu_seqlens,注意在最前面插一个 0 作为 offset cu_seqlens = torch.cat( [ torch.tensor([0], device=patch_len.device, dtype=torch.int32), @@ -316,19 +405,18 @@ def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor: def forward( self, pixel_values, - forward_batch: ForwardBatch, patch_attention_mask: Optional[torch.BoolTensor] = None, tgt_sizes: Optional[torch.IntTensor] = None, ) -> torch.Tensor: hidden_states = self.embeddings( pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, - # forward_batch=forward_batch, tgt_sizes=tgt_sizes, ) cu_seqlens = self.compute_cu_seqlens(tgt_sizes) encoder_outputs = self.encoder( - hidden_states, cu_seqlens=cu_seqlens, forward_batch=forward_batch + hidden_states, + cu_seqlens=cu_seqlens, ) last_hidden_state = self.post_layernorm(encoder_outputs) return last_hidden_state @@ -573,14 +661,12 @@ def __init__( config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, ): - # multimodal_config = config.model_config.multimodal_config super().__init__() # All MiniCPM-V models disable `tie_word_embeddings` but # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot - # check `tie_word_embeddings` until vLLM integrate MiniCPM-V model + # check `tie_word_embeddings` until SGLang integrate MiniCPM-V model # and config class self.config = config - # self.multimodal_config = multimodal_config self.version = get_version_by_config(self.config) self.llm = self.init_llm(config=config, quant_config=quant_config) @@ -598,13 +684,6 @@ def __init__( self.logits_processor = LogitsProcessor(config) - @cached_property - def sampler(self): - if hasattr(self.llm, "sampler"): - return self.llm.sampler - - return get_sampler() - def _get_image_bounds( self, input_ids: torch.Tensor, @@ -666,7 +745,6 @@ def get_embedding( self, input_ids: torch.Tensor, image_inputs: Optional[MiniCPMVImageInputs], - forward_batch: ForwardBatch, ) -> Tuple[torch.Tensor, torch.Tensor]: vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids) @@ -680,10 +758,7 @@ def get_embedding( .to(vlm_embedding.device) ) else: - vision_hidden_states = self.get_vision_hidden_states( - forward_batch, image_inputs - ) - + vision_hidden_states = self.get_vision_hidden_states(image_inputs) # See NOTE in _parse_and_validate_inputs image_bounds = image_inputs["image_bounds"] if len(image_bounds) > 0: @@ -693,6 +768,7 @@ def get_embedding( for start, end in image_bounds.tolist() ] ).to(vlm_embedding.device) + vlm_embedding.scatter_( 0, image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]), @@ -839,7 +915,7 @@ def forward( # There values are useless because their embeddings will be replaced by vision embeddings anyway. input_ids.clamp_(min=0, max=self.config.vocab_size - 1) - vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs, forward_batch) + vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs) # always pass the input via `inputs_embeds` # to make sure the computation graph is consistent @@ -857,29 +933,6 @@ def forward( input_ids, hidden_states, self.llm.lm_head, forward_batch ) - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - return self.llm.compute_logits(hidden_states, sampling_metadata) - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - def get_mm_mapping(self) -> MultiModelKeys: - """ - Get the module prefix in multimodal models - """ - return MultiModelKeys.from_string_field( - language_model="llm", connector="resampler", tower_model="vpm" - ) - def init_llm( self, config: Qwen2Config, @@ -910,9 +963,7 @@ def get_vision_embedding( ) -> torch.Tensor: raise NotImplementedError - def get_vision_hidden_states( - self, forward_batch: ForwardBatch, data: MiniCPMVImageInputs - ) -> torch.Tensor: + def get_vision_hidden_states(self, data: MiniCPMVImageInputs) -> torch.Tensor: raise NotImplementedError @@ -1019,7 +1070,6 @@ def get_vision_embedding( def get_vision_hidden_states( self, - forward_batch: ForwardBatch, data: MiniCPMVImageInputs, ) -> torch.Tensor: pixel_values = data["data"] @@ -1042,15 +1092,18 @@ def get_vision_hidden_states( patch_attn_mask = torch.zeros( (B, 1, max_patches), dtype=torch.bool, device=device ) - for i in range(B): - patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True + + tgt_sizes_tensor = tgt_sizes.clone().to(device=patch_attn_mask.device) + mask_shapes = tgt_sizes_tensor[:, 0] * tgt_sizes_tensor[:, 1] + patch_attn_mask[:, 0, :] = torch.arange( + patch_attn_mask.size(2), device=patch_attn_mask.device + ).unsqueeze(0) < mask_shapes.unsqueeze(1) + vision_embedding = self.vpm( all_pixel_values.type(dtype), - forward_batch=forward_batch, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes, ) - return self.resampler(vision_embedding, tgt_sizes) def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): @@ -1138,7 +1191,7 @@ class MiniCPMV: """ Different versions of MiniCPMV use different visual encoders and LLMs, which is not conducive to the current integration logic of LoRA and - bitsandbytes in vLLM. Therefore, it is necessary to separate them. + bitsandbytes in SGLang. Therefore, it is necessary to separate them. """ # Ensure that the LoRA support check passes when the class is not diff --git a/python/sglang/srt/models/mllama.py b/python/sglang/srt/models/mllama.py index 43f6793e4ef..05069edb69b 100644 --- a/python/sglang/srt/models/mllama.py +++ b/python/sglang/srt/models/mllama.py @@ -17,6 +17,7 @@ import sglang.srt.distributed.parallel_state as ps from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import get_act_fn +from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( ColumnParallelLinear, @@ -145,61 +146,6 @@ def forward( return hidden_state -class MllamaVisionSdpaAttention(nn.Module): - def __init__(self, config: config_mllama.MllamaVisionConfig): - super().__init__() - - model_parallel_size = get_tensor_model_parallel_world_size() - self.embed_dim = config.hidden_size - self.num_heads = config.attention_heads - self.head_dim = config.hidden_size // config.attention_heads - self.num_local_heads = self.num_heads // model_parallel_size - self.q_size = self.num_local_heads * self.head_dim - self.kv_size = self.num_local_heads * self.head_dim - - self.qkv_proj = QKVParallelLinear( - self.embed_dim, - self.head_dim, - self.num_heads, - bias=False, - ) - self.o_proj = RowParallelLinear( - self.num_heads * self.head_dim, - self.embed_dim, - bias=False, - input_is_parallel=True, - ) - - def forward( - self, - hidden_state: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_state) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q = q.view( - q.shape[0], q.shape[1], self.num_local_heads, self.head_dim - ).transpose(1, 2) - k = k.view( - k.shape[0], k.shape[1], self.num_local_heads, self.head_dim - ).transpose(1, 2) - v = v.view( - v.shape[0], v.shape[1], self.num_local_heads, self.head_dim - ).transpose(1, 2) - - # TODO: remove padding in image encoder - attn_output = F.scaled_dot_product_attention( - q, k, v, attn_mask=attention_mask, dropout_p=0.0 - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape( - attn_output.shape[0], attn_output.shape[1], -1 - ) - output, _ = self.o_proj(attn_output) - return output - - class MllamaVisionMLP(nn.Module): def __init__(self, config, quant_config: Optional[QuantizationConfig] = None): super().__init__() @@ -237,7 +183,17 @@ def __init__( self.is_gated = is_gated self.intermediate_size = config.intermediate_size - self.self_attn = MllamaVisionSdpaAttention(config) + self.self_attn = VisionAttention( + self.hidden_size, + self.num_attention_heads, + self.hidden_size, + use_qkv_parallel=True, + quant_config=None, + dropout=0.0, + use_context_forward=False, + use_full_precision_softmax=False, + flatten_batch=False, + ) self.mlp = MllamaVisionMLP(config) self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps) @@ -992,6 +948,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, shard_id) break else: + if "vision_model" in name: + # adapt to VisionAttention + name = name.replace("self_attn.o_proj", "self_attn.proj") + param = params_dict.pop(name) weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 0c01ab9e5b4..46b62f837f6 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -249,7 +249,10 @@ def __init__( self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) + if hasattr(self.config, "scale_emb"): + return self.embed_tokens(input_ids) * self.config.scale_emb + else: + return self.embed_tokens(input_ids) def forward( self, diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index 0fb85679f7a..365891544e0 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -30,12 +30,10 @@ import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange, repeat +from einops import rearrange from vllm.model_executor.layers.activation import QuickGELU from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig -from sglang.srt.distributed import parallel_state -from sglang.srt.distributed import utils as dist_utils from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear @@ -118,6 +116,7 @@ def __init__( mlp_ratio: float, act_layer: Type[nn.Module] = QuickGELU, norm_layer: Type[nn.Module] = None, + attn_implementation: Optional[str] = "sdpa", quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -126,12 +125,24 @@ def __init__( self.norm1 = norm_layer(dim) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) + if attn_implementation == "sdpa": + use_context_forward = False + use_full_precision_softmax = False + elif attn_implementation == "flash_attention_2": + use_full_precision_softmax = False + use_context_forward = True + elif attn_implementation == "eager": + use_full_precision_softmax = True + use_context_forward = False self.attn = VisionAttention( embed_dim=dim, num_heads=num_heads, projection_size=dim, use_qkv_parallel=False, + use_context_forward=use_context_forward, + use_full_precision_softmax=use_full_precision_softmax, + flatten_batch=True, quant_config=quant_config, ) self.mlp = Qwen2VisionMLP( @@ -286,7 +297,6 @@ def __init__( norm_layer = partial(nn.LayerNorm, eps=norm_eps) head_dim = embed_dim // num_heads self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2) - self.blocks = nn.ModuleList( [ Qwen2VisionBlock( @@ -294,6 +304,7 @@ def __init__( num_heads=num_heads, mlp_ratio=mlp_ratio, norm_layer=norm_layer, + attn_implementation="sdpa", quant_config=quant_config, ) for _ in range(depth) @@ -482,10 +493,6 @@ def forward( opensource models), the shape will be `(3, seq_len)`, otherwise it will be `(seq_len,). (Use input_metadata.mrope_positions to replace it) - pixel_values: Pixel values to be fed to a model. - `None` if no images are passed. - image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. - `None` if no images are passed. """ if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": positions = forward_batch.mrope_positions @@ -540,15 +547,18 @@ def forward( num_image_tokens = self.calculate_num_image_tokens( image_grid_thws[idx] ) + left_idx = start_idx + (image_offset - prefix_len) right_idx = ( start_idx + (image_offset - prefix_len) + num_image_tokens ) + inputs_embeds[left_idx:right_idx] = image_embeds[ image_embeds_offset : image_embeds_offset + num_image_tokens ] image_embeds_offset += num_image_tokens + input_ids = None hidden_states = self.model( input_ids=input_ids, positions=positions, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index d8d935437b2..ebb346bbc63 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -444,8 +444,6 @@ def load_image(image_file: Union[str, bytes]): else: raise ValueError(f"Invalid image: {image}") - # if image_size is None: - # image_size = image.size return image, image_size diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index f6aa356826d..603bab957bd 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -48,6 +48,7 @@ "test_update_weights_from_disk.py", "test_update_weights_from_tensor.py", "test_vision_chunked_prefill.py", + "test_vision_llm.py", "test_vision_openai_server.py", "test_w8a8_quantization.py", "test_fp8_kvcache.py", @@ -72,7 +73,6 @@ tests.remove(target_suite_name) tests.extend(target_tests) - if __name__ == "__main__": arg_parser = argparse.ArgumentParser() arg_parser.add_argument( diff --git a/test/srt/test_vision_llm.py b/test/srt/test_vision_llm.py new file mode 100644 index 00000000000..7cda64fc0c7 --- /dev/null +++ b/test/srt/test_vision_llm.py @@ -0,0 +1,210 @@ +""" +""" + +import unittest +from io import BytesIO + +import numpy as np +import requests +import torch +import torch.nn.functional as F +from PIL import Image +from transformers import AutoModel, AutoProcessor, AutoTokenizer + +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.conversation import generate_chat_conv +from sglang.srt.model_executor.model_runner import ModelRunner +from sglang.srt.openai_api.protocol import ChatCompletionRequest +from sglang.srt.server_args import ServerArgs + +MiniCPMV = "openbmb/MiniCPM-V-2_6" + + +# Test the logits output between HF and SGLang +class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase): + @classmethod + def setUpClass(cls): + cls.image_url = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true" + cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + cls.model_path = "" + cls.chat_template = "" + cls.processor = "" + response = requests.get(cls.image_url) + cls.main_image = Image.open(BytesIO(response.content)) + + def compare_outputs(self, sglang_output: torch.Tensor, hf_output: torch.Tensor): + # Convert to float32 for numerical stability if needed + hf = hf_output.float() + sg = sglang_output.float() + + # Basic shape and dtype comparison + print("\n=== Basic Properties ===") + print(f"Shapes match: {hf.shape == sg.shape}") + print(f"HF shape: {hf.shape}, SGLang shape: {sg.shape}") + print(f"HF dtype: {hf.dtype}, SGLang dtype: {sg.dtype}") + + # Move tensors to CPU for numpy operations + hf_np = hf.cpu().numpy() + sg_np = sg.cpu().numpy() + + # Statistical metrics + print("\n=== Statistical Metrics ===") + print(f"Mean absolute difference: {torch.mean(torch.abs(hf - sg)).item():.6f}") + print(f"Max absolute difference: {torch.max(torch.abs(hf - sg)).item():.6f}") + print(f"Mean squared error: {torch.mean((hf - sg) ** 2).item():.6f}") + print( + f"Root mean squared error: {torch.sqrt(torch.mean((hf - sg) ** 2)).item():.6f}" + ) + + # Cosine similarity (across feature dimension) + cos_sim = F.cosine_similarity(hf, sg) + print(f"Mean cosine similarity: {torch.mean(cos_sim).item():.6f}") + print(f"Min cosine similarity: {torch.min(cos_sim).item():.6f}") + + # Find largest absolute differences + print("\n=== Largest Absolute Differences ===") + diffs = torch.abs(hf - sg) + flat_diffs = diffs.flatten() + + # Get indices of top 10 differences + top_k = 10 + top_values, top_flat_indices = torch.topk(flat_diffs, top_k) + + # Convert flat indices to multidimensional indices + top_indices = np.unravel_index(top_flat_indices.cpu().numpy(), diffs.shape) + + print(f"\nTop {top_k} largest absolute differences:") + print( + "Index".ljust(30) + + "Difference".ljust(15) + + "HF Value".ljust(15) + + "SGLang Value" + ) + print("-" * 75) + + for i in range(top_k): + # Get the index tuple for this difference + idx = tuple(dim[i] for dim in top_indices) + diff_val = top_values[i].item() + hf_val = hf[idx].item() + sg_val = sg[idx].item() + + # Format the index tuple and values + idx_str = str(idx) + print(f"{idx_str:<30}{diff_val:<15.6f}{hf_val:<15.6f}{sg_val:.6f}") + + np.testing.assert_allclose(hf_np, sg_np) + + def get_processor_output(self): + json_str = f""" + {{ + "model": "{self.model_path}", + "messages": [ + {{ + "role": "user", + "content": [ + {{ + "type": "image_url", + "image_url": {{ + "url": "{self.image_url}" + }} + }}, + {{ + "type": "text", + "text": "Whats in this picture?" + }} + ] + }} + ] +}} + """ + + req = ChatCompletionRequest.model_validate_json(json_str) + + conv = generate_chat_conv(req, template_name=self.chat_template) + + text = conv.get_prompt() + + # Process inputs using processor + # FIXME: the formal arguments may differ + inputs = self.processor( + text=[text], + images=[self.main_image], + return_tensors="pt", + ).to(self.device) + + return inputs + + def get_sglang_model(self): + model_runner = ModelRunner( + model_config=ModelConfig(self.model_path, model_override_args="{}"), + mem_fraction_static=0.8, + gpu_id=0, + tp_rank=0, + tp_size=1, + nccl_port=12435, + server_args=ServerArgs( + model_path=self.model_path, + disable_cuda_graph=True, + ), + ) + return model_runner.model + + +class TestMiniCPMVLogits(VisionLLMLogitsBase): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.model_path = MiniCPMV + cls.tokenizer = AutoTokenizer.from_pretrained( + cls.model_path, trust_remote_code=True + ) + cls.processor = AutoProcessor.from_pretrained( + cls.model_path, trust_remote_code=True + ) + cls.chat_template = "minicpmv" + + cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + cls.model = AutoModel.from_pretrained( + cls.model_path, torch_dtype=torch.bfloat16, trust_remote_code=True + ).eval() + cls.model.to(cls.device) + + async def test_encode_output(self): + inputs = self.get_processor_output() + + with torch.no_grad(): + model_inputs = { + "input_ids": inputs.input_ids, + "image_bound": inputs.image_bound, + "pixel_values": inputs.pixel_values, + "tgt_sizes": inputs.tgt_sizes, + } + (hf_output, _) = self.model.get_vllm_embedding( + model_inputs, + ) + hf_output = hf_output.squeeze(0) + + with torch.no_grad(): + model = self.get_sglang_model() + input_ids = inputs["input_ids"].to(self.device).flatten() + image_inputs = model._parse_and_validate_inputs( + input_ids=input_ids, + **{ + "pixel_values": [inputs["pixel_values"]], + "tgt_sizes": [inputs["tgt_sizes"]], + "im_start_id": [self.tokenizer.im_start_id], + "im_end_id": [self.tokenizer.im_end_id], + "slice_start_id": [self.tokenizer.slice_start_id], + "slice_end_id": [self.tokenizer.slice_end_id], + }, + ) + (sglang_output, _) = model.get_embedding( + input_ids=input_ids, image_inputs=image_inputs + ) + + self.compare_outputs(sglang_output, hf_output) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index 5be911ab84a..01762202882 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -180,7 +180,9 @@ def test_multi_images_chat_completion(self): assert response.usage.total_tokens > 0 def prepare_video_messages(self, video_path): - max_frames_num = 32 + # the memory consumed by the Vision Attention varies a lot, e.g. blocked qkv vs full-sequence sdpa + # the size of the video embeds differs from the `modality` argument when preprocessed + max_frames_num = 12 vr = VideoReader(video_path, ctx=cpu(0)) total_frame_num = len(vr) uniform_sampled_frames = np.linspace( From 20453cef6288bdeb5998dd2eee8955f13f88ffea Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Thu, 30 Jan 2025 02:01:23 -0800 Subject: [PATCH 07/12] [test] Lower number of top logprobs to get rid of `-inf` (#3212) --- .../sampling/penaltylib/test_srt_endpoint_with_penalizers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py index 5245905f79b..34565c9ff65 100644 --- a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py +++ b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py @@ -36,7 +36,7 @@ def tearDownClass(cls): def run_decode( self, return_logprob=True, - top_logprobs_num=5, + top_logprobs_num=3, return_text=True, n=1, **sampling_params, From c38b5fb4f45ad8dd1c4ad1b7b05170c87c0f3ea1 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Thu, 30 Jan 2025 19:32:21 +0800 Subject: [PATCH 08/12] update 3rdparty and rms norm for sgl-kernel (#3213) --- sgl-kernel/3rdparty/cutlass | 2 +- sgl-kernel/3rdparty/flashinfer | 2 +- sgl-kernel/pyproject.toml | 2 +- .../csrc/fused_add_rms_norm_kernel.cu | 113 +----------------- sgl-kernel/version.py | 2 +- 5 files changed, 8 insertions(+), 113 deletions(-) diff --git a/sgl-kernel/3rdparty/cutlass b/sgl-kernel/3rdparty/cutlass index b78588d1630..bdd641790ad 160000 --- a/sgl-kernel/3rdparty/cutlass +++ b/sgl-kernel/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit b78588d1630aa6643bf021613717bafb705df4ef +Subproject commit bdd641790ad49353b40ada41330552a78d2f8b5a diff --git a/sgl-kernel/3rdparty/flashinfer b/sgl-kernel/3rdparty/flashinfer index 4f1f08989c7..e5a3befbe3e 160000 --- a/sgl-kernel/3rdparty/flashinfer +++ b/sgl-kernel/3rdparty/flashinfer @@ -1 +1 @@ -Subproject commit 4f1f08989c71f92df181e346548c2ca48ae6daf5 +Subproject commit e5a3befbe3e63025f0158bc96b218a9c5f402ac7 diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index aca6f045054..bb7d6943348 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.0.3" +version = "0.0.3.post1" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.9" diff --git a/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu index 4c4ecb966ee..f0f3a51744e 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/fused_add_rms_norm_kernel.cu @@ -1,116 +1,11 @@ -// Adapted from https://github.com/flashinfer-ai/flashinfer/blob/v0.1.6/include/flashinfer/norm.cuh -// and https://github.com/flashinfer-ai/flashinfer/blob/v0.1.6/python/csrc/norm.cu -// TODO(zhyncs): tmp fix, v0.1.6 enables SGLang e2e to pass CIs unlike v0.2.0 - #include -#include -#include -#include -#include +#include #include "utils.h" using namespace flashinfer; -template -__global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ residual, T* __restrict__ weight, - const uint32_t d, float eps) { - const uint32_t bx = blockIdx.x; - const uint32_t tx = threadIdx.x, ty = threadIdx.y; - constexpr uint32_t warp_size = 32; - const uint32_t num_warps = blockDim.y; - const uint32_t thread_id = tx + ty * warp_size; - const uint32_t num_threads = num_warps * warp_size; - const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads); - extern __shared__ float smem[]; - - float sum_sq = 0.f; - - for (uint32_t i = 0; i < rounds; i++) { - vec_t input_vec; - input_vec.fill(0.f); - vec_t residual_vec; - residual_vec.fill(0.f); - if ((i * num_threads + thread_id) * VEC_SIZE < d) { - input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - } -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; j++) { - float x = float(input_vec[j]); - x += float(residual_vec[j]); - sum_sq += x * x; - residual_vec[j] = (T)x; - } - if ((i * num_threads + thread_id) * VEC_SIZE < d) { - residual_vec.store(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - } - } - - // first, warp reduce sum -#pragma unroll - for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) { - sum_sq += math::shfl_xor_sync(sum_sq, offset); - } - - smem[ty] = sum_sq; - __syncthreads(); - // then, cross warp reduce sum using only the first warp - if (ty == 0) { - sum_sq = (tx < num_warps) ? smem[tx] : 0.f; -#pragma unroll - for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) { - sum_sq += math::shfl_xor_sync(sum_sq, offset); - } - smem[0] = sum_sq; - } - __syncthreads(); - - float rms_rcp = math::rsqrt(smem[0] / float(d) + eps); - - for (uint32_t i = 0; i < rounds; i++) { - vec_t input_vec; - vec_t weight_vec; - vec_t residual_vec; - input_vec.fill(0.f); - weight_vec.fill(0.f); - residual_vec.fill(0.f); - if ((i * num_threads + thread_id) * VEC_SIZE < d) { - input_vec.load(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - residual_vec.load(residual + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - } -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; j++) { - input_vec[j] = float(residual_vec[j]) * rms_rcp * float(weight_vec[j]); - } - if ((i * num_threads + thread_id) * VEC_SIZE < d) { - input_vec.store(input + bx * d + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - } - } -} - -template -cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d, float eps = 1e-5, - cudaStream_t stream = 0) { - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - - const uint32_t block_size = std::min(1024, d / vec_size); - const uint32_t num_warps = ceil_div(block_size, 32); - dim3 nblks(batch_size); - dim3 nthrs(32, num_warps); - const uint32_t smem_size = num_warps * sizeof(float); - void* args[] = {&input, &residual, &weight, &d, &eps}; - - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = FusedAddRMSNormKernel; - FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); - }); - - return cudaSuccess; -} - void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps) { CHECK_INPUT(input); CHECK_INPUT(residual); @@ -130,9 +25,9 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream(); // support float16, bfloat16 and float32 DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { - cudaError_t status = - FusedAddRMSNorm(static_cast(input.data_ptr()), static_cast(residual.data_ptr()), - static_cast(weight.data_ptr()), batch_size, hidden_size, eps, torch_current_stream); + cudaError_t status = norm::FusedAddRMSNorm( + static_cast(input.data_ptr()), static_cast(residual.data_ptr()), + static_cast(weight.data_ptr()), batch_size, hidden_size, eps, torch_current_stream); TORCH_CHECK(status == cudaSuccess, "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status))); return true; diff --git a/sgl-kernel/version.py b/sgl-kernel/version.py index 27fdca497c3..647733203b6 100644 --- a/sgl-kernel/version.py +++ b/sgl-kernel/version.py @@ -1 +1 @@ -__version__ = "0.0.3" +__version__ = "0.0.3.post1" From 468d23cff971b3174c37938f74a007646f9cfb78 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Thu, 30 Jan 2025 19:47:50 +0800 Subject: [PATCH 09/12] update setup for sgl-kernel (#3214) --- sgl-kernel/setup.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index f887f5c19f0..90c3cbc1d3c 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -1,5 +1,6 @@ import multiprocessing import os +import sys from pathlib import Path import torch @@ -9,14 +10,8 @@ root = Path(__file__).parent.resolve() -def _update_wheel_platform_tag(): - wheel_dir = Path("dist") - if wheel_dir.exists() and wheel_dir.is_dir(): - old_wheel = next(wheel_dir.glob("*.whl")) - new_wheel = wheel_dir / old_wheel.name.replace( - "linux_x86_64", "manylinux2014_x86_64" - ) - old_wheel.rename(new_wheel) +if "bdist_wheel" in sys.argv and "--plat-name" not in sys.argv: + sys.argv.extend(["--plat-name", "manylinux2014_x86_64"]) def _get_cuda_version(): @@ -162,5 +157,3 @@ def _get_version(): }, options={"bdist_wheel": {"py_limited_api": "cp39"}}, ) - -_update_wheel_platform_tag() From 222ce6f1da31b6bfe168513ff85b2d5cad34fb85 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Thu, 30 Jan 2025 23:04:41 +0800 Subject: [PATCH 10/12] add tensorrt_llm common and cutlass_extensions as 3rdparty (#3216) Co-authored-by: BBuf <35585791+BBuf@users.noreply.github.com> --- .clang-format-ignore | 1 + .../tensorrt_llm/common/CMakeLists.txt | 22 + .../3rdparty/tensorrt_llm/common/assert.cpp | 34 + .../tensorrt_llm/common/cublasMMWrapper.cpp | 360 +++++++ .../tensorrt_llm/common/cublasMMWrapper.h | 148 +++ .../tensorrt_llm/common/cublasVersionCheck.h | 35 + .../tensorrt_llm/common/cudaBf16Fallbacks.cuh | 313 ++++++ .../tensorrt_llm/common/cudaDriverWrapper.cpp | 187 ++++ .../tensorrt_llm/common/cudaDriverWrapper.h | 138 +++ .../tensorrt_llm/common/cudaFp8Utils.cu | 436 +++++++++ .../tensorrt_llm/common/cudaProfilerUtils.cpp | 84 ++ .../tensorrt_llm/common/cudaTypeUtils.cuh | 752 +++++++++++++++ .../common/customAllReduceUtils.h | 36 + .../3rdparty/tensorrt_llm/common/envUtils.cpp | 214 +++++ .../3rdparty/tensorrt_llm/common/envUtils.h | 60 ++ .../3rdparty/tensorrt_llm/common/logger.cpp | 70 ++ .../3rdparty/tensorrt_llm/common/mathUtils.h | 37 + .../tensorrt_llm/common/memoryUtils.cu | 906 ++++++++++++++++++ .../tensorrt_llm/common/memoryUtils.h | 292 ++++++ .../3rdparty/tensorrt_llm/common/mpiUtils.cpp | 588 ++++++++++++ .../3rdparty/tensorrt_llm/common/nvtxUtils.h | 46 + .../3rdparty/tensorrt_llm/common/opUtils.cpp | 323 +++++++ .../3rdparty/tensorrt_llm/common/opUtils.h | 215 +++++ .../tensorrt_llm/common/quantTypeUtils.cuh | 55 ++ .../tensorrt_llm/common/reduceKernelUtils.cuh | 399 ++++++++ .../3rdparty/tensorrt_llm/common/stlUtils.h | 123 +++ .../tensorrt_llm/common/stringUtils.cpp | 76 ++ .../tensorrt_llm/common/timestampUtils.cpp | 42 + .../tensorrt_llm/common/timestampUtils.h | 25 + .../tensorrt_llm/common/tllmException.cpp | 105 ++ .../3rdparty/tensorrt_llm/common/workspace.h | 87 ++ .../arch/copy_red_global.hpp | 352 +++++++ .../include/cutlass_extensions/arch/mma.h | 120 +++ .../cutlass_extensions/compute_occupancy.h | 88 ++ .../collective/epilogue_moe_finalize.hpp | 550 +++++++++++ .../epilogue/thread/fused_activations.h | 105 ++ .../epilogue_per_row_per_col_scale.h | 352 +++++++ .../threadblock/epilogue_tensor_op_int32.h | 282 ++++++ .../cutlass_extensions/epilogue_helpers.h | 141 +++ .../builders/sm90_gmma_builder_gated.inl | 221 +++++ .../collective/collective_builder_gated.hpp | 58 ++ .../gemm/collective/collective_mma_gated.hpp | 59 ++ ..._mma_gated_tma_gmma_ss_warpspecialized.hpp | 642 +++++++++++++ ..._gated_tma_gmma_ss_warpspecialized_fp8.hpp | 665 +++++++++++++ .../gemm/device/gemm_universal_base_compat.h | 438 +++++++++ .../gemm/device/splitk_gemm_grouped.h | 542 +++++++++++ .../gemm/kernel/default_fpA_intB_traits.h | 162 ++++ .../gemm/kernel/default_int8_traits.h | 57 ++ .../gemm/kernel/default_splitk_gemm_grouped.h | 207 ++++ .../gemm/kernel/fpA_intB_gemm.h | 566 +++++++++++ .../gemm/kernel/fused_moe_kernel.cuh | 218 +++++ .../gemm/kernel/fused_moe_kernel_routine.cuh | 799 +++++++++++++++ .../gemm/kernel/fused_moe_kernel_traits.cuh | 215 +++++ .../gemm/kernel/gemm_moe_problem_visitor.h | 73 ++ .../gemm/kernel/gemm_universal_gated.hpp | 70 ++ .../gemm/kernel/gemm_with_epilogue_visitor.h | 585 +++++++++++ .../gemm/kernel/mixed_gemm_B_layout.h | 143 +++ .../gemm/kernel/moe_cute_util.cuh | 185 ++++ .../gemm/kernel/moe_cutlass_kernel.h | 553 +++++++++++ .../gemm/kernel/moe_problem_visitor.h | 344 +++++++ ..._gated_tma_warpspecialized_cooperative.hpp | 646 +++++++++++++ ...emm_gated_tma_warpspecialized_pingpong.hpp | 621 ++++++++++++ .../gemm/kernel/splitk_gemm_grouped.h | 494 ++++++++++ .../gemm/threadblock/default_dq_mma.h | 125 +++ .../threadblock/default_dq_mma_multistage.h | 302 ++++++ .../threadblock/default_dq_mma_pipelined.h | 284 ++++++ .../gemm/threadblock/default_mma.h | 351 +++++++ .../gemm/threadblock/default_mma_bf16.h | 353 +++++++ .../gemm/threadblock/dq_mma_base.h | 257 +++++ .../gemm/threadblock/dq_mma_multistage.h | 110 +++ .../dq_mma_multistage_finegrained.h | 708 ++++++++++++++ .../threadblock/dq_mma_multistage_percol.h | 647 +++++++++++++ .../gemm/threadblock/dq_mma_pipelined.h | 106 ++ .../dq_mma_pipelined_finegrained.h | 486 ++++++++++ .../threadblock/dq_mma_pipelined_percol.h | 399 ++++++++ .../gemm/warp/default_mma_tensor_op.h | 107 +++ .../warp/mma_tensorop_compute_B_with_f16.h | 306 ++++++ .../gemm/warp/mma_tensorop_dequantizer.h | 463 +++++++++ .../include/cutlass_extensions/gemm_configs.h | 224 +++++ .../interleaved_numeric_conversion.h | 447 +++++++++ .../tile_interleaved_layout.h | 66 ++ .../fine_grained_scale_zero_iterator.h | 250 +++++ .../cutlass_extensions/util/gather_tensor.hpp | 181 ++++ .../cutlass_extensions/weight_only_quant_op.h | 58 ++ sgl-kernel/THIRDPARTYNOTICES.txt | 205 ++++ sgl-kernel/setup.py | 4 + 86 files changed, 23201 insertions(+) create mode 100644 .clang-format-ignore create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/CMakeLists.txt create mode 100755 sgl-kernel/3rdparty/tensorrt_llm/common/assert.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cublasVersionCheck.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Fallbacks.cuh create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaProfilerUtils.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaTypeUtils.cuh create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/customAllReduceUtils.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/logger.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/mathUtils.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/mpiUtils.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/nvtxUtils.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/quantTypeUtils.cuh create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/reduceKernelUtils.cuh create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/stlUtils.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/workspace.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/copy_red_global.hpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/mma.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/compute_occupancy.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_builder_gated.hpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_mma_gated.hpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/splitk_gemm_grouped.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_int8_traits.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cute_util.cuh create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_problem_visitor.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/weight_only_quant_op.h diff --git a/.clang-format-ignore b/.clang-format-ignore new file mode 100644 index 00000000000..15c76cc457f --- /dev/null +++ b/.clang-format-ignore @@ -0,0 +1 @@ +sgl-kernel/3rdparty/tensorrt_llm/* diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/CMakeLists.txt b/sgl-kernel/3rdparty/tensorrt_llm/common/CMakeLists.txt new file mode 100644 index 00000000000..e479b298db4 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/CMakeLists.txt @@ -0,0 +1,22 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & +# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# +file(GLOB SRCS *.cpp) +file(GLOB CU_SRCS *.cu) + +add_library(common_src OBJECT ${SRCS} ${CU_SRCS}) +set_property(TARGET common_src PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET common_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/assert.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/assert.cpp new file mode 100755 index 00000000000..eaaf6624472 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/assert.cpp @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/assert.h" + +namespace +{ + +bool initCheckDebug() +{ + auto constexpr kDebugEnabled = "TLLM_DEBUG_MODE"; + auto const debugEnabled = std::getenv(kDebugEnabled); + return debugEnabled && debugEnabled[0] == '1'; +} +} // namespace + +bool DebugConfig::isCheckDebugEnabled() +{ + static bool const debugEnabled = initCheckDebug(); + return debugEnabled; +} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.cpp new file mode 100644 index 00000000000..351257f4d2e --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.cpp @@ -0,0 +1,360 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/cublasMMWrapper.h" +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cublasVersionCheck.h" +#include + +#ifndef CUDART_VERSION +#error CUDART_VERSION Undefined! +#endif + +namespace tensorrt_llm +{ +namespace common +{ + +CublasMMWrapper::CublasMMWrapper(std::shared_ptr cublasHandle, + std::shared_ptr cublasltHandle, cudaStream_t stream, void* workspace) + : mCublasHandle(cublasHandle) + , mCublasLtHandle(cublasltHandle) + , mStream(stream) + , mCublasWorkspace(workspace) +{ +} + +CublasMMWrapper::~CublasMMWrapper() {} + +CublasMMWrapper::CublasMMWrapper(CublasMMWrapper const& wrapper) + : mCublasHandle(wrapper.mCublasHandle) + , mCublasLtHandle(wrapper.mCublasLtHandle) + , mStream(wrapper.mStream) +{ +} + +void CublasMMWrapper::createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, + int const k, int const lda, int const ldb, int const ldc, int8_t fastAcc) +{ + // -------------------------------------- + // Create descriptors for the original matrices + check_cuda_error( + cublasLtMatrixLayoutCreate(&mADesc, mAType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda)); + check_cuda_error( + cublasLtMatrixLayoutCreate(&mBDesc, mBType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb)); + check_cuda_error(cublasLtMatrixLayoutCreate(&mCDesc, mCType, m, n, ldc)); + check_cuda_error(cublasLtMatmulDescCreate(&mOperationDesc, mComputeType, mScaleType)); + check_cuda_error(cublasLtMatmulDescSetAttribute( + mOperationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t))); + check_cuda_error(cublasLtMatmulDescSetAttribute( + mOperationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t))); + check_cuda_error( + cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAcc, sizeof(int8_t))); +} + +void CublasMMWrapper::setScaleDescriptors(void* scale_a, void* scale_b) +{ + check_cuda_error( + cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scale_a, sizeof(void*))); + check_cuda_error( + cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scale_b, sizeof(void*))); +} + +void CublasMMWrapper::destroyDescriptors() +{ + check_cuda_error(cublasLtMatmulDescDestroy(mOperationDesc)); + check_cuda_error(cublasLtMatrixLayoutDestroy(mADesc)); + check_cuda_error(cublasLtMatrixLayoutDestroy(mBDesc)); + check_cuda_error(cublasLtMatrixLayoutDestroy(mCDesc)); + mOperationDesc = NULL; + mADesc = NULL; + mBDesc = NULL; + mCDesc = NULL; +} + +void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc) +{ + Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f); +} + +void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, + std::optional const& heuristic) +{ + if (heuristic) + { + Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, /* hasAlgo */ (*heuristic).algo, + (*heuristic).state == CUBLAS_STATUS_SUCCESS && (*heuristic).workspaceSize < CUBLAS_WORKSPACE_SIZE, + /* usingCublasLt */ true); + } + else + { + Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, {}, /* hasAlgo */ false, + /* usingCublasLt */ true); + } +} + +void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta, + std::optional const& heuristic) +{ + if (heuristic) + { + Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, /* hasAlgo */ (*heuristic).algo, + (*heuristic).state == CUBLAS_STATUS_SUCCESS && (*heuristic).workspaceSize < CUBLAS_WORKSPACE_SIZE, + /* usingCublasLt */ true); + } + else + { + Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, {}, /* hasAlgo */ false, + /* usingCublasLt */ true); + } +} + +void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta) +{ + bool usingCublasLt = mAType == CUDA_R_16F || mAType == CUDA_R_8F_E4M3; + + Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, {}, /* hasAlgo */ false, + /* usingCublasLt */ usingCublasLt); +} + +void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta, + cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt) +{ + half h_alpha = (half) (f_alpha); + half h_beta = (half) (f_beta); + + // TODO: default cublas libs + usingCublasLt = usingCublasLt && (mAType == CUDA_R_16F || mAType == CUDA_R_8F_E4M3); + bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F; + int batch_count = 1; + // fp32 use cublas as default + // fp16 use cublasLt as default + void const* alpha = isFp16ComputeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); + void const* beta = isFp16ComputeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); + int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE; + + if (usingCublasLt) + { + if (hasAlgo) + { + hasAlgo = checkTactic(transa, transb, m, n, k, lda, ldb, ldc, algo); + } + + check_cuda_error(cublasLtMatmul(getCublasLtHandle(), mOperationDesc, alpha, A, mADesc, B, mBDesc, beta, C, + mCDesc, C, mCDesc, (hasAlgo ? (&algo) : NULL), mCublasWorkspace, workspaceSize, mStream)); + + sync_check_cuda_error(); + } + else + { + check_cuda_error(cublasSetStream(getCublasHandle(), mStream)); + check_cuda_error(cublasSetWorkspace(getCublasHandle(), mCublasWorkspace, workspaceSize)); + // Go with default heuristic to choose tactic as cuBLAS does not allow to choose tactics in Ampere+ + cublasGemmAlgo_t cublasAlgo = CUBLAS_GEMM_DEFAULT; + check_cuda_error(cublasGemmEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda, B, mBType, ldb, + beta, C, mCType, ldc, mComputeType, static_cast(cublasAlgo))); + sync_check_cuda_error(); + } +} + +void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, + int const k, void const* A, int const lda, const int64_t strideA, void const* B, int const ldb, + const int64_t strideB, void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha, + float const f_beta) +{ + half h_alpha = (half) f_alpha; + half h_beta = (half) f_beta; + + int isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0; + void const* alpha = isFp16ComputeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); + void const* beta = isFp16ComputeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); + + check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda, + strideA, B, mBType, ldb, strideB, beta, C, mCType, ldc, strideC, batchCount, mComputeType, + mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP)); +} + +void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, + int const k, float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA, + void const* B, cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C, + cudaDataType_t CType, int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType) +{ + half h_alpha = (half) f_alpha; + half h_beta = (half) f_beta; + + bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0; + void const* alpha = isFp16ComputeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); + void const* beta = isFp16ComputeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); + + check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, AType, lda, + strideA, B, BType, ldb, strideB, beta, C, CType, ldc, strideC, batchCount, computeType, + mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP)); +} + +void CublasMMWrapper::setWorkspace(void* workspace) +{ + mCublasWorkspace = workspace; +} + +void CublasMMWrapper::setFP32GemmConfig() +{ + setGemmConfig(CUDA_R_32F, CUDA_R_32F, CUDA_R_32F, CUDA_R_32F); +} + +void CublasMMWrapper::setFP16GemmConfig(cudaDataType_t outputType) +{ + setGemmConfig(CUDA_R_16F, CUDA_R_16F, outputType, CUDA_R_32F); +} + +#ifdef ENABLE_BF16 +void CublasMMWrapper::setBF16GemmConfig(cudaDataType_t outputType) +{ + setGemmConfig(CUDA_R_16BF, CUDA_R_16BF, outputType, CUDA_R_32F); +} +#endif + +#ifdef ENABLE_FP8 +void CublasMMWrapper::setFP8GemmConfig(cudaDataType_t outputType) +{ + setGemmConfig(CUDA_R_8F_E4M3, CUDA_R_8F_E4M3, outputType, CUDA_R_32F); +} +#endif + +void CublasMMWrapper::setGemmConfig( + cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType) +{ + mAType = aType; + mBType = bType; + mCType = cType; + bool isFp16ComputeType = computeType == CUDA_R_16F; + if (isFp16ComputeType) + { + mComputeType = CUBLAS_COMPUTE_16F; + mScaleType = CUDA_R_16F; + } + else + { + mComputeType = CUBLAS_COMPUTE_32F; + mScaleType = CUDA_R_32F; + } +} + +CublasDataType CublasMMWrapper::getCublasDataType(cudaDataType_t data_type) +{ + if (data_type == CUDA_R_16F) + { + return HALF_DATATYPE; + } + else if (data_type == CUDA_R_32F) + { + return FLOAT_DATATYPE; + } + else if (data_type == CUDA_R_8I) + { + return INT8_DATATYPE; + } +#ifdef ENABLE_BF16 + else if (data_type == CUDA_R_16BF) + { + return BFLOAT16_DATATYPE; + } +#endif + return FLOAT_DATATYPE; +} + +void CublasMMWrapper::setStream(cudaStream_t stream) +{ + mStream = stream; +} + +bool CublasMMWrapper::checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, + int const k, int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo) +{ + TLLM_CHECK_WITH_INFO( + descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function"); + + int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE; + + cublasLtMatmulHeuristicResult_t heurResult; + cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck( + getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc, &algo, &heurResult); + + if (algoStatus != CUBLAS_STATUS_SUCCESS || heurResult.state != CUBLAS_STATUS_SUCCESS + || heurResult.workspaceSize > CUBLAS_WORKSPACE_SIZE) + { + return false; + } + + sync_check_cuda_error(); + + return true; +} + +std::vector CublasMMWrapper::getTactics(cublasOperation_t transa, + cublasOperation_t transb, int const m, int const n, int const k, int const lda, int const ldb, int const ldc) +{ + TLLM_CHECK_WITH_INFO( + descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function"); + + auto const heuristics = getTactics(getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc); + + sync_check_cuda_error(); + + return heuristics; +} + +std::vector CublasMMWrapper::getTactics(cublasLtHandle_t lightHandle, + cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc, + cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc) +{ +#if TLLM_CUBLAS_VER_LE(11, 4, 2) + TLLM_CHECK_WITH_INFO(false, "CUBLAS version too low, must be > 11.4.2."); + return {}; +#else + std::vector heuristics(200); + cublasLtMatmulPreference_t preference; + check_cuda_error(cublasLtMatmulPreferenceCreate(&preference)); + check_cuda_error(cublasLtMatmulPreferenceInit(preference)); + uint64_t workspace_size = CUBLAS_WORKSPACE_SIZE; + check_cuda_error(cublasLtMatmulPreferenceSetAttribute( + preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size))); + // Restrict reduction algorithms for numerical stability and better determinism + uint32_t reduction_mask = CUBLASLT_REDUCTION_SCHEME_MASK; + check_cuda_error(cublasLtMatmulPreferenceSetAttribute( + preference, CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, &reduction_mask, sizeof(reduction_mask))); +#if TLLM_CUBLAS_VER_LT(12, 0, 0) + uint32_t pointer_mode_mask = 0; + check_cuda_error(cublasLtMatmulPreferenceSetAttribute( + preference, CUBLASLT_MATMUL_PREF_EPILOGUE_MASK, &pointer_mode_mask, sizeof(pointer_mode_mask))); +#endif + + int return_count = 0; + check_cuda_error(cublasLtMatmulAlgoGetHeuristic(lightHandle, computeDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, + heuristics.size(), heuristics.data(), &return_count)); + heuristics.resize(return_count); + + return heuristics; +#endif +} + +} // namespace common + +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.h new file mode 100644 index 00000000000..79b7c92a47d --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.h @@ -0,0 +1,148 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/common/cudaUtils.h" +#include +#include +#include +#include +#include +#include + +namespace tensorrt_llm +{ +namespace common +{ + +class CublasMMWrapper +{ +protected: + std::shared_ptr mCublasHandle; + std::shared_ptr mCublasLtHandle; + + cudaDataType_t mAType{}; + cudaDataType_t mBType{}; + cudaDataType_t mCType{}; + cublasComputeType_t mComputeType{}; + cudaDataType_t mScaleType{}; + + cublasLtMatmulDesc_t mOperationDesc{NULL}; + cublasLtMatrixLayout_t mADesc{NULL}; + cublasLtMatrixLayout_t mBDesc{NULL}; + cublasLtMatrixLayout_t mCDesc{NULL}; + + cudaStream_t mStream; + + void* mCublasWorkspace = nullptr; + +private: + bool descriptorsCreated() const + { + return mOperationDesc != NULL && mADesc != NULL && mBDesc != NULL && mCDesc != NULL; + } + +public: + CublasMMWrapper(std::shared_ptr cublasHandle, std::shared_ptr cublasLtHandle, + cudaStream_t stream, void* workspace); + + ~CublasMMWrapper(); + + CublasMMWrapper(CublasMMWrapper const& wrapper); + + /********************** GEMMs **********************/ + void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A, + int const lda, void const* B, int const ldb, void* C, int const ldc); + + void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A, + int const lda, void const* B, int const ldb, void* C, int const ldc, + std::optional const& algo); + + void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A, + int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta, + std::optional const& algo); + + void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A, + int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta); + + void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A, + int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta, + cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt); + + void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + void const* A, int const lda, const int64_t strideA, void const* B, int const ldb, const int64_t strideB, + void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha = 1.0f, + float const f_beta = 0.0f); + + void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA, void const* B, + cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C, cudaDataType_t CType, + int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType); + + /********************** Tactic selection helpers **********************/ + bool checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo); + + std::vector getTactics(cublasOperation_t transa, cublasOperation_t transb, + int const m, int const n, int const k, int const lda, int const ldb, int const ldc); + + std::vector getTactics(cublasLtHandle_t lightHandle, + cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc, + cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc); + + using MatrixLayout = std::tuple; + using cache_idx_t = std::tuple>; + + MatrixLayout createMatrixLayout(cublasLtMatrixLayout_t Mdesc); + + /********************** Utils **********************/ + void setWorkspace(void* workspace); + + void setFP32GemmConfig(); + void setFP16GemmConfig(cudaDataType_t outputType = CUDA_R_16F); +#ifdef ENABLE_BF16 + void setBF16GemmConfig(cudaDataType_t outputType = CUDA_R_16BF); +#endif +#ifdef ENABLE_FP8 + void setFP8GemmConfig(cudaDataType_t outputType = CUDA_R_16F); +#endif + + void setStream(cudaStream_t stream); + + void setGemmConfig(cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType); + + CublasDataType getCublasDataType(cudaDataType_t data_type); + + void createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + int const lda, int const ldb, int const ldc, int8_t fastAcc = 0); + void setScaleDescriptors(void* scale_a, void* scale_b); + void destroyDescriptors(); + + cublasHandle_t getCublasHandle() + { + return *(this->mCublasHandle); + } + + cublasLtHandle_t getCublasLtHandle() const + { + return *(this->mCublasLtHandle); + } +}; + +} // namespace common + +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cublasVersionCheck.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cublasVersionCheck.h new file mode 100644 index 00000000000..1ee72c63566 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cublasVersionCheck.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// We don't want to include cublas_api.h. It contains the CUBLAS_VER_* macro +// definition which is not sufficient to determine if we include cublas.h, +// cublas_v2.h or cublasLt.h. + +#define TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) (MAJOR * 10000 + MINOR * 100 + PATCH) +#define TLLM_CUBLAS_VER_LE(MAJOR, MINOR, PATCH) \ + TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \ + <= TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) +#define TLLM_CUBLAS_VER_LT(MAJOR, MINOR, PATCH) \ + TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \ + < TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) +#define TLLM_CUBLAS_VER_GE(MAJOR, MINOR, PATCH) \ + TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \ + >= TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) +#define TLLM_CUBLAS_VER_GT(MAJOR, MINOR, PATCH) \ + TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \ + > TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Fallbacks.cuh b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Fallbacks.cuh new file mode 100644 index 00000000000..0519251e6fd --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Fallbacks.cuh @@ -0,0 +1,313 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/common/cudaBf16Wrapper.h" +#include +#include + +namespace tensorrt_llm +{ +namespace common +{ + +#ifdef ENABLE_BF16 +inline __device__ float2 bf1622float2(const __nv_bfloat162 val) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float2 f_val; + f_val.x = __low2float(val); + f_val.y = __high2float(val); + return f_val; +#else + return __bfloat1622float2(val); +#endif +} + +inline __device__ int16_t bf1622int16(__nv_bfloat162 val) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float2 f_val; + f_val.x = max(min(__low2float(val), 127.f), -128.f); + f_val.y = max(min(__high2float(val), 127.f), -128.f); + + union + { + int8_t int8[2]; + int16_t int16; + }; + + int8[0] = static_cast(static_cast(f_val.x)); + int8[1] = static_cast(static_cast(f_val.y)); + return int16; +#else + val = __hmin2(val, make_bfloat162(127., 127.)); + val = __hmax2(val, make_bfloat162(-128., -128.)); + + union + { + int8_t int8[2]; + int16_t int16; + }; + + int8[0] = static_cast(static_cast(val.x)); + int8[1] = static_cast(static_cast(val.y)); + return int16; +#endif +} + +inline __device__ __nv_bfloat162 float22bf162(const float2 val) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __floats2bfloat162_rn(val.x, val.y); +#else + return __float22bfloat162_rn(val); +#endif +} + +inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + __nv_bfloat162 val2; + val2.x = val; + val2.y = val; + return val2; +#else + return __bfloat162bfloat162(val); +#endif +} + +inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl + fyl, fxh + fyh); +#else + return __hadd2(x, y); +#endif +} + +inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(x) + __bfloat162float(y)); +#else + return __hadd(x, y); +#endif +} + +inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl - fyl, fxh - fyh); +#else + return __hsub2(x, y); +#endif +} + +inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(x) - __bfloat162float(y)); +#else + return __hsub(x, y); +#endif +} + +inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl * fyl, fxh * fyh); +#else + return __hmul2(x, y); +#endif +} + +inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y)); +#else + return __hmul(x, y); +#endif +} + +inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh, fzl, fzh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + fzl = __low2float(z); + fzh = __high2float(z); + return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh); +#else + return __hfma2(x, y, z); +#endif +} + +inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z)); +#else + return __hfma(x, y, z); +#endif +} + +inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh; + fxl = __low2float(x); + fxh = __high2float(x); + ; + return __floats2bfloat162_rn(expf(fxl), expf(fxh)); +#else + return h2exp(x); +#endif +} + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) +#if defined(CUDART_VERSION) && (CUDART_VERSION < 12020) + +inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y) +{ + __nv_bfloat162 t; + t.x = x; + t.y = y; + return t; +} +#endif +#endif + +inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c)); +#else + return a + b + c; +#endif +} + +inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d)); +#else + return (__nv_bfloat16) ((float) a + (float) b + (float) c + (float) d); +#endif +} + +inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fal, fah, fbl, fbh, fcl, fch; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch); +#else + return a + b + c; +#endif +} + +inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c)); +#else + return a * b * c; +#endif +} + +inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fal, fah, fbl, fbh, fcl, fch; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch); +#else + return a * b * c; +#endif +} + +inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fal, fah, fbl, fbh, fcl, fch, fdl, fdh; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + fdl = __low2float(d); + fdh = __high2float(d); + return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh); +#else + return a * b * c + d; +#endif +} + +#endif // ENABLE_BF16 + +} // namespace common +} // namespace tensorrt_llm + +// Operator definitions intentionally in global namespace +namespace +{ +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) +#if defined(CUDART_VERSION) && (CUDART_VERSION < 12020) + +inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y) +{ + return tensorrt_llm::common::bf16hmul2(x, y); +}; + +inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y) +{ + return tensorrt_llm::common::bf16hadd2(x, y); +}; +#endif +#endif +} // namespace diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp new file mode 100644 index 00000000000..7eca46a1cab --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define CUDA_LIB_NAME "cuda" + +#if defined(_WIN32) +#include +#define dllOpen(name) LoadLibrary("nv" name ".dll") +#define dllClose(handle) FreeLibrary(static_cast(handle)) +#define dllGetSym(handle, name) static_cast(GetProcAddress(static_cast(handle), name)) +#else // For non-Windows platforms +#include +#define dllOpen(name) dlopen("lib" name ".so.1", RTLD_LAZY) +#define dllClose(handle) dlclose(handle) +#define dllGetSym(handle, name) dlsym(handle, name) +#endif // defined(_WIN32) + +#include "cudaDriverWrapper.h" +#include "tensorrt_llm/common/assert.h" +#include +#include + +namespace tensorrt_llm::common +{ + +std::shared_ptr CUDADriverWrapper::getInstance() +{ + static std::mutex mutex; + static std::weak_ptr instance; + std::shared_ptr result = instance.lock(); + if (result) + { + return result; + } + + std::lock_guard lock(mutex); + result = instance.lock(); + if (!result) + { + result = std::shared_ptr(new CUDADriverWrapper()); + instance = result; + } + return result; +} + +CUDADriverWrapper::CUDADriverWrapper() + : handle(dllOpen(CUDA_LIB_NAME)) +{ + + TLLM_CHECK_WITH_INFO(handle != nullptr, "CUDA driver library is not open correctly."); + + auto load_sym = [](void* handle, char const* name) + { + void* ret = dllGetSym(handle, name); + return ret; + }; + + *reinterpret_cast(&_cuGetErrorName) = load_sym(handle, "cuGetErrorName"); + *reinterpret_cast(&_cuGetErrorMessage) = load_sym(handle, "cuGetErrorMessage"); + *reinterpret_cast(&_cuFuncSetAttribute) = load_sym(handle, "cuFuncSetAttribute"); + *reinterpret_cast(&_cuLinkComplete) = load_sym(handle, "cuLinkComplete"); + *reinterpret_cast(&_cuModuleUnload) = load_sym(handle, "cuModuleUnload"); + *reinterpret_cast(&_cuLinkDestroy) = load_sym(handle, "cuLinkDestroy"); + *reinterpret_cast(&_cuModuleLoadData) = load_sym(handle, "cuModuleLoadData"); + *reinterpret_cast(&_cuLinkCreate) = load_sym(handle, "cuLinkCreate_v2"); + *reinterpret_cast(&_cuModuleGetFunction) = load_sym(handle, "cuModuleGetFunction"); + *reinterpret_cast(&_cuModuleGetGlobal) = load_sym(handle, "cuModuleGetGlobal_v2"); + *reinterpret_cast(&_cuLinkAddFile) = load_sym(handle, "cuLinkAddFile_v2"); + *reinterpret_cast(&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2"); + *reinterpret_cast(&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel"); + *reinterpret_cast(&_cuLaunchKernel) = load_sym(handle, "cuLaunchKernel"); + *reinterpret_cast(&_cuTensorMapEncodeTiled) = load_sym(handle, "cuTensorMapEncodeTiled"); + *reinterpret_cast(&_cuMemcpyDtoH) = load_sym(handle, "cuMemcpyDtoH_v2"); +} + +CUDADriverWrapper::~CUDADriverWrapper() +{ + dllClose(handle); +} + +CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, char const** pStr) const +{ + return (*_cuGetErrorName)(error, pStr); +} + +CUresult CUDADriverWrapper::cuGetErrorMessage(CUresult error, char const** pStr) const +{ + return (*_cuGetErrorMessage)(error, pStr); +} + +CUresult CUDADriverWrapper::cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const +{ + return (*_cuFuncSetAttribute)(hfunc, attrib, value); +} + +CUresult CUDADriverWrapper::cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const +{ + return (*_cuLinkComplete)(state, cubinOut, sizeOut); +} + +CUresult CUDADriverWrapper::cuModuleUnload(CUmodule hmod) const +{ + return (*_cuModuleUnload)(hmod); +} + +CUresult CUDADriverWrapper::cuLinkDestroy(CUlinkState state) const +{ + return (*_cuLinkDestroy)(state); +} + +CUresult CUDADriverWrapper::cuModuleLoadData(CUmodule* module, void const* image) const +{ + return (*_cuModuleLoadData)(module, image); +} + +CUresult CUDADriverWrapper::cuLinkCreate( + unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const +{ + return (*_cuLinkCreate)(numOptions, options, optionValues, stateOut); +} + +CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const +{ + return (*_cuModuleGetFunction)(hfunc, hmod, name); +} + +CUresult CUDADriverWrapper::cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const +{ + return (*_cuModuleGetGlobal)(dptr, bytes, hmod, name); +} + +CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, + unsigned int numOptions, CUjit_option* options, void** optionValues) const +{ + return (*_cuLinkAddFile)(state, type, path, numOptions, options, optionValues); +} + +CUresult CUDADriverWrapper::cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, + char const* name, unsigned int numOptions, CUjit_option* options, void** optionValues) const +{ + return (*_cuLinkAddData)(state, type, data, size, name, numOptions, options, optionValues); +} + +CUresult CUDADriverWrapper::cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, + unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, + unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const +{ + return (*_cuLaunchCooperativeKernel)( + f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams); +} + +CUresult CUDADriverWrapper::cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, + unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, + unsigned int sharedMemBytes, CUstream hStream, void** kernelParams, void** extra) const +{ + return (*_cuLaunchKernel)( + f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra); +} + +CUresult CUDADriverWrapper::cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, + cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const +{ + return (*_cuTensorMapEncodeTiled)(tensorMap, tensorDataType, tensorRank, globalAddress, globalDim, globalStrides, + boxDim, elementStrides, interleave, swizzle, l2Promotion, oobFill); +} + +CUresult CUDADriverWrapper::cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const +{ + return (*_cuMemcpyDtoH)(dstHost, srcDevice, ByteCount); +} + +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h new file mode 100644 index 00000000000..c4d470a85f0 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h @@ -0,0 +1,138 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CUDA_DRIVER_WRAPPER_H +#define CUDA_DRIVER_WRAPPER_H + +#include "tensorrt_llm/common/assert.h" +#include +#include +#include +#include + +namespace tensorrt_llm::common +{ + +class CUDADriverWrapper +{ +public: + static std::shared_ptr getInstance(); + + ~CUDADriverWrapper(); + CUDADriverWrapper(CUDADriverWrapper const&) = delete; + CUDADriverWrapper operator=(CUDADriverWrapper const&) = delete; + CUDADriverWrapper(CUDADriverWrapper&&) = delete; + CUDADriverWrapper operator=(CUDADriverWrapper&&) = delete; + + CUresult cuGetErrorName(CUresult error, char const** pStr) const; + + CUresult cuGetErrorMessage(CUresult error, char const** pStr) const; + + CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const; + + CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const; + + CUresult cuModuleUnload(CUmodule hmod) const; + + CUresult cuLinkDestroy(CUlinkState state) const; + + CUresult cuModuleLoadData(CUmodule* module, void const* image) const; + + CUresult cuLinkCreate( + unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const; + + CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const; + + CUresult cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const; + + CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, unsigned int numOptions, + CUjit_option* options, void** optionValues) const; + + CUresult cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, char const* name, + unsigned int numOptions, CUjit_option* options, void** optionValues) const; + + CUresult cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, + unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, + unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const; + + CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, + unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, + CUstream hStream, void** kernelParams, void** extra) const; + + CUresult cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, cuuint32_t tensorRank, + void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, cuuint32_t const* boxDim, + cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle, + CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const; + + CUresult cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const; + +private: + void* handle; + CUDADriverWrapper(); + + CUresult (*_cuGetErrorName)(CUresult, char const**); + CUresult (*_cuGetErrorMessage)(CUresult, char const**); + CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int); + CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*); + CUresult (*_cuModuleUnload)(CUmodule); + CUresult (*_cuLinkDestroy)(CUlinkState); + CUresult (*_cuLinkCreate)(unsigned int, CUjit_option*, void**, CUlinkState*); + CUresult (*_cuModuleLoadData)(CUmodule*, void const*); + CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, char const*); + CUresult (*_cuModuleGetGlobal)(CUdeviceptr*, size_t*, CUmodule, char const*); + CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, char const*, unsigned int, CUjit_option*, void**); + CUresult (*_cuLinkAddData)( + CUlinkState, CUjitInputType, void*, size_t, char const*, unsigned int, CUjit_option*, void**); + CUresult (*_cuLaunchCooperativeKernel)(CUfunction, unsigned int, unsigned int, unsigned int, unsigned int, + unsigned int, unsigned int, unsigned int, CUstream, void**); + CUresult (*_cuLaunchKernel)(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, + unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, + CUstream hStream, void** kernelParams, void** extra); + CUresult (*_cuTensorMapEncodeTiled)(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, + cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill); + CUresult (*_cuMemcpyDtoH)(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount); +}; + +template +void checkDriver( + T result, CUDADriverWrapper const& wrap, char const* const func, char const* const file, int const line) +{ + if (result) + { + char const* errorName = nullptr; + char const* errorMsg = nullptr; + wrap.cuGetErrorName(result, &errorName); + wrap.cuGetErrorMessage(result, &errorMsg); + throw TllmException( + file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA driver error in %s: %s: %s", func, errorName, errorMsg)); + } +} + +} // namespace tensorrt_llm::common + +/* + * Macros compliant with TensorRT coding conventions + */ +#define TLLM_CU_CHECK(stat) \ + do \ + { \ + tensorrt_llm::common::checkDriver( \ + (stat), *tensorrt_llm::common::CUDADriverWrapper::getInstance(), #stat, __FILE__, __LINE__); \ + } while (0) + +#endif // CUDA_DRIVER_WRAPPER_H diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.cu b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.cu new file mode 100644 index 00000000000..8e140609f2a --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.cu @@ -0,0 +1,436 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/cudaFp8Utils.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/reduceKernelUtils.cuh" +#include +#include +#include +#include +#include + +namespace tensorrt_llm +{ +namespace common +{ +#ifdef ENABLE_FP8 + +constexpr int CTA_SIZE = 256; + +template +__inline__ __device__ float scale(float a, float b) +{ + return QUANTIZE ? a / b : a * b; +} + +template +__global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda) +{ + for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < numel; i += blockDim.x * gridDim.x) + { + + if (QUANTIZE_MODE == QuantizeMode::PER_CHANNEL) + { + output[i] = T_OUT(scale(static_cast(input[i]), static_cast(input_scale[i % lda]))); + } + else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN) + { + output[i] = T_OUT(scale(static_cast(input[i]), static_cast(input_scale[i / lda]))); + } + else if (QUANTIZE_MODE == QuantizeMode::PER_TENSOR) + { + output[i] = T_OUT(scale(static_cast(input[i]), static_cast(input_scale[0]))); + } + } +} + +template +void invokeQuantizeMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda, + QuantizeMode quantize_mode, cudaStream_t stream) +{ + dim3 grid(1024); + dim3 block(CTA_SIZE); + if (quantize_mode == QuantizeMode::PER_CHANNEL) + { + scaleMatrix + <<>>(output, input_scale, input, numel, lda); + } + else if (quantize_mode == QuantizeMode::PER_TOKEN) + { + scaleMatrix<<>>(output, input_scale, input, numel, lda); + } + else if (quantize_mode == QuantizeMode::PER_TENSOR) + { + scaleMatrix<<>>(output, input_scale, input, numel, lda); + } + sync_check_cuda_error(); +} + +template +void invokeDequantizeMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda, + QuantizeMode quantize_mode, cudaStream_t stream) +{ + dim3 grid(1024); + dim3 block(CTA_SIZE); + if (quantize_mode == QuantizeMode::PER_CHANNEL) + { + scaleMatrix + <<>>(output, input_scale, input, numel, lda); + } + else if (quantize_mode == QuantizeMode::PER_TOKEN) + { + scaleMatrix<<>>(output, input_scale, input, numel, lda); + } + else if (quantize_mode == QuantizeMode::PER_TENSOR) + { + scaleMatrix + <<>>(output, input_scale, input, numel, lda); + } + sync_check_cuda_error(); +} + +template +__global__ void fakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel) +{ + for (int64_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < numel; tid += blockDim.x * gridDim.x) + { + T_FAKE tmp = (T_FAKE) (static_cast(src[tid])); + dst[tid] = (T_OUT) (static_cast(tmp)); + } +} + +template +void invokeFakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel, cudaStream_t stream) +{ + fakeQuantize<<<1024, CTA_SIZE, 0, stream>>>(dst, src, numel); + sync_check_cuda_error(); +} + +template void invokeFakeQuantize<__nv_fp8_e4m3, float, float>( + float* dst, float const* src, const int64_t numel, cudaStream_t stream); +template void invokeFakeQuantize( + float* dst, __nv_fp8_e4m3 const* src, const int64_t numel, cudaStream_t stream); +template void invokeFakeQuantize<__nv_fp8_e4m3, half, half>( + half* dst, half const* src, const int64_t numel, cudaStream_t stream); +template void invokeFakeQuantize<__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16>( + __nv_bfloat16* dst, __nv_bfloat16 const* src, const int64_t numel, cudaStream_t stream); + +template void invokeFakeQuantize( + half* dst, float const* src, const int64_t numel, cudaStream_t stream); + +__device__ float atomicMaxExtd(float* address, float val) +{ + assert(val >= 0); + unsigned int* address_as_u = reinterpret_cast(address); + unsigned int old = atomicMax(address_as_u, __float_as_uint(val)); + return __uint_as_float(old); +} + +template +inline __device__ T atomicMaxExtdV2(T* address, T val) +{ +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + static_assert(std::is_same_v | std::is_same_v, "T needs to be either half or bfloat16"); + // The address in 64 bits. + uint64_t address_u64 = reinterpret_cast(address); + + // Pack the input value into 32 bits. + union + { + T v[2]; + uint16_t u[2]; + } old, tmp = {}; + + int const loc = (address_u64 & 0x2) >> 1; + tmp.v[loc] = val; + + // 4B aligned pointer. + auto aligned_address = reinterpret_cast(address_u64 & ~0x3ull); + + if constexpr (std::is_same_v) + { + asm volatile("atom.global.v2.f16.max.noftz {%0, %1}, [%2], {%3, %4};" + : "=h"(old.u[0]), "=h"(old.u[1]) + : "l"(aligned_address), "h"(tmp.u[0]), "h"(tmp.u[1])); + } + if constexpr (std::is_same_v) + { + asm volatile("atom.global.v2.bf16.max.noftz {%0, %1}, [%2], {%3, %4};" + : "=h"(old.u[0]), "=h"(old.u[1]) + : "l"(aligned_address), "h"(tmp.u[0]), "h"(tmp.u[1])); + } + + // Return the correct half. + return old.v[loc]; +#endif +} + +__device__ half atomicMaxExtd(half* address, half val) +{ + unsigned short int* address_as_u = reinterpret_cast(address); + unsigned short int old = *address_as_u, assumed; + + while (val > __ushort_as_half(old)) + { + assumed = old; + old = atomicCAS(address_as_u, assumed, __half_as_ushort(val)); + } + + return __ushort_as_half(old); +} + +__device__ __nv_bfloat16 atomicMaxExtd(__nv_bfloat16* address, __nv_bfloat16 val) +{ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + unsigned short int* address_as_u = reinterpret_cast(address); + unsigned short int old = *address_as_u, assumed; + + while (val > __ushort_as_bfloat16(old)) + { + assumed = old; + old = atomicCAS(address_as_u, assumed, __bfloat16_as_ushort(val)); + } + + return __ushort_as_bfloat16(old); +#else + assert(0); + asm volatile("brkpt;\n" ::); + return __nv_bfloat16(0); +#endif +} + +template +__global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t size, const int64_t n) +{ + constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f); + if (QUANTIZE_MODE == QuantizeMode::PER_CHANNEL) + { + for (int64_t col = threadIdx.x; col < n; col += blockDim.x) + { + float max = 0.f; + for (int64_t i = col + n * blockIdx.x; i < size; i += gridDim.x * n) + { + auto val = fabs(static_cast(weights[i])); + max = max > val ? max : val; + } + auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + if constexpr (std::is_same_v) + { + atomicMaxExtd(quant_ptr + col, scale); + } + else + { + auto const address_u64 = reinterpret_cast(quant_ptr + col); + if ((col == 0 && address_u64 % 4 != 0) || (col == n - 1 && address_u64 % 4 == 0)) + atomicMaxExtd(quant_ptr + col, scale); + else + atomicMaxExtdV2(quant_ptr + col, scale); + } +#else // Vector atomics require __CUDA_ARCH__ >= 900 + atomicMaxExtd(quant_ptr + col, scale); +#endif + } + } + else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN) + { + auto const nrows = size / n; + for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x) + { + float max = 0.f; + for (int64_t i = threadIdx.x; i < n; i += blockDim.x) + { + auto val = fabs(static_cast(weights[row * n + i])); + max = max > val ? max : val; + } + max = blockReduceMax(max); + if (threadIdx.x == 0) + { + auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); + quant_ptr[row] = scale; + } + } + } + else if (QUANTIZE_MODE == QuantizeMode::PER_TENSOR) + { + float max = 0.f; + for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < size; i += gridDim.x * blockDim.x) + { + auto val = fabs(static_cast(weights[i])); + max = max > val ? max : val; + } + max = blockReduceMax(max); + if (threadIdx.x == 0) + { + auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); + atomicMaxExtd(quant_ptr, scale); + } + } +} + +template +void invokeComputeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t numel, const int64_t lda, + QuantizeMode quantize_mode, cudaStream_t stream) +{ + if (quantize_mode == QuantizeMode::PER_TOKEN) + { + dim3 block(CTA_SIZE); + dim3 grid(numel / lda); + computeFP8QuantizeScale<<>>(quant_ptr, weights, numel, lda); + } + else if (quantize_mode == QuantizeMode::PER_CHANNEL) + { + dim3 block(CTA_SIZE); + dim3 grid((lda + CTA_SIZE - 1) / CTA_SIZE); + cudaMemsetAsync(quant_ptr, 0, lda * sizeof(T_S), stream); + sync_check_cuda_error(); + computeFP8QuantizeScale<<>>(quant_ptr, weights, numel, lda); + } + else if (quantize_mode == QuantizeMode::PER_TENSOR) + { + dim3 block(1024); + dim3 grid(1024); + cudaMemsetAsync(quant_ptr, 0, sizeof(T_S), stream); + sync_check_cuda_error(); + computeFP8QuantizeScale<<>>(quant_ptr, weights, numel, lda); + } + sync_check_cuda_error(); +} + +#define DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(type_scale, type_in) \ + template void invokeComputeFP8QuantizeScale(type_scale * input_scale, type_in const* weights, \ + int64_t numel, int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream); + +DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(half, half); +DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, half); +DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, float); +#ifdef ENABLE_BF16 +DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(__nv_bfloat16, __nv_bfloat16); +DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, __nv_bfloat16); +#endif + +template +__global__ void dynamicQuantizeMatrixPerToken( + T_OUT* output, T_S* quant_ptr, T_IN const* input, int64_t numel, int64_t lda) +{ + extern __shared__ __align__(sizeof(float)) char _shmem[]; + T_IN* shmem = reinterpret_cast(_shmem); + constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f); + auto const nrows = numel / lda; + for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x) + { + float max = 0.f; + for (int64_t i = threadIdx.x; i < lda; i += blockDim.x) + { + auto const in = input[row * lda + i]; + shmem[i] = in; + auto val = fabs(static_cast(in)); + max = max > val ? max : val; + } + max = blockAllReduceMax(max); // __syncthreads() called so we can read shmem + auto const s = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); + for (int64_t i = threadIdx.x; i < lda; i += blockDim.x) + { + // true means we are quantizing + output[row * lda + i] = (T_OUT) scale(static_cast(shmem[i]), static_cast(s)); + } + if (threadIdx.x == 0) + { + quant_ptr[row] = s; + } + } +} + +template +void invokeComputeScalesAndQuantizeMatrix(T_OUT* output, T_S* quant_ptr, const T_IN* input, const int64_t numel, + const int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream) +{ + if (quantize_mode == QuantizeMode::PER_TOKEN) + { + dim3 grid(numel / lda); + bool use_shmem = true; + auto const shmem_size = lda * sizeof(T_IN); + if (shmem_size >= (48 << 10)) + { + cudaError_t ret = cudaFuncSetAttribute(dynamicQuantizeMatrixPerToken, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + use_shmem = ret == cudaSuccess; + } + if (use_shmem) + { + // ensure the threadblock is as large as possible to increase occupancy + dim3 block(std::min((lda + 31) / 32 * 32, static_cast(1024))); + dynamicQuantizeMatrixPerToken<<>>(output, quant_ptr, input, numel, lda); + } + else + { + dim3 block(CTA_SIZE); + computeFP8QuantizeScale<<>>(quant_ptr, input, numel, lda); + sync_check_cuda_error(); + invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream); + } + } + else if (quantize_mode == QuantizeMode::PER_CHANNEL) + { + dim3 block(CTA_SIZE); + dim3 grid((lda + CTA_SIZE - 1) / CTA_SIZE); + cudaMemsetAsync(quant_ptr, 0, lda * sizeof(T_S), stream); + sync_check_cuda_error(); + computeFP8QuantizeScale<<>>(quant_ptr, input, numel, lda); + sync_check_cuda_error(); + invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream); + } + else if (quantize_mode == QuantizeMode::PER_TENSOR) + { + dim3 block(1024); + dim3 grid(1024); + cudaMemsetAsync(quant_ptr, 0, sizeof(T_S), stream); + sync_check_cuda_error(); + computeFP8QuantizeScale<<>>(quant_ptr, input, numel, lda); + sync_check_cuda_error(); + invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream); + } + sync_check_cuda_error(); +} + +#define DEFINE_INVOKE_QUANTIZE_MATRIX(type_out, type_scale, type_in) \ + template void invokeQuantizeMatrix(type_out * output, \ + type_scale const* input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \ + cudaStream_t stream); \ + template void invokeDequantizeMatrix(type_out * output, \ + type_scale const* input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \ + cudaStream_t stream); \ + template void invokeComputeScalesAndQuantizeMatrix(type_out * output, \ + type_scale * input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \ + cudaStream_t stream); + +#ifdef ENABLE_FP8 +DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, float, float); +DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, float, half); +DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, half, half); +DEFINE_INVOKE_QUANTIZE_MATRIX(half, half, __nv_fp8_e4m3); +DEFINE_INVOKE_QUANTIZE_MATRIX(float, float, __nv_fp8_e4m3); +DEFINE_INVOKE_QUANTIZE_MATRIX(half, float, __nv_fp8_e4m3); +#ifdef ENABLE_BF16 +DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16); +DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_bfloat16, __nv_bfloat16, __nv_fp8_e4m3); +#endif +#endif + +#endif // ENABLE_FP8 +} // namespace common +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaProfilerUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaProfilerUtils.cpp new file mode 100644 index 00000000000..5576fe782fa --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaProfilerUtils.cpp @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/cudaProfilerUtils.h" +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/common/stringUtils.h" +#include +#include + +namespace +{ + +std::tuple, std::unordered_set> populateIterationIndexesImpl( + std::string const& envVarName) +{ + auto envVarVal = std::getenv(envVarName.c_str()); + auto envVarValStr = std::string{envVarVal != nullptr ? envVarVal : ""}; + auto values = tensorrt_llm::common::str2set(envVarValStr, ','); + std::unordered_set startSet; + std::unordered_set endSet; + for (std::string const& value : values) + { + size_t dashIdx = value.find("-"); + if (dashIdx != std::string::npos) + { + int32_t start = std::stoi(value.substr(0, dashIdx)); + startSet.insert(start); + int32_t end = std::stoi(value.substr(dashIdx + 1)); + endSet.insert(end); + } + else + { + int32_t start_end = std::stoi(value); + startSet.insert(start_end); + endSet.insert(start_end); + } + } + + return std::make_pair(startSet, endSet); +} + +} // namespace + +namespace tensorrt_llm::common +{ + +std::pair, std::unordered_set> populateIterationIndexes( + std::string const& envVarName, std::optional const& legacyEnvVarName) +{ + auto [profileIterIdxs, stopIterIdxs] = populateIterationIndexesImpl(envVarName); + + // If empty, try to use legacy env var name + if (legacyEnvVarName && profileIterIdxs.empty() && stopIterIdxs.empty()) + { + std::tie(profileIterIdxs, stopIterIdxs) = populateIterationIndexesImpl(legacyEnvVarName.value()); + + if (!profileIterIdxs.empty() || !stopIterIdxs.empty()) + { + TLLM_LOG_WARNING( + "Using deprecated environment variable %s to specify cudaProfiler start and stop iterations. " + "Please " + "use %s " + "instead.", + legacyEnvVarName.value().c_str(), envVarName.c_str()); + } + } + + return std::make_pair(profileIterIdxs, stopIterIdxs); +} + +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaTypeUtils.cuh b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaTypeUtils.cuh new file mode 100644 index 00000000000..a0463a3a49e --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaTypeUtils.cuh @@ -0,0 +1,752 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh" +#include "tensorrt_llm/common/cudaBf16Wrapper.h" +#include "tensorrt_llm/common/cudaFp8Utils.h" +#include +#include +#include +#if ENABLE_BF16 +#include +#endif + +namespace tensorrt_llm +{ +namespace common +{ + +template +inline __device__ T ldg(T const* val) +{ + return __ldg(val); +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat162 ldg(__nv_bfloat162 const* val) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return val[0]; +#else + return __ldg(val); +#endif +} + +template <> +inline __device__ __nv_bfloat16 ldg(__nv_bfloat16 const* val) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return val[0]; +#else + return __ldg(val); +#endif +} +#endif // ENABLE_BF16 + +// Get type2 from type or vice versa (applied to half and bfloat16) +template +struct TypeConverter +{ + using Type = half2; +}; // keep for generality + +template <> +struct TypeConverter +{ + using Type = half; +}; + +template <> +struct TypeConverter +{ + using Type = half2; +}; + +#if ENABLE_BF16 +template <> +struct TypeConverter<__nv_bfloat162> +{ + using Type = __nv_bfloat16; +}; + +template <> +struct TypeConverter<__nv_bfloat16> +{ + using Type = __nv_bfloat162; +}; +#endif // ENABLE_BF16 + +// Defined math operations (bfloat16 fallback to fp32 when it is not supported) +template +inline __device__ T hadd2(T a, T b) +{ + return __hadd2(a, b); +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat162 hadd2(__nv_bfloat162 a, __nv_bfloat162 b) +{ + return bf16hadd2(a, b); +} +#endif // ENABLE_BF16 + +template +inline __device__ T add(T a, T b) +{ + return a + b; +} + +template <> +inline __device__ half2 add(half2 a, half2 b) +{ + return __hadd2(a, b); +} + +template <> +inline __device__ half add(half a, half b) +{ + return __hadd(a, b); +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) +{ + return bf16hadd2(a, b); +} + +template <> +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) +{ + return bf16hadd(a, b); +} + +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, float b) +{ + return bf16hadd(a, __float2bfloat16(b)); +} +#endif // ENABLE_BF16 + +// applies to all 4 values addition +template +inline __device__ T add(T a, T b, T c) +{ + return a + b + c; +} + +#if ENABLE_BF16 +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) +{ + return bf16hadd(a, b, c); +} + +inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) +{ + return bf16hadd2(a, b, c); +} +#endif // ENABLE_BF16 + +// applies to all 4 values addition +template +inline __device__ T add(T a, T b, T c, T d) +{ + return (T) ((float) a + (float) b + (float) c + (float) d); +} + +#if ENABLE_BF16 +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) +{ + return bf16hadd(a, b, c, d); +} +#endif // ENABLE_BF16 + +template +inline __device__ T hsub2(T a, T b) +{ + return __hsub2(a, b); +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat162 hsub2(__nv_bfloat162 a, __nv_bfloat162 b) +{ + return bf16hsub2(a, b); +} +#endif // ENABLE_BF16 + +template +inline __device__ T hmul2(T a, T b) +{ + return __hmul2(a, b); +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b) +{ + return bf16hmul2(a, b); +} +#endif // ENABLE_BF16 + +template +inline __device__ T hmul2(T a, T b, T c) +{ + return a * b * c; +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) +{ + return bf16hmul2(a, b, c); +} +#endif // ENABLE_BF16 + +template +inline __device__ T mul(T a, T b, T c) +{ + return a * b * c; +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) +{ + return bf16hmul(a, b, c); +} + +inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) +{ + return bf16hmul2(a, b, c); +} +#endif // ENABLE_BF16 + +template +inline __device__ T fma(T a, T b, T c, T d) +{ + return a * b * c + d; +} + +#if ENABLE_BF16 +inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) +{ + return bf16hfma2(a, b, c, d); +} +#endif // ENABLE_BF16 + +template +inline __device__ T fma(T a, T b, T c) +{ + return a * b + c; +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) +{ + return bf16hfma2(a, b, c); +} + +template <> +inline __device__ __nv_bfloat16 fma(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) +{ + return bf16hfma(a, b, c); +} +#endif // ENABLE_BF16 + +template +inline __device__ T hexp2(T a) +{ + return h2exp(a); +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat162 hexp2(__nv_bfloat162 a) +{ + return bf16exp2(a); +} +#endif // ENABLE_BF16 + +template +__device__ inline T_OUT cuda_cast(T_IN val) +{ + return val; +} + +template <> +__device__ inline float2 cuda_cast(int2 val) +{ + return make_float2(val.x, val.y); +} + +template <> +__device__ inline float2 cuda_cast(float val) +{ + return make_float2(val, val); +} + +template <> +__device__ inline float2 cuda_cast(half2 val) +{ + return __half22float2(val); +} + +template <> +__device__ inline half2 cuda_cast(float2 val) +{ + return __float22half2_rn(val); +} + +template <> +__device__ inline half2 cuda_cast(float val) +{ + return __float2half2_rn(val); +} + +template <> +__device__ inline half2 cuda_cast(half val) +{ + return __half2half2(val); +} + +template <> +__device__ inline int8_t cuda_cast(half val) +{ + union + { + int8_t int8[2]; + int16_t int16; + }; + + union + { + half fp16; + int16_t int16_in; + }; + + fp16 = val; + asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in)); + return int8[0]; +} + +template <> +__device__ inline int16_t cuda_cast(half2 val) +{ + union + { + int8_t int8[2]; + int16_t int16; + }; + + int8[0] = cuda_cast(val.x); + int8[1] = cuda_cast(val.y); + return int16; +} + +template <> +__device__ inline int8_t cuda_cast(float val) +{ + union + { + int8_t int8[2]; + int16_t int16; + }; + + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val)); + return int8[0]; +} + +template <> +__device__ inline int16_t cuda_cast(float2 val) +{ + union + { + int8_t int8[2]; + int16_t int16; + }; + + int8[0] = cuda_cast(val.x); + int8[1] = cuda_cast(val.y); + return int16; +} + +template <> +__device__ inline half2 cuda_cast(int16_t val) +{ + union + { + int8_t int8[2]; + int16_t int16; + }; + + int16 = val; + return make_half2(int8[0], int8[1]); +} + +template <> +__device__ inline float2 cuda_cast(int16_t val) +{ + union + { + int8_t int8[2]; + int16_t int16; + }; + + int16 = val; + return make_float2(int8[0], int8[1]); +} + +#ifdef ENABLE_BF16 +template <> +__device__ inline __nv_bfloat16 cuda_cast(int32_t val) +{ + return static_cast(val); +} + +template <> +__device__ inline __nv_bfloat16 cuda_cast(int8_t val) +{ + return static_cast(val); +} + +template <> +__device__ inline int8_t cuda_cast(__nv_bfloat16 val) +{ + return static_cast(val); +} + +template <> +__device__ inline float cuda_cast(__nv_bfloat16 val) +{ + return __bfloat162float(val); +} + +template <> +__device__ inline float2 cuda_cast(__nv_bfloat162 val) +{ + return bf1622float2(val); +} + +template <> +__device__ inline half cuda_cast(__nv_bfloat16 val) +{ + return __float2half(__bfloat162float(val)); +} + +template <> +__device__ inline int16_t cuda_cast(__nv_bfloat162 val) +{ + return bf1622int16(val); +} + +template <> +__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val) +{ + return __float2bfloat16(val); +} + +template <> +__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, half>(half val) +{ + return __float2bfloat16(__half2float(val)); +} + +template <> +__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_bfloat16>(__nv_bfloat16 val) +{ + return bf162bf162(val); +} + +template <> +__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float>(float val) +{ + return __float2bfloat162_rn(val); +} + +template <> +__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float2>(float2 val) +{ + return float22bf162(val); +} + +template <> +__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(int16_t val) +{ + union + { + int8_t int8[2]; + int16_t int16; + }; + + int16 = val; + __nv_bfloat162 res; + res.x = cuda_cast<__nv_bfloat16>(int8[0]); + res.y = cuda_cast<__nv_bfloat16>(int8[1]); + return res; +} + +template <> +__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val) +{ + return float22bf162(__half22float2(val)); +} + +#endif // ENABLE BF16 + +template +__device__ inline T cuda_abs(T val) +{ + assert(false); + return {}; +} + +template <> +__device__ inline float cuda_abs(float val) +{ + return fabs(val); +} + +template <> +__device__ inline float2 cuda_abs(float2 val) +{ + return make_float2(fabs(val.x), fabs(val.y)); +} + +template <> +__device__ inline half cuda_abs(half val) +{ + return __habs(val); +} + +template <> +__device__ inline half2 cuda_abs(half2 val) +{ + return __habs2(val); +} + +#ifdef ENABLE_BF16 + +#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__) +template <> +__device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val) +{ + return __habs(val); +} + +template <> +__device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val) +{ + return __habs2(val); +} +#endif + +#endif // ENABLE_FP16 + +template +__device__ inline To cuda_sum(Ti val) +{ + return cuda_cast(val); +}; + +template +__device__ inline To cuda_sum(float2 val) +{ + return cuda_cast(val.x + val.y); +}; + +// Unary maximum: compute the max of a vector type +template +__device__ inline To cuda_max(Ti val) +{ + return cuda_cast(val); +}; + +template <> +__device__ inline float cuda_max(float2 val) +{ + return fmaxf(val.x, val.y); +} + +template <> +__device__ inline half cuda_max(half2 val) +{ + return __hmax(val.x, val.y); +} + +#ifdef ENABLE_BF16 +template <> +__device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val) +{ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + return __hmax(val.x, val.y); +#else + assert(0); + asm volatile("brkpt;\n" ::); + return __nv_bfloat16(0); +#endif +} +#endif + +// Binary maximum: compute the max of two values. +template +__device__ inline T cuda_max(T val1, T val2) +{ + return (val1 > val2) ? val1 : val2; +} + +template <> +__device__ inline float2 cuda_max(float2 val1, float2 val2) +{ + float2 out; + out.x = fmaxf(val1.x, val2.x); + out.y = fmaxf(val1.y, val2.y); + return out; +} + +template <> +__device__ inline half2 cuda_max(half2 val1, half2 val2) +{ + return __hmax2(val1, val2); +} + +#ifdef ENABLE_BF16 +template <> +__device__ inline __nv_bfloat162 cuda_max(__nv_bfloat162 val1, __nv_bfloat162 val2) +{ + return __hmax2(val1, val2); +} +#endif // ENABLE_BF16 + +// Binary maximum: compute the min of two values. +template +__device__ inline T cuda_min(T val1, T val2) +{ + return (val1 < val2) ? val1 : val2; +} + +template <> +__device__ inline float2 cuda_min(float2 val1, float2 val2) +{ + float2 out; + out.x = fminf(val1.x, val2.x); + out.y = fminf(val1.y, val2.y); + return out; +} + +template <> +__device__ inline half2 cuda_min(half2 val1, half2 val2) +{ + return __hmin2(val1, val2); +} + +#ifdef ENABLE_BF16 +template <> +__device__ inline __nv_bfloat162 cuda_min(__nv_bfloat162 val1, __nv_bfloat162 val2) +{ + return __hmin2(val1, val2); +} +#endif // ENABLE_BF16 + +// Helper function of clamping the val into the given range. +template +inline __device__ T cuda_clamp(T val, T minVal, T maxVal) +{ + return cuda_min(cuda_max(val, minVal), maxVal); +} + +#ifdef ENABLE_FP8 +template <> +__device__ inline float2 cuda_cast(__nv_fp8x2_e4m3 val) +{ + return bf1622float2(fp8x2_e4m3_to_bfloat2(&val)); +} + +template <> +__device__ inline half2 cuda_cast(__nv_fp8x2_e4m3 val) +{ + return fp8x2_e4m3_to_half2(&val); +} + +template <> +__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, float2>(float2 val) +{ + return __nv_fp8x2_e4m3(bf1622float2(float22bf162(val))); +} + +template <> +__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, half2>(half2 val) +{ + return __nv_fp8x2_e4m3(cuda_cast(val)); +} + +template <> +__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, __nv_bfloat162>(__nv_bfloat162 val) +{ + return __nv_fp8x2_e4m3(cuda_cast(val)); +} + +template <> +__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, half>(half val) +{ + return __nv_fp8_e4m3(val); +} + +template <> +__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, __nv_bfloat16>(__nv_bfloat16 val) +{ + return __nv_fp8_e4m3(val); +} + +template <> +__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, float>(float val) +{ + return __nv_fp8_e4m3(val); +} + +template <> +__device__ inline float cuda_cast(__nv_fp8_e4m3 val) +{ + return (float) val; +} + +template <> +__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val) +{ + return fp8x2_e4m3_to_bfloat2(&val); +} + +template <> +__device__ inline int8_t cuda_cast(__nv_fp8_e4m3 val) +{ + // no impl + return 0; +} + +template <> +__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, int8_t>(int8_t val) +{ + return cuda_cast<__nv_fp8_e4m3>(cuda_cast<__nv_bfloat16>(cuda_cast(val))); +} + +#endif // ENABLE_FP8 + +} // namespace common +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/customAllReduceUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/customAllReduceUtils.h new file mode 100644 index 00000000000..d7bf43b4075 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/customAllReduceUtils.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace tensorrt_llm::utils::customAllReduceUtils +{ + +constexpr size_t NUM_POINTERS_PER_RANK = 7; + +// WARNING: MUST BE KEPT IN SYNC with tensorrt_llm/plugin/plugin.py +inline size_t getMaxRequiredWorkspaceSize(int worldSize) noexcept +{ + if (worldSize <= 2) + { + return 16 * 1000 * 1000; + } + return 8 * 1000 * 1000; +} + +} // namespace tensorrt_llm::utils::customAllReduceUtils diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.cpp new file mode 100644 index 00000000000..64d3d44acb8 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.cpp @@ -0,0 +1,214 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "envUtils.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/logger.h" +#include + +namespace tensorrt_llm::common +{ + +std::optional getIntEnv(char const* name) +{ + char const* const env = std::getenv(name); + if (env == nullptr) + { + return std::nullopt; + } + int32_t const val = std::stoi(env); + if (val <= 0) + { + return std::nullopt; + } + return {val}; +}; + +// Returns true if the env variable exists and is set to "1" +static bool getBoolEnv(char const* name) +{ + char const* env = std::getenv(name); + return env && env[0] == '1' && env[1] == '\0'; +} + +// XQA kernels (optimized kernels for generation phase). +bool forceXQAKernels() +{ + static bool const forceXQA = (getIntEnv("TRTLLM_FORCE_XQA").value_or(0) != 0); + return forceXQA; +} + +std::optional getEnvEnableXQAJIT() +{ + static bool init = false; + static bool exists = false; + static bool enableXQAJIT = false; + if (!init) + { + init = true; + char const* enable_xqa_jit_var = std::getenv("TRTLLM_ENABLE_XQA_JIT"); + if (enable_xqa_jit_var) + { + exists = true; + if (enable_xqa_jit_var[0] == '1' && enable_xqa_jit_var[1] == '\0') + { + enableXQAJIT = true; + } + } + } + if (exists) + { + return enableXQAJIT; + } + else + { + return std::nullopt; + } +} + +// Tune the number of blocks per sequence for accuracy/performance purpose. +bool getEnvMmhaMultiblockDebug() +{ + static bool init = false; + static bool forceMmhaMaxSeqLenTile = false; + if (!init) + { + init = true; + char const* enable_mmha_debug_var = std::getenv("TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG"); + if (enable_mmha_debug_var) + { + if (enable_mmha_debug_var[0] == '1' && enable_mmha_debug_var[1] == '\0') + { + forceMmhaMaxSeqLenTile = true; + } + } + } + return forceMmhaMaxSeqLenTile; +} + +int getEnvMmhaBlocksPerSequence() +{ + static bool init = false; + static int mmhaBlocksPerSequence = 0; + if (!init) + { + init = true; + char const* mmhaBlocksPerSequenceEnv = std::getenv("TRTLLM_MMHA_BLOCKS_PER_SEQUENCE"); + if (mmhaBlocksPerSequenceEnv) + { + mmhaBlocksPerSequence = std::atoi(mmhaBlocksPerSequenceEnv); + if (mmhaBlocksPerSequence <= 0) + { + TLLM_LOG_WARNING("Invalid value for TRTLLM_MMHA_BLOCKS_PER_SEQUENCE. Will use default values instead!"); + } + } + } + return mmhaBlocksPerSequence; +} + +int getEnvMmhaKernelBlockSize() +{ + static bool init = false; + static int mmhaKernelBlockSize = 0; + if (!init) + { + init = true; + char const* mmhaKernelBlockSizeEnv = std::getenv("TRTLLM_MMHA_KERNEL_BLOCK_SIZE"); + if (mmhaKernelBlockSizeEnv) + { + mmhaKernelBlockSize = std::atoi(mmhaKernelBlockSizeEnv); + if (mmhaKernelBlockSize <= 0) + { + TLLM_LOG_WARNING("Invalid value for TRTLLM_MMHA_KERNEL_BLOCK_SIZE. Will use default values instead!"); + } + } + } + return mmhaKernelBlockSize; +} + +bool getEnvEnablePDL() +{ + static bool init = false; + static bool enablePDL = false; + if (!init) + { + init = true; + // PDL only available when arch >= 90 + if (getSMVersion() >= 90) + { + // PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1` + enablePDL = getBoolEnv("TRTLLM_ENABLE_PDL"); + } + } + return enablePDL; +} + +bool getEnvUseUCXKvCache() +{ + static bool const useUCXKVCache = getBoolEnv("TRTLLM_USE_UCX_KVCACHE"); + return useUCXKVCache; +} + +std::string getEnvUCXInterface() +{ + static bool init = false; + static std::string ucxInterface; + if (!init) + { + init = true; + { + char const* ucx_interface = std::getenv("TRTLLM_UCX_INTERFACE"); + if (ucx_interface) + { + ucxInterface = ucx_interface; + } + } + } + return ucxInterface; +} + +bool getEnvDisaggLayerwise() +{ + static bool const disaggLayerwise = getBoolEnv("TRTLLM_DISAGG_LAYERWISE"); + return disaggLayerwise; +} + +bool getEnvParallelCacheSend() +{ + static bool const parallelCacheSend = getBoolEnv("TRTLLM_PARALLEL_CACHE_SEND"); + return parallelCacheSend; +} + +bool getEnvRequestKVCacheSerial() +{ + static bool const requestKVCacheSerial = getBoolEnv("TRTLLM_REQUEST_KV_CACHE_SERIAL"); + return requestKVCacheSerial; +} + +bool getEnvDisableKVCacheTransferOverlap() +{ + static bool const disableKVCacheTransferOverlap = getBoolEnv("TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP"); + return disableKVCacheTransferOverlap; +} + +bool getEnvDisableReceiveKVCacheParallel() +{ + static bool const disableReceiveParallel = getBoolEnv("TRTLLM_DISABLE_KVCACHE_RECEIVE_PARALLEL"); + return disableReceiveParallel; +} + +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.h new file mode 100644 index 00000000000..027c7cfbb3b --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.h @@ -0,0 +1,60 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include + +namespace tensorrt_llm::common +{ +// Useful when you want to inject some debug code controllable with env var. +std::optional getIntEnv(char const* name); + +// XQA kernels (optimized kernels for generation phase). +bool forceXQAKernels(); + +// Whether XQA JIT is enabled. +// +// Returns the value of TRTLLM_ENABLE_XQA_JIT env var. If such env var doesn't exist, std::nullopt is returned. +std::optional getEnvEnableXQAJIT(); + +// Tune the number of blocks per sequence for accuracy/performance purpose. +bool getEnvMmhaMultiblockDebug(); + +int getEnvMmhaBlocksPerSequence(); + +int getEnvMmhaKernelBlockSize(); + +// Whether PDL is enabled. +bool getEnvEnablePDL(); + +bool getEnvUseUCXKvCache(); + +std::string getEnvUCXInterface(); + +bool getEnvDisaggLayerwise(); + +bool getEnvParallelCacheSend(); + +bool getEnvRequestKVCacheSerial(); + +bool getEnvDisableKVCacheTransferOverlap(); + +bool getEnvDisableReceiveKVCacheParallel(); + +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/logger.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/logger.cpp new file mode 100644 index 00000000000..334ad236906 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/logger.cpp @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/tllmException.h" +#include + +namespace tensorrt_llm::common +{ + +Logger::Logger() +{ + char* isFirstRankOnlyChar = std::getenv("TLLM_LOG_FIRST_RANK_ONLY"); + bool isFirstRankOnly = (isFirstRankOnlyChar != nullptr && std::string(isFirstRankOnlyChar) == "ON"); + + auto const* levelName = std::getenv("TLLM_LOG_LEVEL"); + if (levelName != nullptr) + { + auto level = [levelName = std::string(levelName)]() + { + if (levelName == "TRACE") + return TRACE; + if (levelName == "DEBUG") + return DEBUG; + if (levelName == "INFO") + return INFO; + if (levelName == "WARNING") + return WARNING; + if (levelName == "ERROR") + return ERROR; + TLLM_THROW("Invalid log level: %s", levelName.c_str()); + }(); + // If TLLM_LOG_FIRST_RANK_ONLY=ON, set LOG LEVEL of other device to ERROR + if (isFirstRankOnly) + { + auto const deviceId = getDevice(); + if (deviceId != 1) + { + level = ERROR; + } + } + setLevel(level); + } +} + +void Logger::log(std::exception const& ex, Logger::Level level) +{ + log(level, "%s: %s", TllmException::demangle(typeid(ex).name()).c_str(), ex.what()); +} + +Logger* Logger::getLogger() +{ + thread_local Logger instance; + return &instance; +} +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/mathUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/mathUtils.h new file mode 100644 index 00000000000..1bad3a2c152 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/mathUtils.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace tensorrt_llm +{ +namespace common +{ + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ __host__ T divUp(T m, T n) +{ + return (m + n - 1) / n; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace common +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.cu b/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.cu new file mode 100644 index 00000000000..d13217b203a --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.cu @@ -0,0 +1,906 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaTypeUtils.cuh" +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/common/memoryUtils.h" + +#include +#include +#include + +namespace tensorrt_llm +{ +namespace common +{ + +template +void deviceMalloc(T** ptr, size_t size, bool is_random_initialize) +{ + check_cuda_error(cudaMalloc((void**) (ptr), sizeof(T) * size)); + if (is_random_initialize) + { + cudaRandomUniform(*ptr, size); + } +} + +template void deviceMalloc(float** ptr, size_t size, bool is_random_initialize); +template void deviceMalloc(half** ptr, size_t size, bool is_random_initialize); +#ifdef ENABLE_BF16 +template void deviceMalloc(__nv_bfloat16** ptr, size_t size, bool is_random_initialize); +#endif +template void deviceMalloc(uint16_t** ptr, size_t size, bool is_random_initialize); +template void deviceMalloc(int** ptr, size_t size, bool is_random_initialize); +template void deviceMalloc(bool** ptr, size_t size, bool is_random_initialize); +template void deviceMalloc(char** ptr, size_t size, bool is_random_initialize); +template void deviceMalloc(int8_t** ptr, size_t size, bool is_random_initialize); +#ifdef ENABLE_FP8 +template void deviceMalloc(__nv_fp8_e4m3** ptr, size_t size, bool is_random_initialize); +#endif + +template +void deviceMemSetZero(T* ptr, size_t size) +{ + check_cuda_error(cudaMemset(static_cast(ptr), 0, sizeof(T) * size)); +} + +template void deviceMemSetZero(float* ptr, size_t size); +template void deviceMemSetZero(half* ptr, size_t size); +template void deviceMemSetZero(int* ptr, size_t size); +template void deviceMemSetZero(uint32_t* ptr, size_t size); +template void deviceMemSetZero(bool* ptr, size_t size); +#ifdef ENABLE_FP8 +template void deviceMemSetZero(__nv_fp8_e4m3* ptr, size_t size); +#endif +#ifdef ENABLE_BF16 +template void deviceMemSetZero(__nv_bfloat16* ptr, size_t size); +#endif + +template +void deviceFree(T*& ptr) +{ + if (ptr != NULL) + { + check_cuda_error(cudaFree(ptr)); + ptr = NULL; + } +} + +template void deviceFree(float*& ptr); +template void deviceFree(half*& ptr); +#ifdef ENABLE_BF16 +template void deviceFree(__nv_bfloat16*& ptr); +#endif +template void deviceFree(unsigned short*& ptr); +template void deviceFree(int*& ptr); +template void deviceFree(bool*& ptr); +template void deviceFree(char*& ptr); +template void deviceFree(int8_t*& ptr); +#ifdef ENABLE_FP8 +template void deviceFree(__nv_fp8_e4m3*& ptr); +#endif + +template +void deviceFill(T* devptr, size_t size, T value, cudaStream_t stream) +{ + T* arr = new T[size]; + std::fill(arr, arr + size, value); + check_cuda_error(cudaMemcpyAsync(devptr, arr, sizeof(T) * size, cudaMemcpyHostToDevice, stream)); + delete[] arr; +} + +template void deviceFill(float* devptr, size_t size, float value, cudaStream_t stream); +template void deviceFill(half* devptr, size_t size, half value, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void deviceFill(__nv_bfloat16* devptr, size_t size, __nv_bfloat16 value, cudaStream_t stream); +#endif +template void deviceFill(int* devptr, size_t size, int value, cudaStream_t stream); +template void deviceFill(bool* devptr, size_t size, bool value, cudaStream_t stream); + +template +void cudaD2Hcpy(T* tgt, T const* src, const size_t size) +{ + check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyDeviceToHost)); +} + +template void cudaD2Hcpy(float* tgt, float const* src, size_t size); +template void cudaD2Hcpy(half* tgt, half const* src, size_t size); +#ifdef ENABLE_BF16 +template void cudaD2Hcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size); +#endif +template void cudaD2Hcpy(int* tgt, int const* src, size_t size); +template void cudaD2Hcpy(bool* tgt, bool const* src, size_t size); +#ifdef ENABLE_FP8 +template void cudaD2Hcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size); +#endif +template void cudaD2Hcpy(unsigned long long* tgt, unsigned long long const* src, size_t size); +template void cudaD2Hcpy(unsigned int* tgt, unsigned int const* src, size_t size); +template void cudaD2Hcpy(int8_t* tgt, int8_t const* src, size_t size); + +template +void cudaH2Dcpy(T* tgt, T const* src, const size_t size) +{ + check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyHostToDevice)); +} + +template void cudaH2Dcpy(float* tgt, float const* src, size_t size); +template void cudaH2Dcpy(half* tgt, half const* src, size_t size); +#ifdef ENABLE_BF16 +template void cudaH2Dcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size); +#endif +template void cudaH2Dcpy(int* tgt, int const* src, size_t size); +template void cudaH2Dcpy(bool* tgt, bool const* src, size_t size); +#ifdef ENABLE_FP8 +template void cudaH2Dcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size); +#endif +template void cudaH2Dcpy(unsigned long long* tgt, unsigned long long const* src, size_t size); +template void cudaH2Dcpy(unsigned int* tgt, unsigned int const* src, size_t size); +template void cudaH2Dcpy(int8_t* tgt, int8_t const* src, size_t size); + +template +void cudaD2Dcpy(T* tgt, T const* src, const size_t size, cudaStream_t stream) +{ + check_cuda_error(cudaMemcpyAsync(tgt, src, sizeof(T) * size, cudaMemcpyDeviceToDevice, stream)); +} + +template void cudaD2Dcpy(float* tgt, float const* src, size_t size, cudaStream_t stream); +template void cudaD2Dcpy(half* tgt, half const* src, size_t size, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void cudaD2Dcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size, cudaStream_t stream); +#endif +template void cudaD2Dcpy(int* tgt, int const* src, size_t size, cudaStream_t stream); +template void cudaD2Dcpy(bool* tgt, bool const* src, size_t size, cudaStream_t stream); +template void cudaD2Dcpy(int8_t* tgt, int8_t const* src, size_t size, cudaStream_t stream); +#ifdef ENABLE_FP8 +template void cudaD2Dcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size, cudaStream_t stream); +#endif +template void cudaD2Dcpy(unsigned long long* tgt, unsigned long long const* src, size_t size, cudaStream_t stream); + +template +__global__ void cudaCast(T_OUT* dst, T_IN* src, const size_t size) +{ + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) + { + dst[tid] = (T_OUT) ((float) (src[tid])); + } +} + +template +void invokeCudaCast(T_OUT* dst, T_IN const* const src, const size_t size, cudaStream_t stream) +{ + cudaCast<<<256, 256, 0, stream>>>(dst, src, size); +} + +template void invokeCudaCast(float* dst, half const* const src, const size_t size, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeCudaCast(float* dst, __nv_bfloat16 const* const src, const size_t size, cudaStream_t stream); +template void invokeCudaCast(__nv_bfloat16* dst, float const* const src, const size_t size, cudaStream_t stream); +template void invokeCudaCast(__nv_bfloat16* dst, half const* const src, const size_t size, cudaStream_t stream); +template void invokeCudaCast(half* dst, __nv_bfloat16 const* const src, const size_t size, cudaStream_t stream); +#endif +#ifdef ENABLE_FP8 +template void invokeCudaCast(float* dst, __nv_fp8_e4m3 const* const src, const size_t size, cudaStream_t stream); +template void invokeCudaCast( + __nv_bfloat16* dst, __nv_fp8_e4m3 const* const src, const size_t size, cudaStream_t stream); +template void invokeCudaCast(half* dst, __nv_fp8_e4m3 const* const src, const size_t size, cudaStream_t stream); +template void invokeCudaCast(__nv_fp8_e4m3* dst, float const* const src, const size_t size, cudaStream_t stream); +template void invokeCudaCast( + __nv_fp8_e4m3* dst, __nv_bfloat16 const* const src, const size_t size, cudaStream_t stream); +template void invokeCudaCast(__nv_fp8_e4m3* dst, half const* const src, const size_t size, cudaStream_t stream); +#endif + +template +void cudaAutoCpy(T* tgt, T const* src, const size_t size, cudaStream_t stream) +{ + if (stream != NULL) + { + check_cuda_error(cudaMemcpyAsync(tgt, src, sizeof(T) * size, cudaMemcpyDefault, stream)); + } + else + { + check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyDefault)); + } +} + +template void cudaAutoCpy(float* tgt, float const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(half* tgt, half const* src, size_t size, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void cudaAutoCpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size, cudaStream_t stream); +#endif +template void cudaAutoCpy(int* tgt, int const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(bool* tgt, bool const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(int8_t* tgt, int8_t const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(uint8_t* tgt, uint8_t const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(uint32_t* tgt, uint32_t const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(unsigned long long* tgt, unsigned long long const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(unsigned long* tgt, unsigned long const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(char* tgt, char const* src, size_t size, cudaStream_t stream); + +template void cudaAutoCpy(float const** tgt, float const* const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(half const** tgt, half const* const* src, size_t size, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void cudaAutoCpy(__nv_bfloat16 const** tgt, __nv_bfloat16 const* const* src, size_t size, cudaStream_t stream); +#endif +template void cudaAutoCpy(int const** tgt, int const* const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(bool const** tgt, bool const* const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(int8_t const** tgt, int8_t const* const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy( + unsigned long long const** tgt, unsigned long long const* const* src, size_t size, cudaStream_t stream); + +template +__global__ void cuda_random_uniform_kernel(T* buffer, const size_t size, int const seq_offset) +{ + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + curandState_t local_state; + curand_init((unsigned long long int) 1337, idx + seq_offset, 0, &local_state); + for (size_t index = idx; index < size; index += blockDim.x * gridDim.x) + { + buffer[index] = (T) (curand_uniform(&local_state) * 0.2f - 0.1f); + } +} + +template <> +__global__ void cuda_random_uniform_kernel(int* buffer, const size_t size, int const seq_offset) +{ + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + curandState_t local_state; + curand_init((float) 1337.f, idx + seq_offset, 0, &local_state); + for (size_t index = idx; index < size; index += blockDim.x * gridDim.x) + { + buffer[index] = curand(&local_state); + } +} + +template <> +__global__ void cuda_random_uniform_kernel(bool* buffer, const size_t size, int const seq_offset) +{ + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + curandState_t local_state; + curand_init((float) 1337.f, idx + seq_offset, 0, &local_state); + for (size_t index = idx; index < size; index += blockDim.x * gridDim.x) + { + buffer[index] = (curand(&local_state) % 2 == 0); + } +} + +template <> +__global__ void cuda_random_uniform_kernel(char* buffer, const size_t size, int const seq_offset) +{ + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + curandState_t local_state; + curand_init((float) 1337.f, idx + seq_offset, 0, &local_state); + for (size_t index = idx; index < size; index += blockDim.x * gridDim.x) + { + buffer[index] = curand(&local_state) % 0xFF; + } +} + +template +void cudaRandomUniform(T* buffer, const size_t size) +{ + static int seq_offset = 0; + cuda_random_uniform_kernel<<<256, 256>>>(buffer, size, seq_offset); + seq_offset += 256 * 256; +} + +template void cudaRandomUniform(float* buffer, const size_t size); +template void cudaRandomUniform(half* buffer, const size_t size); +#ifdef ENABLE_BF16 +template void cudaRandomUniform(__nv_bfloat16* buffer, const size_t size); +#endif +template void cudaRandomUniform(int* buffer, const size_t size); +template void cudaRandomUniform(bool* buffer, const size_t size); +template void cudaRandomUniform(char* buffer, const size_t size); +#ifdef ENABLE_FP8 +template void cudaRandomUniform(__nv_fp8_e4m3* buffer, const size_t size); +#endif + +// loads data from binary file. If it succeeds, returns a non-empty vector. If loading fails or +// the product of the elements in shape is 0, this function will return an empty vector. +template +std::vector loadWeightFromBinHelper(std::vector shape, std::string filename) +{ + if (shape.size() > 2) + { + printf("[ERROR] shape should have less than two dims \n"); + return std::vector(); + } + size_t dim0 = shape[0], dim1 = 1; + if (shape.size() == 2) + { + dim1 = shape[1]; + } + size_t size = dim0 * dim1; + if (size == 0) + { + TLLM_LOG_WARNING("shape is zero, skip loading weight from file %s \n", filename.c_str()); + return std::vector(); + } + + std::vector host_array(size); + std::ifstream in(filename, std::ios::in | std::ios::binary); + if (!in.is_open()) + { + TLLM_LOG_WARNING("file %s cannot be opened, loading model fails! \n", filename.c_str()); + return std::vector(); + } + + size_t loaded_data_size = sizeof(T) * size; + in.seekg(0, in.end); + in.seekg(0, in.beg); + + TLLM_LOG_DEBUG("Read " + std::to_string(loaded_data_size) + " bytes from " + filename); + in.read((char*) host_array.data(), loaded_data_size); + + size_t in_get_size = in.gcount(); + if (in_get_size != loaded_data_size) + { + TLLM_LOG_WARNING("file %s only has %ld, but request %ld, loading model fails! \n", filename.c_str(), + in_get_size, loaded_data_size); + return std::vector(); + } + in.close(); + // If we succeed, return an array with values. + return host_array; +} + +template +int loadWeightFromBinFunc(T* ptr, std::vector shape, std::string filename) +{ + std::vector host_array = loadWeightFromBinHelper(shape, filename); + + if (host_array.empty()) + { + return 0; + } + + if (std::is_same::value == true) + { + cudaH2Dcpy(ptr, (T*) host_array.data(), host_array.size()); + } + else + { + T_IN* ptr_2 = nullptr; + deviceMalloc(&ptr_2, host_array.size(), false); + cudaH2Dcpy(ptr_2, host_array.data(), host_array.size()); + invokeCudaD2DcpyConvert(ptr, ptr_2, host_array.size()); + deviceFree(ptr_2); + } + return 0; +} + +template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(int8_t* ptr, std::vector shape, std::string filename); +#ifdef ENABLE_BF16 +template int loadWeightFromBinFunc<__nv_bfloat16, float>( + __nv_bfloat16* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc<__nv_bfloat16, half>( + __nv_bfloat16* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc<__nv_bfloat16, __nv_bfloat16>( + __nv_bfloat16* ptr, std::vector shape, std::string filename); +#endif // ENABLE_BF16 +template int loadWeightFromBinFunc(int* ptr, std::vector shape, std::string filename); +#ifdef ENABLE_FP8 +template int loadWeightFromBinFunc<__nv_fp8_e4m3, float>( + __nv_fp8_e4m3* ptr, std::vector shape, std::string filename); +#endif // ENABLE_FP8 + +template +int loadWeightFromBin(T* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type) +{ + switch (model_file_type) + { + case TRTLLMCudaDataType::FP32: loadWeightFromBinFunc(ptr, shape, filename); break; + case TRTLLMCudaDataType::FP16: loadWeightFromBinFunc(ptr, shape, filename); break; + case TRTLLMCudaDataType::INT8: loadWeightFromBinFunc(ptr, shape, filename); break; +#ifdef ENABLE_BF16 + case TRTLLMCudaDataType::BF16: loadWeightFromBinFunc(ptr, shape, filename); break; +#endif +#ifdef ENABLE_FP8 + case TRTLLMCudaDataType::FP8: loadWeightFromBinFunc(ptr, shape, filename); break; +#endif + default: TLLM_LOG_ERROR("Does not support TRTLLMCudaDataType=%d", model_file_type); TLLM_CHECK(false); + } + return 0; +} + +template <> +int loadWeightFromBin(int* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type) +{ + loadWeightFromBinFunc(ptr, shape, filename); + return 0; +} + +template int loadWeightFromBin( + float* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); +template int loadWeightFromBin( + half* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); +template int loadWeightFromBin( + int8_t* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); +#ifdef ENABLE_BF16 +template int loadWeightFromBin( + __nv_bfloat16* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); +#endif +#ifdef ENABLE_FP8 +template int loadWeightFromBin( + __nv_fp8_e4m3* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); +#endif +template int loadWeightFromBin( + int* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); + +template +__global__ void cudaD2DcpyConvert(T_OUT* dst, const T_IN* src, const size_t size) +{ + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) + { + dst[tid] = cuda_cast(src[tid]); + } +} + +template +void invokeCudaD2DcpyConvert(T_OUT* tgt, const T_IN* src, const size_t size, cudaStream_t stream) +{ + cudaD2DcpyConvert<<<256, 256, 0, stream>>>(tgt, src, size); +} + +template void invokeCudaD2DcpyConvert(int8_t* tgt, float const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(float* tgt, int8_t const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(float* tgt, int const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(half* tgt, int const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(float* tgt, float const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(half* tgt, float const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(float* tgt, half const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(uint32_t* tgt, int const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(int* tgt, uint32_t const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(int* tgt, float const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(int* tgt, half const* src, const size_t size, cudaStream_t stream); + +#ifdef ENABLE_BF16 +template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, float const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, int const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(float* tgt, __nv_bfloat16 const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(int* tgt, __nv_bfloat16 const* src, const size_t size, cudaStream_t stream); +#endif // ENABLE_BF16 + +template +__global__ void cudaD2DScaleCpyConvert( + T_OUT* dst, const T_IN* src, float const* scale, bool invert_scale, const size_t size) +{ + float const scale_value = invert_scale ? 1.0f / scale[0] : scale[0]; + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) + { + dst[tid] = cuda_cast(cuda_cast(src[tid]) * scale_value); + } +} + +template +void invokeCudaD2DScaleCpyConvert( + T_OUT* tgt, const T_IN* src, float const* scale, bool invert_scale, const size_t size, cudaStream_t stream) +{ + cudaD2DScaleCpyConvert<<<256, 256, 0, stream>>>(tgt, src, scale, invert_scale, size); +} + +// clang-format off +template void invokeCudaD2DScaleCpyConvert(float* tgt, const int32_t* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); +template void invokeCudaD2DScaleCpyConvert(int32_t* tgt, const float* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); +template void invokeCudaD2DScaleCpyConvert(half* tgt, const int32_t* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); +template void invokeCudaD2DScaleCpyConvert(int32_t* tgt, const half* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeCudaD2DScaleCpyConvert(__nv_bfloat16* tgt, const int32_t* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); +template void invokeCudaD2DScaleCpyConvert(int32_t* tgt, const __nv_bfloat16* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); +#endif // ENABLE_BF16 +#ifdef ENABLE_FP8 +template void invokeCudaD2DScaleCpyConvert(float* tgt, const __nv_fp8_e4m3* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); +#endif // ENABLE_FP8 +// clang-format on + +void invokeCudaD2DcpyHalf2Float(float* dst, half* src, const size_t size, cudaStream_t stream) +{ + invokeCudaD2DcpyConvert(dst, src, size, stream); +} + +void invokeCudaD2DcpyFloat2Half(half* dst, float* src, const size_t size, cudaStream_t stream) +{ + invokeCudaD2DcpyConvert(dst, src, size, stream); +} + +template +void saveToBinary(T const* ptr, const size_t size, std::string filename) +{ + + std::vector h_ptr(size); + cudaD2Hcpy(h_ptr.data(), ptr, size); + std::vector float_ptr(size); + for (size_t i = 0; i < size; i++) + { + float_ptr[i] = (float) h_ptr[i]; + } + + std::ofstream out(filename, std::ios::out | std::ios::binary); + TLLM_CHECK_WITH_INFO(out.is_open(), "Fail to open file " + filename); + + out.write((char*) float_ptr.data(), size * sizeof(float)); +} + +template void saveToBinary(float const* ptr, const size_t size, std::string filename); +template void saveToBinary(half const* ptr, const size_t size, std::string filename); +#ifdef ENABLE_BF16 +template void saveToBinary(__nv_bfloat16 const* ptr, const size_t size, std::string filename); +#endif // ENABLE_BF16 + +template <> +void saveToBinary(int const* ptr, const size_t size, std::string filename) +{ + std::vector h_ptr(size); + cudaD2Hcpy(h_ptr.data(), ptr, size); + std::ofstream out(filename, std::ios::out | std::ios::binary); + TLLM_CHECK_WITH_INFO(out.is_open(), "Fail to open file " + filename); + out.write((char*) h_ptr.data(), size * sizeof(int)); +} + +template +__global__ void fakeCast(T_IN* input_ptr, const size_t size) +{ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) + { + T_fake_type tmp_val = (T_fake_type) ((float) input_ptr[i]); + input_ptr[i] = (T_IN) ((float) tmp_val); + } +} + +template +void invokeFakeCast(T_IN* input_ptr, const size_t size, cudaStream_t stream) +{ + dim3 block(256); + dim3 grid((size + 255) / 256); + fakeCast<<>>(input_ptr, size); +} + +#ifdef ENABLE_FP8 +__global__ void cudaD2Dcpyfp82Float(float* dst, __nv_fp8_e4m3* src, const size_t size) +{ + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) + { + dst[tid] = (float) (src[tid]); + } +} + +void invokeCudaD2Dcpyfp82Float(float* dst, __nv_fp8_e4m3* src, const size_t size, cudaStream_t stream) +{ + cudaD2Dcpyfp82Float<<<256, 256, 0, stream>>>(dst, src, size); +} + +__global__ void cudaD2Dcpyfp82Half(half* dst, __nv_fp8_e4m3* src, const size_t size) +{ + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) + { + dst[tid] = (half) ((float) (src[tid])); + } +} + +void invokeCudaD2Dcpyfp82Half(half* dst, __nv_fp8_e4m3* src, const size_t size, cudaStream_t stream) +{ + cudaD2Dcpyfp82Half<<<256, 256, 0, stream>>>(dst, src, size); +} + +__global__ void cudaD2DcpyFloat2fp8(__nv_fp8_e4m3* dst, float* src, const size_t size) +{ + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) + { + dst[tid] = (__nv_fp8_e4m3) src[tid]; + } +} + +void invokeCudaD2DcpyFloat2fp8(__nv_fp8_e4m3* dst, float* src, const size_t size, cudaStream_t stream) +{ + cudaD2DcpyFloat2fp8<<<256, 256, 0, stream>>>(dst, src, size); +} + +__global__ void cudaD2DcpyHalf2fp8(__nv_fp8_e4m3* dst, half* src, const size_t size) +{ + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) + { + dst[tid] = (__nv_fp8_e4m3) src[tid]; + } +} + +void invokeCudaD2DcpyHalf2fp8(__nv_fp8_e4m3* dst, half* src, const size_t size, cudaStream_t stream) +{ + cudaD2DcpyHalf2fp8<<<256, 256, 0, stream>>>(dst, src, size); +} + +__global__ void cudaD2DcpyBfloat2fp8(__nv_fp8_e4m3* dst, __nv_bfloat16* src, const size_t size) +{ + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) + { + dst[tid] = (__nv_fp8_e4m3) src[tid]; + } +} + +void invokeCudaD2DcpyBfloat2fp8(__nv_fp8_e4m3* dst, __nv_bfloat16* src, const size_t size, cudaStream_t stream) +{ + cudaD2DcpyBfloat2fp8<<<256, 256, 0, stream>>>(dst, src, size); +} + +#endif // ENABLE_FP8 + +template +__global__ void transpose(T_OUT* dst, T_IN* src, const size_t dim0, const size_t dim1) +{ + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < dim0 * dim1; tid += blockDim.x * gridDim.x) + { + const size_t src_col_id = tid % dim1; + const size_t src_row_id = tid / dim1; + dst[src_col_id * dim0 + src_row_id] = (T_OUT) (src[tid]); + } +} + +template +void invokeInPlaceTranspose(T* data, T* workspace, const size_t dim0, const size_t dim1) +{ + // copy data to workspace, and then transpose from workspace to data + cudaD2Dcpy(workspace, data, dim0 * dim1); + transpose<<<256, 256>>>(data, workspace, dim0, dim1); +} + +#ifdef ENABLE_FP8 +template void invokeInPlaceTranspose( + __nv_fp8_e4m3* data, __nv_fp8_e4m3* workspace, const size_t dim0, const size_t dim1); +#endif // ENABLE_FP8 +#ifdef ENABLE_BF16 +template void invokeInPlaceTranspose( + __nv_bfloat16* data, __nv_bfloat16* workspace, const size_t dim0, const size_t dim1); +#endif // ENABLE_BF16 +template void invokeInPlaceTranspose(float* data, float* workspace, const size_t dim0, const size_t dim1); + +template +__global__ void transpose0213( + T_OUT* dst, T_IN* src, const size_t dim0, const size_t dim1, const size_t dim2, const size_t dim3) +{ + // src permutation: [0, 1, 2, 3] + // dst permutation: [0, 2, 1, 3] + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < dim0 * dim1 * dim2 * dim3; + tid += blockDim.x * gridDim.x) + { + size_t tmp_idx = tid; + const size_t dim_3_idx = tmp_idx % dim3; + tmp_idx = (tmp_idx - dim_3_idx) / dim3; + const size_t dim_2_idx = tmp_idx % dim2; + tmp_idx = (tmp_idx - dim_2_idx) / dim2; + const size_t dim_1_idx = tmp_idx % dim1; + tmp_idx = (tmp_idx - dim_1_idx) / dim1; + const size_t dim_0_idx = tmp_idx % dim0; + dst[dim_0_idx * dim1 * dim2 * dim3 + dim_2_idx * dim1 * dim3 + dim_1_idx * dim3 + dim_3_idx] = src[tid]; + } +} + +template +void invokeInPlaceTranspose0213( + T* data, T* workspace, const size_t dim0, const size_t dim1, const size_t dim2, const size_t dim3) +{ + // copy data to workspace, and then transpose from workspace to data + // Note that this kernel is used for pre-processing and not very efficient. + cudaD2Dcpy(workspace, data, dim0 * dim1 * dim2 * dim3); + transpose0213<<<256, 256>>>(data, workspace, dim0, dim1, dim2, dim3); +} + +#ifdef ENABLE_FP8 +template void invokeInPlaceTranspose0213(__nv_fp8_e4m3* data, __nv_fp8_e4m3* workspace, const size_t dim0, + const size_t dim1, const size_t dim2, const size_t dim3); +#endif // ENABLE_FP8 +#ifdef ENABLE_BF16 +template void invokeInPlaceTranspose0213(__nv_bfloat16* data, __nv_bfloat16* workspace, const size_t dim0, + const size_t dim1, const size_t dim2, const size_t dim3); +#endif // ENABLE_BF16 +template void invokeInPlaceTranspose0213( + float* data, float* workspace, const size_t dim0, const size_t dim1, const size_t dim2, const size_t dim3); + +template +__global__ void transpose102(T_OUT* dst, T_IN* src, const size_t dim0, const size_t dim1, const size_t dim2) +{ + // src permutation: [0, 1, 2] + // dst permutation: [1, 0, 2] + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < dim0 * dim1 * dim2; tid += blockDim.x * gridDim.x) + { + size_t tmp_idx = tid; + const size_t dim_2_idx = tmp_idx % dim2; + tmp_idx = (tmp_idx - dim_2_idx) / dim2; + const size_t dim_1_idx = tmp_idx % dim1; + tmp_idx = (tmp_idx - dim_1_idx) / dim1; + const size_t dim_0_idx = tmp_idx % dim0; + dst[dim_1_idx * dim0 * dim2 + dim_0_idx * dim2 + dim_2_idx] = src[tid]; + } +} + +template +void invokeInPlaceTranspose102(T* data, T* workspace, const size_t dim0, const size_t dim1, const size_t dim2) +{ + // copy data to workspace, and then transpose from workspace to data + // Note that this kernel is used for pre-processing and not very efficient. + cudaD2Dcpy(workspace, data, dim0 * dim1 * dim2); + transpose102<<<256, 256>>>(data, workspace, dim0, dim1, dim2); +} + +#ifdef ENABLE_FP8 +template void invokeInPlaceTranspose102( + __nv_fp8_e4m3* data, __nv_fp8_e4m3* workspace, const size_t dim0, const size_t dim1, const size_t dim2); +#endif // ENABLE_FP8 +#ifdef ENABLE_BF16 +template void invokeInPlaceTranspose102( + __nv_bfloat16* data, __nv_bfloat16* workspace, const size_t dim0, const size_t dim1, const size_t dim2); +#endif // ENABLE_BF16 +template void invokeInPlaceTranspose102( + float* data, float* workspace, const size_t dim0, const size_t dim1, const size_t dim2); + +template +void __global__ multiplyScale(T* tensor, float scale, const size_t size) +{ + for (size_t index = threadIdx.x + blockIdx.x * blockDim.x; index < size; index += blockDim.x * gridDim.x) + { + tensor[index] = (T) (((float) tensor[index]) * scale); + } +} + +template +void invokeMultiplyScale(T* tensor, float scale, const size_t size, cudaStream_t stream) +{ + int block = 256; + int grid = (size + 255) / 256; + multiplyScale<<>>(tensor, scale, size); +} + +template void invokeMultiplyScale(float* tensor, float scale, const size_t size, cudaStream_t stream); +template void invokeMultiplyScale(half* tensor, float scale, const size_t size, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeMultiplyScale(__nv_bfloat16* tensor, float scale, const size_t size, cudaStream_t stream); +#endif +#ifdef ENABLE_FP8 +template void invokeMultiplyScale(__nv_fp8_e4m3* tensor, float scale, const size_t size, cudaStream_t stream); +#endif + +template +void __global__ divideScale(T* tensor, float scale, const size_t size) +{ + for (size_t index = threadIdx.x + blockIdx.x * blockDim.x; index < size; index += blockDim.x * gridDim.x) + { + tensor[index] = (T) (((float) tensor[index]) / scale); + } +} + +template +void invokeDivideScale(T* tensor, float scale, const size_t size, cudaStream_t stream) +{ + int block = 256; + int grid = (size + 255) / 256; + divideScale<<>>(tensor, scale, size); +} + +template void invokeDivideScale(float* tensor, float scale, const size_t size, cudaStream_t stream); +template void invokeDivideScale(half* tensor, float scale, const size_t size, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeDivideScale(__nv_bfloat16* tensor, float scale, const size_t size, cudaStream_t stream); +#endif +#ifdef ENABLE_FP8 +template void invokeDivideScale(__nv_fp8_e4m3* tensor, float scale, const size_t size, cudaStream_t stream); +#endif +#ifdef ENABLE_BF16 +template void invokeFakeCast(float* input_ptr, const size_t size, cudaStream_t stream); +template void invokeFakeCast<__nv_bfloat16, __nv_bfloat16>( + __nv_bfloat16* input_ptr, const size_t size, cudaStream_t stream); +template void invokeFakeCast(half* input_ptr, const size_t size, cudaStream_t stream); +#endif +template void invokeFakeCast(float* input_ptr, const size_t size, cudaStream_t stream); +template void invokeFakeCast(float* input_ptr, const size_t size, cudaStream_t stream); +#ifdef ENABLE_FP8 +template void invokeFakeCast(float* input_ptr, const size_t size, cudaStream_t stream); +template void invokeFakeCast(half* input_ptr, const size_t size, cudaStream_t stream); +template void invokeFakeCast<__nv_bfloat16, __nv_fp8_e4m3>( + __nv_bfloat16* input_ptr, const size_t size, cudaStream_t stream); +#endif + +size_t cuda_datatype_size(TRTLLMCudaDataType dt) +{ + static const std::unordered_map sizes{ + {TRTLLMCudaDataType::FP32, sizeof(float)}, {TRTLLMCudaDataType::FP16, sizeof(half)} +#ifdef ENABLE_BF16 + , + {TRTLLMCudaDataType::BF16, sizeof(__nv_bfloat16)} +#endif + }; + + return sizes.at(dt); +} + +template +__global__ void check_range(T const* buffer, size_t size, T min, T max, bool* d_within_range) +{ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) + { + const T val = buffer[i]; + if (val < min || val > max) + { + *d_within_range = false; + } + } +} + +template +bool invokeCheckRange(T const* buffer, const size_t size, T min, T max, bool* d_within_range, cudaStream_t stream) +{ + cudaMemsetAsync(d_within_range, true, sizeof(bool), stream); + + dim3 block(256); + dim3 grid((size + 255) / 256); + check_range<<>>(buffer, size, min, max, d_within_range); + + bool result; + cudaD2Hcpy(&result, d_within_range, 1); + return result; +} + +template bool invokeCheckRange( + int const* buffer, const size_t size, int min, int max, bool* d_within_range, cudaStream_t stream); + +/* + * Determine the total workspace size based on a vector containing multiple variable sizes. + */ +size_t calcAlignedSize(std::vector const& sizes, const size_t ALIGN_BYTES) +{ + const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1); + // Check ALIGN_BYTES is a power of 2 + assert((ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0); + + size_t total = 0; + for (auto sz : sizes) + { + total += (sz + ALIGN_BYTES - 1) & ALIGN_MASK; + } + + // We add extra "ALIGN_BYTES - 1" bytes in case the start address passed to the function calcAlignedPointers() is + // not aligned. + return total + ALIGN_BYTES - 1; +} + +/* + * Given the address of the workspace and the vector containing multiple variable sizes, calculate the start addresses + * of each variable. + */ +void calcAlignedPointers( + std::vector& outPtrs, void const* p, std::vector const& sizes, size_t ALIGN_BYTES) +{ + const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1); + // Check ALIGN_BYTES is a power of 2 + assert((ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0); + + // In case the start address is not aligned + char* ptr = reinterpret_cast((reinterpret_cast(p) + ALIGN_BYTES - 1) & ALIGN_MASK); + + outPtrs.reserve(sizes.size()); + for (auto sz : sizes) + { + outPtrs.push_back(ptr); + ptr += (sz + ALIGN_BYTES - 1) & ALIGN_MASK; + } +} + +} // namespace common +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.h new file mode 100644 index 00000000000..9e413a1beb8 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.h @@ -0,0 +1,292 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/common/cudaFp8Utils.h" +#include "tensorrt_llm/common/cudaUtils.h" + +#include + +namespace tensorrt_llm +{ +namespace common +{ + +template +void deviceMalloc(T** ptr, size_t size, bool is_random_initialize = true); + +template +void deviceMemSetZero(T* ptr, size_t size); + +template + +void deviceFree(T*& ptr); + +template +void deviceFill(T* devptr, size_t size, T value, cudaStream_t stream = 0); + +template +void cudaD2Hcpy(T* tgt, T const* src, size_t const size); + +template +void cudaH2Dcpy(T* tgt, T const* src, size_t const size); + +template +void cudaD2Dcpy(T* tgt, T const* src, size_t const size, cudaStream_t stream = NULL); + +template +void cudaAutoCpy(T* tgt, T const* src, size_t const size, cudaStream_t stream = NULL); + +template +void cudaRandomUniform(T* buffer, size_t const size); + +template +int loadWeightFromBin(T* ptr, std::vector shape, std::string filename, + TRTLLMCudaDataType model_file_type = TRTLLMCudaDataType::FP32); + +// template +// int loadWeightFromBinAndQuantizeForWeightOnly(int8_t* quantized_weight_ptr, +// T* scale_ptr, +// std::vector shape, +// std::string filename, +// TRTLLMCudaDataType model_file_type = TRTLLMCudaDataType::FP32); + +void invokeCudaD2DcpyHalf2Float(float* dst, half* src, size_t const size, cudaStream_t stream); +void invokeCudaD2DcpyFloat2Half(half* dst, float* src, size_t const size, cudaStream_t stream); +#ifdef ENABLE_FP8 +void invokeCudaD2Dcpyfp82Float(float* dst, __nv_fp8_e4m3* src, size_t const size, cudaStream_t stream); +void invokeCudaD2Dcpyfp82Half(half* dst, __nv_fp8_e4m3* src, size_t const size, cudaStream_t stream); +void invokeCudaD2DcpyFloat2fp8(__nv_fp8_e4m3* dst, float* src, size_t const size, cudaStream_t stream); +void invokeCudaD2DcpyHalf2fp8(__nv_fp8_e4m3* dst, half* src, size_t const size, cudaStream_t stream); +void invokeCudaD2DcpyBfloat2fp8(__nv_fp8_e4m3* dst, __nv_bfloat16* src, size_t const size, cudaStream_t stream); +#endif // ENABLE_FP8 +#ifdef ENABLE_BF16 +void invokeCudaD2DcpyBfloat2Float(float* dst, __nv_bfloat16* src, size_t const size, cudaStream_t stream); +#endif // ENABLE_BF16 + +template +void invokeCudaCast(T_OUT* dst, T_IN const* const src, size_t const size, cudaStream_t stream); + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// The following functions implement conversion of multi-dimensional indices to an index in a flat array. +// The shape of the Tensor dimensions is passed as one array (`dims`), the indices are given as individual arguments. +// For examples on how to use these functions, see their tests `test_memory_utils.cu`. +// All of these functions can be evaluated at compile time by recursive template expansion. + +template +__inline__ __host__ __device__ std::enable_if_t::value, T> constexpr flat_index( + T const& acc, TDim dims, TIndex const& index) +{ + assert(index < dims[0]); + return acc * dims[0] + index; +} + +template +__inline__ __host__ __device__ std::enable_if_t::value, T> constexpr flat_index( + T const& acc, TDim dims, TIndex const& index, TIndices... indices) +{ + assert(index < dims[0]); + return flat_index(acc * dims[0] + index, dims + 1, indices...); +} + +template +__inline__ __host__ __device__ std::enable_if_t::value, T> constexpr flat_index( + [[maybe_unused]] TDim dims, T const& index) +{ + assert(index < dims[0]); + return index; +} + +template +__inline__ __host__ __device__ + std::enable_if_t::value, typename std::remove_pointer::type> constexpr flat_index( + TDim dims, TIndex const& index, TIndices... indices) +{ + assert(index < dims[0]); + return flat_index(static_cast::type>(index), dims + 1, indices...); +} + +template +__inline__ __host__ __device__ T constexpr flat_index( + std::array const& dims, TIndex const& index, TIndices... indices) +{ + static_assert(skip < N); + static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions"); + return flat_index(&dims[skip], index, indices...); +} + +template +__inline__ __host__ __device__ T constexpr flat_index( + T const& acc, std::array const& dims, TIndex const& index, TIndices... indices) +{ + static_assert(skip < N); + static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions"); + return flat_index(acc, &dims[skip], index, indices...); +} + +template +__inline__ __host__ __device__ T constexpr flat_index(T const (&dims)[N], TIndex const& index, TIndices... indices) +{ + static_assert(skip < N); + static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions"); + return flat_index(static_cast(dims) + skip, index, indices...); +} + +template +__inline__ __host__ __device__ T constexpr flat_index( + T const& acc, T const (&dims)[N], TIndex const& index, TIndices... indices) +{ + static_assert(skip < N); + static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions"); + return flat_index(acc, static_cast(dims) + skip, index, indices...); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// These are simpler functions for multi-dimensional index conversion. Indices and dimensions are passed as individual +// arguments. These functions are more suitable for usage inside kernels than the corresponding flat_index functions +// which require arrays as arguments. Usage examples can be found in `test_memory_utils.cu`. The functions can be +// evaluated at compile time. + +template +__inline__ __host__ __device__ T constexpr flat_index2(TIndex const& index_0, TIndex const& index_1, T const& dim_1) +{ + assert(index_1 < dim_1); + return index_0 * dim_1 + index_1; +} + +template +__inline__ __host__ __device__ T constexpr flat_index3( + TIndex const& index_0, TIndex const& index_1, TIndex const& index_2, T const& dim_1, T const& dim_2) +{ + assert(index_2 < dim_2); + return flat_index2(index_0, index_1, dim_1) * dim_2 + index_2; +} + +template +__inline__ __host__ __device__ T constexpr flat_index4(TIndex const& index_0, TIndex const& index_1, + TIndex const& index_2, TIndex const& index_3, T const& dim_1, T const& dim_2, T const& dim_3) +{ + assert(index_3 < dim_3); + return flat_index3(index_0, index_1, index_2, dim_1, dim_2) * dim_3 + index_3; +} + +template +__inline__ __host__ __device__ T constexpr flat_index5(TIndex const& index_0, TIndex const& index_1, + TIndex const& index_2, TIndex const& index_3, TIndex const& index_4, T const& dim_1, T const& dim_2, T const& dim_3, + T const& dim_4) +{ + assert(index_4 < dim_4); + return flat_index4(index_0, index_1, index_2, index_3, dim_1, dim_2, dim_3) * dim_4 + index_4; +} + +template +__inline__ __host__ __device__ T constexpr flat_index_strided3( + TIndex const& index_0, TIndex const& index_1, TIndex const& index_2, T const& stride_1, T const& stride_2) +{ + assert(index_1 < stride_1 / stride_2); + assert(index_2 < stride_2); + return index_0 * stride_1 + index_1 * stride_2 + index_2; +} + +template +__inline__ __host__ __device__ T constexpr flat_index_strided4(TIndex const& index_0, TIndex const& index_1, + TIndex const& index_2, TIndex const& index_3, T const& stride_1, T const& stride_2, T const& stride_3) +{ + assert(index_1 < stride_1 / stride_2); + assert(index_2 < stride_2 / stride_3); + assert(index_3 < stride_3); + return index_0 * stride_1 + index_1 * stride_2 + index_2 * stride_3 + index_3; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void invokeInPlaceTranspose(T* data, T* workspace, size_t const dim0, size_t const dim1); + +template +void invokeInPlaceTranspose0213( + T* data, T* workspace, size_t const dim0, size_t const dim1, size_t const dim2, size_t const dim3); + +template +void invokeInPlaceTranspose102(T* data, T* workspace, size_t const dim0, size_t const dim1, size_t const dim2); + +template +void invokeMultiplyScale(T* tensor, float scale, size_t const size, cudaStream_t stream); + +template +void invokeDivideScale(T* tensor, float scale, size_t const size, cudaStream_t stream); + +template +void invokeCudaD2DcpyConvert(T_OUT* tgt, const T_IN* src, size_t const size, cudaStream_t stream = 0); + +template +void invokeCudaD2DScaleCpyConvert( + T_OUT* tgt, const T_IN* src, float const* scale, bool invert_scale, size_t const size, cudaStream_t stream = 0); + +inline bool checkIfFileExist(std::string const& file_path) +{ + std::ifstream in(file_path, std::ios::in | std::ios::binary); + if (in.is_open()) + { + in.close(); + return true; + } + return false; +} + +template +void saveToBinary(T const* ptr, size_t const size, std::string filename); + +template +void invokeFakeCast(T_IN* input_ptr, size_t const size, cudaStream_t stream); + +size_t cuda_datatype_size(TRTLLMCudaDataType dt); + +template +bool invokeCheckRange(T const* buffer, size_t const size, T min, T max, bool* d_within_range, cudaStream_t stream); + +constexpr size_t DEFAULT_ALIGN_BYTES = 256; + +size_t calcAlignedSize(std::vector const& sizes, size_t ALIGN_BYTES = DEFAULT_ALIGN_BYTES); +void calcAlignedPointers(std::vector& outPtrs, void const* p, std::vector const& sizes, + size_t ALIGN_BYTES = DEFAULT_ALIGN_BYTES); + +struct AlignedPointersUnpacker +{ + template + void operator()(T*&... outPtrs) + { + assert(sizeof...(T) == alignedPointers.size()); + auto it = alignedPointers.begin(); + ((outPtrs = static_cast(*it++)), ...); + } + + std::vector alignedPointers; +}; + +AlignedPointersUnpacker inline calcAlignedPointers( + void const* p, std::vector const& sizes, size_t ALIGN_BYTES = DEFAULT_ALIGN_BYTES) +{ + AlignedPointersUnpacker unpacker{}; + calcAlignedPointers(unpacker.alignedPointers, p, sizes, ALIGN_BYTES); + return unpacker; +} + +} // namespace common +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/mpiUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/mpiUtils.cpp new file mode 100644 index 00000000000..dbdaca4ee77 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/mpiUtils.cpp @@ -0,0 +1,588 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "tensorrt_llm/common/mpiUtils.h" + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/runtime/common.h" +#include "tensorrt_llm/runtime/iBuffer.h" + +#include +#include +#include +#include +#include +#ifndef _WIN32 +#include +#endif + +// We rely on SizeType32 being int32_t in some places with weak type checking, +// i.e. we're passing void ptr to some function. To prevent mysterious errors +// in the future, we trigger a compilation error here if SizeType32 isn't int32_t. +static_assert(std::is_same::value); + +namespace tensorrt_llm::mpi +{ + +MPI_Datatype getMpiDtype(MpiType dtype) +{ +#if ENABLE_MULTI_DEVICE + static std::unordered_map const dtype_map{ + {MpiType::kBYTE, MPI_BYTE}, + {MpiType::kHALF, MPI_UINT16_T}, + {MpiType::kFLOAT, MPI_FLOAT}, + {MpiType::kDOUBLE, MPI_DOUBLE}, + {MpiType::kBOOL, MPI_C_BOOL}, + {MpiType::kINT8, MPI_INT8_T}, + {MpiType::kUINT8, MPI_UINT8_T}, + {MpiType::kINT32, MPI_INT32_T}, + {MpiType::kUINT32, MPI_UINT32_T}, + {MpiType::kINT64, MPI_INT64_T}, + {MpiType::kUINT64, MPI_UINT64_T}, + {MpiType::kFP8, MPI_UINT8_T}, + {MpiType::kBF16, MPI_UINT16_T}, + {MpiType::kCHAR, MPI_CHAR}, + }; + return dtype_map.at(dtype); +#else + TLLM_THROW("Multi device support is disabled."); +#endif +} + +MPI_Op getMpiOp(MpiOp op) +{ +#if ENABLE_MULTI_DEVICE + static std::unordered_map const op_map{ + {MpiOp::NULLOP, MPI_OP_NULL}, + {MpiOp::MAX, MPI_MAX}, + {MpiOp::MIN, MPI_MIN}, + {MpiOp::SUM, MPI_SUM}, + {MpiOp::PROD, MPI_PROD}, + {MpiOp::LAND, MPI_LAND}, + {MpiOp::BAND, MPI_BAND}, + {MpiOp::LOR, MPI_LOR}, + {MpiOp::BOR, MPI_BOR}, + {MpiOp::LXOR, MPI_LXOR}, + {MpiOp::BXOR, MPI_BXOR}, + {MpiOp::MINLOC, MPI_MINLOC}, + {MpiOp::MAXLOC, MPI_MAXLOC}, + {MpiOp::REPLACE, MPI_REPLACE}, + }; + return op_map.at(op); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE +} + +namespace +{ + +bool mpiInitialized = false; +std::recursive_mutex mpiMutex; + +MpiComm initLocalSession() +{ +#if ENABLE_MULTI_DEVICE + MPI_Comm localComm = nullptr; + MPI_Comm_split_type(COMM_SESSION, OMPI_COMM_TYPE_HOST, COMM_SESSION.getRank(), MPI_INFO_NULL, &localComm); + MpiComm localSession{localComm, false}; +#else + MpiComm localSession{COMM_SESSION, false}; +#endif // ENABLE_MULTI_DEVICE + return localSession; +} + +} // namespace + +std::vector getWorldRanks(MpiComm const& comm) +{ +#if ENABLE_MULTI_DEVICE + MPI_Group group = nullptr; + MPI_Group worldGroup = nullptr; + + MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup)); + MPICHECK(MPI_Comm_group(comm, &group)); + + int groupSize = 0; + MPICHECK(MPI_Group_size(group, &groupSize)); + std::vector ranks(groupSize); + std::vector worldRanks(groupSize); + std::iota(ranks.begin(), ranks.end(), 0); + + MPICHECK(MPI_Group_translate_ranks(group, groupSize, ranks.data(), worldGroup, worldRanks.data())); + MPICHECK(MPI_Group_free(&group)); + MPICHECK(MPI_Group_free(&worldGroup)); +#else + std::vector worldRanks{0}; +#endif + return worldRanks; +} + +void initialize(MpiThreadSupport threadMode, bool forwardAbortToParent) +{ + // double-checked locking + if (mpiInitialized) + { + return; + } + std::lock_guard lk(mpiMutex); + if (mpiInitialized) + { + return; + } +#if ENABLE_MULTI_DEVICE + int initialized = 0; + TLLM_MPI_CHECK(MPI_Initialized(&initialized)); + if (!initialized) + { + TLLM_LOG_INFO("Initializing MPI with thread mode %d", threadMode); + int providedMode = 0; + auto requiredMode = static_cast(threadMode); + MPICHECK(MPI_Init_thread(nullptr, nullptr, requiredMode, &providedMode)); + TLLM_CHECK_WITH_INFO(providedMode >= requiredMode, "MPI_Init_thread failed"); + std::atexit([]() { MPI_Finalize(); }); + + /* + * We only catch SIGABRT and SIGSEGV because most, of not all errors in the worker will cause one of these 2 + * signals. Signals like SIGINT and SIGTERM should be issued to the parent and should terminate MPI workers + * correctly. + */ + for (int sig : {SIGABRT, SIGSEGV}) + { + __sighandler_t previousHandler = nullptr; + if (forwardAbortToParent) + { + previousHandler = std::signal(sig, + [](int signal) + { +#ifndef _WIN32 + pid_t parentProcessId = getppid(); + kill(parentProcessId, SIGKILL); +#endif + MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); + }); + } + else + { + previousHandler = std::signal(sig, [](int signal) { MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); }); + } + TLLM_CHECK_WITH_INFO(previousHandler != SIG_ERR, "Signal handler setup failed"); + } + + // ensure local MPI communicator is initialized + MpiComm::localSession(); + TLLM_LOG_INFO("Initialized MPI"); + } +#endif // ENABLE_MULTI_DEVICE + mpiInitialized = true; +} + +void MpiComm::barrier() const +{ +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Barrier(mComm)); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE +} + +#if ENABLE_MULTI_DEVICE +template >>> +size_t invokeChunked(TMpiFunc func, TBase* buffer, size_t size, MPI_Datatype dtype, TArgs... args) +{ + constexpr auto maxP1 = static_cast(std::numeric_limits::max()) + 1; + if (TLLM_LIKELY(size < maxP1)) + { + MPICHECK(func(buffer, size, dtype, args...)); + return 1; + } + + constexpr size_t alignment = 256; + int elementSize = 1; + MPICHECK(MPI_Type_size(dtype, &elementSize)); + elementSize = std::min(elementSize, alignment); + + // We cap at max alignment-bytes chunks that can be sent at once. + auto const step = maxP1 - (alignment / elementSize); + + using TCast = std::conditional_t, uint8_t const, uint8_t>; + size_t count = 0; + while (size != 0) + { + auto currentStep = static_cast(std::min(size, step)); + MPICHECK(func(buffer, currentStep, dtype, args...)); + size -= currentStep; + size_t diff = static_cast(currentStep) * elementSize; + buffer = static_cast(buffer) + diff; + ++count; + } + + return count; +} +#endif // ENABLE_MULTI_DEVICE + +std::shared_ptr MpiComm::bcastAsync(void* buffer, size_t size, MpiType dtype, int root) const +{ + std::shared_ptr r = std::make_shared(); +#if ENABLE_MULTI_DEVICE + invokeChunked(MPI_Ibcast, buffer, size, getMpiDtype(dtype), root, mComm, &r->mRequest); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE + return r; +} + +std::shared_ptr MpiComm::bcastAsync(runtime::IBuffer& buf, int root) const +{ + TLLM_CHECK(buf.getMemoryType() != runtime::MemoryType::kGPU); + return bcastAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root); +} + +void MpiComm::bcast(void* buffer, size_t size, MpiType dtype, int root) const +{ +#if ENABLE_MULTI_DEVICE + invokeChunked(MPI_Bcast, buffer, size, getMpiDtype(dtype), root, mComm); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE +} + +void MpiComm::bcast(runtime::IBuffer& buf, int root) const +{ + bcast(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root); +} + +std::shared_ptr MpiComm::sendAsync(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const +{ + TLLM_LOG_DEBUG("start MPI_Isend with size %d", size); + std::shared_ptr r = std::make_shared(); +#if ENABLE_MULTI_DEVICE + invokeChunked(MPI_Isend, buffer, size, getMpiDtype(dtype), dest, tag, mComm, &r->mRequest); +#else + TLLM_THROW("Multi device support is disabled."); +#endif + TLLM_LOG_DEBUG("end MPI_Isend with size %d", size); + return r; +} + +std::shared_ptr MpiComm::sendAsync(runtime::IBuffer const& buf, int dest, int tag) const +{ + return sendAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag); +} + +void MpiComm::send(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const +{ + TLLM_LOG_DEBUG("start MPI_Send with size %d", size); +#if ENABLE_MULTI_DEVICE + invokeChunked(MPI_Send, buffer, size, getMpiDtype(dtype), dest, tag, mComm); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE + TLLM_LOG_DEBUG("end MPI_Send with size %d", size); +} + +void MpiComm::send(runtime::IBuffer const& buf, int dest, int tag) const +{ + send(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag); +} + +MPI_Status MpiComm::recv(void* buffer, size_t size, MpiType dtype, int source, int tag) const +{ + TLLM_LOG_DEBUG("start MPI_Recv with size %d", size); + MPI_Status status{}; +#if ENABLE_MULTI_DEVICE + invokeChunked(MPI_Recv, buffer, size, getMpiDtype(dtype), source, tag, mComm, &status); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE + TLLM_LOG_DEBUG("end MPI_Recv with size %d", size); + return status; +} + +MPI_Status MpiComm::recv(runtime::IBuffer& buf, int source, int tag) const +{ + return recv(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, source, tag); +} + +MpiComm MpiComm::split(int color, int key) const +{ + MPI_Comm splitComm = nullptr; +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Comm_split(mComm, color, key, &splitComm)); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE + return MpiComm{splitComm, true}; +} + +void MpiComm::allreduce(void const* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const +{ +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Allreduce(sendbuf, recvbuf, count, getMpiDtype(dtype), getMpiOp(op), mComm)); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE +} + +void MpiComm::allgather(void const* sendbuf, void* recvbuf, int count, MpiType dtype) const +{ +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Allgather(sendbuf, count, getMpiDtype(dtype), recvbuf, count, getMpiDtype(dtype), mComm)); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE +} + +void MpiComm::allgatherv(void const* sendbuf, int sendcount, MpiType sendtype, void* recvbuf, + std::vector const& recvcounts, std::vector const& displs, MpiType recvtype) const +{ +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Allgatherv(sendbuf, sendcount, getMpiDtype(sendtype), recvbuf, recvcounts.data(), displs.data(), + getMpiDtype(recvtype), mComm)); + +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE +} + +void MpiComm::mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const +{ +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Mprobe(source, tag, mComm, msg, status)); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE +} + +bool MpiComm::improbe(int source, int tag, MPI_Message* msg, MPI_Status* status) const +{ +#if ENABLE_MULTI_DEVICE + int flag{0}; + MPICHECK(MPI_Improbe(source, tag, mComm, &flag, msg, status)); + return flag != 0; +#else + TLLM_THROW("Multi device support is disabled."); + return false; +#endif +} + +bool MpiComm::iprobe(int source, int tag, MPI_Status* status) const +{ +#if ENABLE_MULTI_DEVICE + int flag{0}; + MPICHECK(MPI_Iprobe(source, tag, mComm, &flag, status)); + return flag != 0; +#else + TLLM_THROW("Multi device support is disabled."); + return false; +#endif +} + +void MpiComm::recvPoll(int source, int tag, int periodMs) const +{ + MPI_Status status; + while (!iprobe(source, tag, &status)) + { + std::this_thread::sleep_for(std::chrono::milliseconds(periodMs)); + } +} + +int MpiComm::getRank() const +{ + int rank = 0; +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Comm_rank(mComm, &rank)); +#endif + return rank; +} + +int MpiComm::getSize() const +{ + int world_size = 1; +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Comm_size(mComm, &world_size)); +#endif + return world_size; +} + +MpiComm const& MpiComm::world() +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + static MpiComm commWorld{MPI_COMM_WORLD, false}; + initialize(); + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); + return commWorld; +} + +MpiComm& MpiComm::mutableSession() +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + static MpiComm commSession{MPI_COMM_WORLD, false}; + initialize(); + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); + return commSession; +} + +MpiComm& MpiComm::mutableLocalSession() +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + static MpiComm localSession = initLocalSession(); + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); + return localSession; +} + +void MpiComm::refreshLocalSession() +{ +#if ENABLE_MULTI_DEVICE + static std::mutex mutex; + std::unique_lock lock(mutex); + auto initSessionRanks = getWorldRanks(MpiComm::session()); + auto localSessionRanks = getWorldRanks(MpiComm::localSession()); + + // Add to intersectionRanks in order of initSessionRanks + std::vector intersectionRanks; + std::unordered_set localSessionRanksSet(localSessionRanks.begin(), localSessionRanks.end()); + for (auto rank : initSessionRanks) + { + if (localSessionRanksSet.find(rank) != localSessionRanksSet.end()) + { + intersectionRanks.push_back(rank); + } + } + + MPI_Group worldGroup = nullptr; + MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup)); + MPI_Group localGroup = nullptr; + MPICHECK(MPI_Group_incl(worldGroup, intersectionRanks.size(), intersectionRanks.data(), &localGroup)); + MPI_Comm localComm = nullptr; + MPICHECK(MPI_Comm_create_group(MPI_COMM_WORLD, localGroup, intersectionRanks.front(), &localComm)); + MpiComm::mutableLocalSession().mFreeComm = true; + MpiComm::mutableLocalSession() = MpiComm{localComm, false}; + TLLM_LOG_INFO("Refreshed the MPI local session"); +#endif // ENABLE_MULTI_DEVICE +} + +MpiComm::MpiComm(MPI_Comm g, bool freeComm) + : mComm{g} + , mFreeComm{freeComm} +{ + TLLM_CHECK(mComm != MPI_COMM_NULL); +} + +MpiComm::~MpiComm() noexcept +{ +#if ENABLE_MULTI_DEVICE + if (mFreeComm && mComm) + { + if (MPI_Comm_free(&mComm) != MPI_SUCCESS) + { + TLLM_LOG_ERROR("MPI_Comm_free failed"); + } + } +#endif // ENABLE_MULTI_DEVICE +} + +MpiComm::MpiComm(MpiComm&& comm) noexcept + : mComm{comm.mComm} + , mFreeComm{comm.mFreeComm} +{ + comm.mFreeComm = false; +} + +MpiComm& MpiComm::operator=(MpiComm&& comm) noexcept +{ + this->~MpiComm(); + mComm = comm.mComm; + mFreeComm = comm.mFreeComm; + comm.mFreeComm = false; + return *this; +} + +MpiWaitThread::MpiWaitThread(std::string name, std::function funcWait, std::function funcSetup) + : mName{name.c_str()} + , mFuncWait{funcWait} + , mFuncSetup{funcSetup} +{ + TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); + mThread = std::make_unique(&MpiWaitThread::sideThread, this); + TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); +} + +MpiWaitThread::~MpiWaitThread() +{ + TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); + waitStop(); + mShouldExit.store(true); + notifyStart(); + mThread->join(); + mThread.reset(nullptr); + TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); +} + +void MpiWaitThread::sideThread() +{ + if (mFuncSetup) + { + mFuncSetup(); + } + while (!mShouldExit.load()) + { + notifyStop(); + waitStart(); + mFuncWait(); + } +} + +void MpiWaitThread::waitStart() +{ + TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); + std::unique_lock lock(mMutex); + mCondVar.wait(lock, [this] { return mRunning; }); + TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); +} + +void MpiWaitThread::waitStop() +{ + TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); + std::unique_lock lock(mMutex); + mCondVar.wait(lock, [this] { return !mRunning; }); + TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); +} + +void MpiWaitThread::notifyStart() +{ + TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); + std::lock_guard lock(mMutex); + mRunning = true; + mCondVar.notify_one(); + TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); +} + +void MpiWaitThread::notifyStop() +{ + TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); + std::lock_guard lock(mMutex); + mRunning = false; + mCondVar.notify_one(); + TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); +} + +} // namespace tensorrt_llm::mpi diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/nvtxUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/nvtxUtils.h new file mode 100644 index 00000000000..0a9d51975af --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/nvtxUtils.h @@ -0,0 +1,46 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include + +namespace tensorrt_llm::common::nvtx +{ +inline nvtx3::color nextColor() +{ +#ifndef NVTX_DISABLE + constexpr std::array kColors{nvtx3::color{0xff00ff00}, nvtx3::color{0xff0000ff}, nvtx3::color{0xffffff00}, + nvtx3::color{0xffff00ff}, nvtx3::color{0xff00ffff}, nvtx3::color{0xffff0000}, nvtx3::color{0xffffffff}}; + constexpr auto numColors = kColors.size(); + + static thread_local std::size_t colorId = 0; + auto const color = kColors[colorId]; + colorId = colorId + 1 >= numColors ? 0 : colorId + 1; + return color; +#else + return nvtx3::color{0}; +#endif +} + +} // namespace tensorrt_llm::common::nvtx + +#define NVTX3_SCOPED_RANGE_WITH_NAME(range, name) \ + ::nvtx3::scoped_range range(::tensorrt_llm::common::nvtx::nextColor(), name) +#define NVTX3_SCOPED_RANGE(range) NVTX3_SCOPED_RANGE_WITH_NAME(range##_range, #range) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.cpp new file mode 100644 index 00000000000..39aefda481a --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.cpp @@ -0,0 +1,323 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tensorrt_llm/common/opUtils.h" +#include "tensorrt_llm/common/mpiUtils.h" + +#include "cuda.h" +#include +#include +#include +#include +#include +#include +#include + +#ifdef _MSC_VER +#define FN_NAME __FUNCTION__ +#else +#define FN_NAME __func__ +#endif + +#if ENABLE_MULTI_DEVICE + +std::unordered_map* getDtypeMap() +{ + static std::unordered_map dtypeMap = {{nvinfer1::DataType::kFLOAT, ncclFloat32}, + {nvinfer1::DataType::kHALF, ncclFloat16}, {nvinfer1::DataType::kBF16, ncclBfloat16}}; + return &dtypeMap; +} + +namespace +{ + +// Get NCCL unique ID for a group of ranks. +ncclUniqueId getUniqueId(std::set const& group) noexcept +{ + auto const rank = COMM_SESSION.getRank(); + TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank); + ncclUniqueId id; + if (rank == *group.begin()) + { + NCCLCHECK(ncclGetUniqueId(&id)); + for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it) + { + COMM_SESSION.sendValue(id, *it, 0); + } + } + else + { + COMM_SESSION.recvValue(id, *group.begin(), 0); + } + TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank); + return id; +} +} // namespace + +std::shared_ptr getComm(std::set const& group) +{ + auto const rank = COMM_SESSION.getRank(); + TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank); + static std::map, std::shared_ptr> commMap; + static std::mutex mutex; + std::lock_guard lock(mutex); + std::ostringstream oss; + int index = 0; + for (auto const& rank : group) + { + if (index != 0) + { + oss << ","; + } + oss << rank; + index++; + } + auto groupStr = oss.str(); + auto it = commMap.find(group); + if (it != commMap.end()) + { + auto ncclComm = it->second; + TLLM_LOG_TRACE("NCCL comm for group(%s) is cached for rank %d", groupStr.c_str(), rank); + return ncclComm; + } + + TLLM_LOG_TRACE("Init NCCL comm for group(%s) for rank %d", groupStr.c_str(), rank); + ncclUniqueId id = getUniqueId(group); + int groupRank = 0; + for (auto const& currentRank : group) + { + if (rank == currentRank) + break; + ++groupRank; + } + TLLM_CHECK(groupRank < group.size()); + std::shared_ptr ncclComm(new ncclComm_t, + [](ncclComm_t* comm) + { + ncclCommDestroy(*comm); + delete comm; + }); + NCCLCHECK(ncclCommInitRank(ncclComm.get(), group.size(), id, groupRank)); + commMap[group] = ncclComm; + TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank); + return ncclComm; +} +#endif // ENABLE_MULTI_DEVICE + +void const* tensorrt_llm::common::getCommSessionHandle() +{ +#if ENABLE_MULTI_DEVICE + return &COMM_SESSION; +#else + return nullptr; +#endif // ENABLE_MULTI_DEVICE +} + +namespace +{ + +// Get current cuda context, a default context will be created if there is no context. +inline CUcontext getCurrentCudaCtx() +{ + CUcontext ctx{}; + CUresult err = cuCtxGetCurrent(&ctx); + if (err == CUDA_ERROR_NOT_INITIALIZED || ctx == nullptr) + { + TLLM_CUDA_CHECK(cudaFree(nullptr)); + err = cuCtxGetCurrent(&ctx); + } + TLLM_CHECK(err == CUDA_SUCCESS); + return ctx; +} + +// Helper to create per-cuda-context singleton managed by std::shared_ptr. +// Unlike conventional singletons, singleton created with this will be released +// when not needed, instead of on process exit. +// Objects of this class shall always be declared static / global, and shall never own CUDA +// resources. +template +class PerCudaCtxSingletonCreator +{ +public: + using CreatorFunc = std::function()>; + using DeleterFunc = std::function; + + // creator returning std::unique_ptr is by design. + // It forces separation of memory for T and memory for control blocks. + // So when T is released, but we still have observer weak_ptr in mObservers, the T mem block can be released. + // creator itself must not own CUDA resources. Only the object it creates can. + PerCudaCtxSingletonCreator(CreatorFunc creator, DeleterFunc deleter) + : mCreator{std::move(creator)} + , mDeleter{std::move(deleter)} + { + } + + std::shared_ptr operator()() + { + std::lock_guard lk{mMutex}; + CUcontext ctx{getCurrentCudaCtx()}; + std::shared_ptr result = mObservers[ctx].lock(); + if (result == nullptr) + { + // Create the resource and register with an observer. + result = std::shared_ptr{mCreator().release(), + [this, ctx](T* obj) + { + if (obj == nullptr) + { + return; + } + mDeleter(obj); + + // Clears observer to avoid growth of mObservers, in case users creates/destroys cuda contexts + // frequently. + std::shared_ptr observedObjHolder; // Delay destroy to avoid dead lock. + std::lock_guard lk{mMutex}; + // Must check observer again because another thread may created new instance for this ctx just + // before we lock mMutex. We can't infer that the observer is stale from the fact that obj is + // destroyed, because shared_ptr ref-count checking and observer removing are not in one atomic + // operation, and the observer may be changed to observe another instance. + observedObjHolder = mObservers.at(ctx).lock(); + if (observedObjHolder == nullptr) + { + mObservers.erase(ctx); + } + }}; + mObservers.at(ctx) = result; + } + return result; + } + +private: + CreatorFunc mCreator; + DeleterFunc mDeleter; + mutable std::mutex mMutex; + // CUDA resources are per-context. + std::unordered_map> mObservers; +}; + +template +class PerThreadSingletonCreator +{ +public: + using CreatorFunc = std::function()>; + using DeleterFunc = std::function; + + // creator returning std::unique_ptr is by design. + // It forces separation of memory for T and memory for control blocks. + // So when T is released, but we still have observer weak_ptr in mObservers, the T mem block can be released. + // creator itself must not own CUDA resources. Only the object it creates can. + PerThreadSingletonCreator(CreatorFunc creator, DeleterFunc deleter) + : mCreator{std::move(creator)} + , mDeleter{std::move(deleter)} + { + } + + std::shared_ptr operator()() + { + std::lock_guard lk{mMutex}; + + std::thread::id thread = std::this_thread::get_id(); + std::shared_ptr result = mObservers[thread].lock(); + + if (result == nullptr) + { + // Create the resource and register with an observer. + result = std::shared_ptr{mCreator().release(), + [this, thread](T* obj) + { + if (obj == nullptr) + { + return; + } + mDeleter(obj); + + // Clears observer to avoid growth of mObservers, in case users creates/destroys cuda contexts + // frequently. + std::shared_ptr observedObjHolder; // Delay destroy to avoid dead lock. + std::lock_guard lk{mMutex}; + // Must check observer again because another thread may created new instance for this ctx just + // before we lock mMutex. We can't infer that the observer is stale from the fact that obj is + // destroyed, because shared_ptr ref-count checking and observer removing are not in one atomic + // operation, and the observer may be changed to observe another instance. + observedObjHolder = mObservers.at(thread).lock(); + if (observedObjHolder == nullptr) + { + mObservers.erase(thread); + } + }}; + mObservers.at(thread) = result; + } + return result; + } + +private: + CreatorFunc mCreator; + DeleterFunc mDeleter; + mutable std::mutex mMutex; + // CUDA resources are per-thread. + std::unordered_map> mObservers; +}; + +} // namespace + +std::shared_ptr getCublasHandle() +{ + static PerThreadSingletonCreator creator( + []() -> auto + { + auto handle = std::unique_ptr(new cublasHandle_t); + TLLM_CUDA_CHECK(cublasCreate(handle.get())); + return handle; + }, + [](cublasHandle_t* handle) + { + TLLM_CUDA_CHECK(cublasDestroy(*handle)); + delete handle; + }); + return creator(); +} + +std::shared_ptr getCublasLtHandle() +{ + static PerThreadSingletonCreator creator( + []() -> auto + { + auto handle = std::unique_ptr(new cublasLtHandle_t); + TLLM_CUDA_CHECK(cublasLtCreate(handle.get())); + return handle; + }, + [](cublasLtHandle_t* handle) + { + TLLM_CUDA_CHECK(cublasLtDestroy(*handle)); + delete handle; + }); + return creator(); +} + +std::shared_ptr getCublasMMWrapper(std::shared_ptr cublasHandle, + std::shared_ptr cublasltHandle, cudaStream_t stream, void* workspace) +{ + static PerThreadSingletonCreator creator( + [cublasHandle, cublasltHandle, stream, workspace]() -> auto + { + auto wrapper = std::unique_ptr( + new tensorrt_llm::common::CublasMMWrapper(cublasHandle, cublasltHandle, stream, workspace)); + return wrapper; + }, + [](tensorrt_llm::common::CublasMMWrapper* wrapper) { delete wrapper; }); + return creator(); +} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.h new file mode 100644 index 00000000000..4e278e5cf23 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.h @@ -0,0 +1,215 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/common/cublasMMWrapper.h" +#include "tensorrt_llm/common/workspace.h" + +#include +#include +#include +#include +#if ENABLE_MULTI_DEVICE +#include +#endif // ENABLE_MULTI_DEVICE + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tensorrt_llm::common +{ + +// Write values into buffer +template +void write(char*& buffer, T const& val) +{ + std::memcpy(buffer, &val, sizeof(T)); + buffer += sizeof(T); +} + +// Read values from buffer +template +void read(char const*& buffer, T& val) +{ + std::memcpy(&val, buffer, sizeof(T)); + buffer += sizeof(T); +} + +// Like std::unique_ptr, but does not prevent generation of default copy constructor when used as class members. +// The copy constructor produces nullptr. So the plugin default copy constructor will not really copy this, and +// your clone() implementation is responsible for initializing such data members. +// With this we can simplify clone() implementation when there are many data members including at least one unique_ptr. +template > +class UniqPtrWNullCopy : public std::unique_ptr +{ +public: + using std::unique_ptr::unique_ptr; + + // for compatibility with std::make_unique + explicit UniqPtrWNullCopy(std::unique_ptr&& src) + : std::unique_ptr::unique_ptr{std::move(src)} + { + } + + // copy constructor produces nullptr + UniqPtrWNullCopy(UniqPtrWNullCopy const&) + : std::unique_ptr::unique_ptr{} + { + } +}; + +// for testing only +void const* getCommSessionHandle(); +} // namespace tensorrt_llm::common + +inline bool isBuilding() +{ + auto constexpr key = "IS_BUILDING"; + auto const val = getenv(key); + return val != nullptr && std::string(val) == "1"; +} + +#if ENABLE_MULTI_DEVICE +#define NCCLCHECK(cmd) \ + do \ + { \ + ncclResult_t r = cmd; \ + if (r != ncclSuccess) \ + { \ + printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +std::unordered_map* getDtypeMap(); + +std::shared_ptr getComm(std::set const& group); + +#endif // ENABLE_MULTI_DEVICE + +//! To save GPU memory, all the plugins share the same cublas and cublasLt handle globally. +//! Get cublas and cublasLt handle for current cuda context +std::shared_ptr getCublasHandle(); +std::shared_ptr getCublasLtHandle(); +std::shared_ptr getCublasMMWrapper(std::shared_ptr cublasHandle, + std::shared_ptr cublasltHandle, cudaStream_t stream, void* workspace); + +#ifndef DEBUG + +#define PLUGIN_CHECK(status) \ + do \ + { \ + if (status != 0) \ + abort(); \ + } while (0) + +#define ASSERT_PARAM(exp) \ + do \ + { \ + if (!(exp)) \ + return STATUS_BAD_PARAM; \ + } while (0) + +#define ASSERT_FAILURE(exp) \ + do \ + { \ + if (!(exp)) \ + return STATUS_FAILURE; \ + } while (0) + +#define CSC(call, err) \ + do \ + { \ + cudaError_t cudaStatus = call; \ + if (cudaStatus != cudaSuccess) \ + { \ + return err; \ + } \ + } while (0) + +#define DEBUG_PRINTF(...) \ + do \ + { \ + } while (0) + +#else + +#define ASSERT_PARAM(exp) \ + do \ + { \ + if (!(exp)) \ + { \ + fprintf(stderr, "Bad param - " #exp ", %s:%d\n", __FILE__, __LINE__); \ + return STATUS_BAD_PARAM; \ + } \ + } while (0) + +#define ASSERT_FAILURE(exp) \ + do \ + { \ + if (!(exp)) \ + { \ + fprintf(stderr, "Failure - " #exp ", %s:%d\n", __FILE__, __LINE__); \ + return STATUS_FAILURE; \ + } \ + } while (0) + +#define CSC(call, err) \ + do \ + { \ + cudaError_t cudaStatus = call; \ + if (cudaStatus != cudaSuccess) \ + { \ + printf("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(cudaStatus)); \ + return err; \ + } \ + } while (0) + +#define PLUGIN_CHECK(status) \ + { \ + if (status != 0) \ + { \ + DEBUG_PRINTF("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(status)); \ + abort(); \ + } \ + } + +#define DEBUG_PRINTF(...) \ + do \ + { \ + printf(__VA_ARGS__); \ + } while (0) + +#endif // DEBUG + +#define NVML_CHECK(cmd) \ + do \ + { \ + nvmlReturn_t r = cmd; \ + if (r != NVML_SUCCESS) \ + { \ + printf("Failed, NVML error %s:%d '%s'\n", __FILE__, __LINE__, nvmlErrorString(r)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/quantTypeUtils.cuh b/sgl-kernel/3rdparty/tensorrt_llm/common/quantTypeUtils.cuh new file mode 100644 index 00000000000..a228d3f9fc6 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/quantTypeUtils.cuh @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh" +#include "tensorrt_llm/common/cudaFp8Utils.h" +#include +#include +#include + +namespace tensorrt_llm +{ +namespace common +{ + +template +struct QuantTypeStaticVals; + +template <> +struct QuantTypeStaticVals +{ + static constexpr float MAX_VAL = 127.f; + static constexpr float MIN_SCALING_FACTOR = 0.f; + static constexpr float MIN_SCALING_FACTOR_RCP = FLT_MAX; +}; + +#ifdef ENABLE_FP8 + +template <> +struct QuantTypeStaticVals<__nv_fp8_e4m3> +{ + static constexpr float MAX_VAL = 448.f; + // Ref: https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L720 + static constexpr float MIN_SCALING_FACTOR = 1.0f / (448.f * 512.f); + static constexpr float MIN_SCALING_FACTOR_RCP = (448.f * 512.f); +}; + +#endif // ENABLE_FP8 + +} // namespace common +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/reduceKernelUtils.cuh b/sgl-kernel/3rdparty/tensorrt_llm/common/reduceKernelUtils.cuh new file mode 100644 index 00000000000..c5a4fe0e24e --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/reduceKernelUtils.cuh @@ -0,0 +1,399 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) +#include +#else +#include +#endif +#include "tensorrt_llm/common/cudaTypeUtils.cuh" +#include +#include +#include +#include +#include + +namespace cg = cooperative_groups; + +namespace tensorrt_llm +{ +namespace common +{ + +template +struct BytesToType; + +template <> +struct BytesToType<1> +{ + using type = uint8_t; +}; + +template <> +struct BytesToType<2> +{ + using type = uint16_t; +}; + +template <> +struct BytesToType<4> +{ + using type = uint32_t; +}; + +template <> +struct BytesToType<8> +{ + using type = uint64_t; +}; + +template <> +struct BytesToType<16> +{ + using type = float4; +}; + +template +__device__ inline void copy(void const* local, void* data) +{ + using T = typename BytesToType::type; + + T const* in = static_cast(local); + T* out = static_cast(data); + *out = *in; +} + +static float constexpr HALF_FLT_MAX = 65504.F; +#define FINAL_MASK 0xffffffff + +template +__inline__ __device__ T warpReduceSum(T val) +{ +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val = add(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80 + return val; +} + +/* Calculate the sum of all elements in a block */ +template +__inline__ __device__ T blockReduceSum(T val) +{ + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + val = warpReduceSum(val); + + if (lane == 0) + shared[wid] = val; + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T) (0.0f); + val = warpReduceSum(val); + + return val; +} + +template +__inline__ __device__ T warpReduceMax(T val) +{ +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); + return val; +} + +/* Calculate the maximum of all elements in a block */ +template +__inline__ __device__ T blockReduceMax(T val) +{ + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; // in-warp idx + int wid = threadIdx.x >> 5; // warp idx + + val = warpReduceMax(val); // get maxx in each warp + + if (lane == 0) // record in-warp maxx by warp Idx + shared[wid] = val; + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : -1e20f; + val = warpReduceMax(val); + + return val; +} + +/* Calculate the maximum of all elements in a block */ +template +__inline__ __device__ T blockAllReduceMax(T val) +{ + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; // in-warp idx + int wid = threadIdx.x >> 5; // warp idx + + val = warpReduceMax(val); // get maxx in each warp + + if (lane == 0) // record in-warp maxx by warp Idx + shared[wid] = val; + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (lane < (blockDim.x / 32.f)) ? shared[lane] : -1e20f; + val = warpReduceMax(val); + + return val; +} + +template +__inline__ __device__ T warpReduceSumV2(T* val) +{ +#pragma unroll + for (int i = 0; i < NUM; i++) + { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); + } + return (T) (0.0f); +} + +template +__inline__ __device__ T blockReduceSumV2(T* val) +{ + static __shared__ T shared[NUM][33]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduceSumV2(val); + + if (lane == 0) + { +#pragma unroll + for (int i = 0; i < NUM; i++) + { + shared[i][wid] = val[i]; + } + } + + __syncthreads(); + + bool is_mask = threadIdx.x < (blockDim.x / 32.f); +#pragma unroll + for (int i = 0; i < NUM; i++) + { + val[i] = is_mask ? shared[i][lane] : (T) (0.0f); + } + warpReduceSumV2(val); + return (T) 0.0f; +} + +template +__inline__ __device__ T warpReduceMaxV2(T* val) +{ +#pragma unroll + for (int i = 0; i < NUM; i++) + { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32)); + } + return (T) (0.0f); +} + +template +__inline__ __device__ T blockReduceMaxV2(T* val) +{ + static __shared__ T shared[32][NUM]; + int lane = threadIdx.x & 0x1f; // in-warp idx + int wid = threadIdx.x >> 5; // warp idx + + warpReduceMaxV2(val); // get maxx in each warp + + if (lane == 0) // record in-warp maxx by warp Idx + { +#pragma unroll + for (int i = 0; i < NUM; i++) + { + shared[wid][i] = val[i]; + } + } + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + bool is_mask = threadIdx.x < (blockDim.x / 32.f); +#pragma unroll + for (int i = 0; i < NUM; i++) + { + val[i] = is_mask ? shared[lane][i] : (T) -1e20f; + } + warpReduceMaxV2(val); + + return (T) 0.0f; +} + +template +__inline__ __device__ void cgBlockReduceSumElements(float* element_list, float* cgBlockReduceSumElements_shm) +{ + cg::thread_block cta = cg::this_thread_block(); + cg::thread_block_tile<32> tile = cg::tiled_partition<32>(cta); + + int const tid = cta.thread_rank(); + int const blockz = blockDim.x; + for (int i = 0; i < NUM; i++) + { +#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) + cgBlockReduceSumElements_shm[i * blockz + tid] = cg::reduce(tile, element_list[i], cg::plus()); +#else + // TODO Add implementation here + if (threadIdx.x == 0 && blockIdx.x == 0) + { + printf("[ERROR] Not support cgBlockReduceSumElements when CUDA < 11 \n"); + assert(false); + } +#endif + } + cg::sync(cta); + if (tid == 0) + { +#pragma unroll + for (int i = 0; i < NUM; i++) + { + float beta = 0.0f; + for (int j = 0; j < blockz; j += 32) + { + beta += cgBlockReduceSumElements_shm[i * blockz + j]; + } + element_list[i] = beta; + } + } +} + +template +struct TopK +{ + int p[MAX_K]; // index, being -1 at the tail if the array is not full + T u[MAX_K]; // value in descend order, being -MAX_T_VAL if the element is invalid + + __device__ __forceinline__ void insert(T const elem, int const elem_id) + { + if (elem_id < 0) + { + return; + } + // Condition of updating the array + // 1. array is not full + // 2. elem is greater than the smallest (last) element in the array + // 3. elem is equal to the smallest (last) element in the array but its elem_id is smaller + bool const need_update + = (p[MAX_K - 1] == -1 || elem > u[MAX_K - 1] || elem == u[MAX_K - 1] && elem_id < p[MAX_K - 1]); + if (!need_update) + { + return; + } + // Find suitable index for the new element + int i; + for (i = MAX_K - 2; i >= 0; --i) + { + bool const need_decrease = (p[i] == -1 || elem > u[i] || elem == u[i] && elem_id < p[i]); + if (!need_decrease) + break; + } + // Move elements to correct positions + for (int k = MAX_K - 2; k >= i; --k) + { + p[k + 1] = p[k]; + u[k + 1] = u[k]; + } + p[i] = elem_id; + u[i] = elem; + } + + __device__ __forceinline__ void init() + { + T const MAX_T_VAL = (std::is_same::value) ? HALF_FLT_MAX : FLT_MAX; + for (int i = 0; i < MAX_K; i++) + { + p[i] = -1; + u[i] = -MAX_T_VAL; + } + } +}; + +template +__device__ __forceinline__ TopK reduce_topk_op(TopK const& a, TopK const& b) +{ + TopK res = a; + for (int i = 0; i < MAX_K; ++i) + res.insert(b.u[i], b.p[i]); + return res; +} + +template +struct TopK_2 +{ + int p = -1; + T u = -((std::is_same::value) ? HALF_FLT_MAX : FLT_MAX); + + __device__ __forceinline__ void insert(T elem, int elem_id) + { + if (elem > u) + { + u = elem; + p = elem_id; + } + } + + __device__ __forceinline__ void init() + { + u = -((std::is_same::value) ? HALF_FLT_MAX : FLT_MAX); + p = -1; + } +}; + +template +__device__ __forceinline__ TopK_2 reduce_topk_op_2(TopK_2 const& a, TopK_2 const& b) +{ + return a.u > b.u ? a : b; +} + +template +__device__ __forceinline__ T clamp_inf_for_half(float const input) +{ + return input; +} + +template <> +__device__ __forceinline__ half clamp_inf_for_half(float const input) +{ + // clamp inf values to enable fp16 training + return input > 0.0f ? (half) min(input, HALF_FLT_MAX - 1000) : (half) max(input, -HALF_FLT_MAX + 1000); +} + +} // namespace common +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/stlUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/stlUtils.h new file mode 100644 index 00000000000..9cda9fa0d42 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/stlUtils.h @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace tensorrt_llm::common::stl_utils +{ + +template +constexpr TOutputIt basicInclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst, TBinOp op) +{ + if (first != last) + { + auto val = *first; + while (true) + { + *dFirst = val; + ++dFirst; + ++first; + if (first == last) + { + break; + } + val = op(std::move(val), *first); + } + } + return dFirst; +} + +template +constexpr TOutputIt inclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst) +{ +#if defined(__GNUC__) && __GNUC__ <= 8 + return basicInclusiveScan(first, last, dFirst, std::plus<>{}); +#else + return std::inclusive_scan(first, last, dFirst); +#endif +} + +template +constexpr TOutputIt basicExclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst, T init, TBinOp op) +{ + if (first != last) + { + while (true) + { + T tmp{op(init, *first)}; + *dFirst = init; + ++dFirst; + ++first; + if (first == last) + { + break; + } + init = std::move(tmp); + } + } + return dFirst; +} + +template +constexpr TOutputIt exclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst, T init) +{ +#if defined(__GNUC__) && __GNUC__ <= 8 + return basicExclusiveScan(first, last, dFirst, std::move(init), std::plus<>{}); +#else + return std::exclusive_scan(first, last, dFirst, std::move(init)); +#endif +} + +template +struct HasOperatorOutput : std::false_type +{ +}; + +template +struct HasOperatorOutput() << std::declval()))>> + : std::true_type +{ +}; + +template +std::string toString(T const& t, typename std::enable_if_t::value, int> = 0) +{ + std::ostringstream oss; + oss << t; + return oss.str(); +} + +template +std::string toString(std::optional const& t, typename std::enable_if_t::value, int> = 0) +{ + std::ostringstream oss; + if (t) + { + oss << t.value(); + } + else + { + oss << "None"; + } + return oss.str(); +} + +} // namespace tensorrt_llm::common::stl_utils diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.cpp new file mode 100644 index 00000000000..f1c6f88b431 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.cpp @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/stringUtils.h" +#include "tensorrt_llm/common/assert.h" + +#include +#include +#include +#include +#include + +namespace tensorrt_llm::common +{ + +namespace +{ +std::string vformat(char const* fmt, va_list args) +{ + va_list args0; + va_copy(args0, args); + auto const size = vsnprintf(nullptr, 0, fmt, args0); + if (size <= 0) + return ""; + + std::string stringBuf(size, char{}); + auto const size2 = std::vsnprintf(&stringBuf[0], size + 1, fmt, args); + + TLLM_CHECK_WITH_INFO(size2 == size, std::string(std::strerror(errno))); + + return stringBuf; +} + +} // namespace + +std::string fmtstr(char const* format, ...) +{ + va_list args; + va_start(args, format); + std::string result = vformat(format, args); + va_end(args); + return result; +}; + +std::unordered_set str2set(std::string const& input, char delimiter) +{ + std::unordered_set values; + if (!input.empty()) + { + std::stringstream valStream(input); + std::string val; + while (std::getline(valStream, val, delimiter)) + { + if (!val.empty()) + { + values.insert(val); + } + } + } + return values; +}; + +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.cpp new file mode 100644 index 00000000000..c00041abdac --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "tensorrt_llm/common/timestampUtils.h" + +namespace tensorrt_llm::common +{ + +std::string getCurrentTimestamp() +{ + auto now = std::chrono::system_clock::now(); + auto now_t = std::chrono::system_clock::to_time_t(now); + auto tm = *std::localtime(&now_t); + + auto epoch_to_now = now.time_since_epoch(); + auto seconds = std::chrono::duration_cast(epoch_to_now); + auto us = std::chrono::duration_cast(epoch_to_now - seconds); + + std::ostringstream stream; + stream << std::put_time(&tm, "%m-%d-%Y %H:%M:%S"); + stream << "." << std::setfill('0') << std::setw(6) << us.count(); + return stream.str(); +} + +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.h new file mode 100644 index 00000000000..f52f23028c1 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace tensorrt_llm::common +{ + +/// @brief Get the current timestamp in the format "MM-DD-YYYY HH:MM:SS:uuuuuu" +std::string getCurrentTimestamp(); + +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.cpp new file mode 100644 index 00000000000..b410613d055 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.cpp @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/tllmException.h" +#include "tensorrt_llm/common/stringUtils.h" + +#include +#if !defined(_MSC_VER) +#include +#include +#include +#endif +#include + +namespace tensorrt_llm::common +{ + +namespace +{ +int constexpr VOID_PTR_SZ = 2 + sizeof(void*) * 2; +} + +#if !defined(_MSC_VER) + +TllmException::TllmException(char const* file, std::size_t line, std::string const& msg) + : std::runtime_error{""} +{ + mNbFrames = backtrace(mCallstack.data(), MAX_FRAMES); + auto const trace = getTrace(); + std::runtime_error::operator=( + std::runtime_error{fmtstr("%s (%s:%zu)\n%s", msg.c_str(), file, line, trace.c_str())}); +} +#else +TllmException::TllmException(char const* file, std::size_t line, std::string const& msg) + : mNbFrames{} + , std::runtime_error{fmtstr("%s (%s:%zu)", msg.c_str(), file, line)} +{ +} +#endif + +TllmException::~TllmException() noexcept = default; + +std::string TllmException::getTrace() const +{ +#if defined(_MSC_VER) + return ""; +#else + auto const trace = backtrace_symbols(mCallstack.data(), mNbFrames); + std::ostringstream buf; + for (auto i = 1; i < mNbFrames; ++i) + { + Dl_info info; + if (dladdr(mCallstack[i], &info) && info.dli_sname) + { + auto const clearName = demangle(info.dli_sname); + buf << fmtstr("%-3d %*p %s + %zd", i, VOID_PTR_SZ, mCallstack[i], clearName.c_str(), + static_cast(mCallstack[i]) - static_cast(info.dli_saddr)); + } + else + { + buf << fmtstr("%-3d %*p %s", i, VOID_PTR_SZ, mCallstack[i], trace[i]); + } + if (i < mNbFrames - 1) + buf << std::endl; + } + + if (mNbFrames == MAX_FRAMES) + buf << std::endl << "[truncated]"; + + std::free(trace); + return buf.str(); +#endif +} + +std::string TllmException::demangle(char const* name) +{ +#if defined(_MSC_VER) + return name; +#else + std::string clearName{name}; + auto status = -1; + auto const demangled = abi::__cxa_demangle(name, nullptr, nullptr, &status); + if (status == 0) + { + clearName = demangled; + std::free(demangled); + } + return clearName; +#endif +} + +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/workspace.h b/sgl-kernel/3rdparty/tensorrt_llm/common/workspace.h new file mode 100644 index 00000000000..1406e821333 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/workspace.h @@ -0,0 +1,87 @@ +/* + * Copyright (c) 1993-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include + +namespace tensorrt_llm::common +{ + +std::uintptr_t constexpr kCudaMemAlign = 128; + +inline int8_t* alignPtr(int8_t* ptr, uintptr_t to) +{ + uintptr_t addr = (uintptr_t) ptr; + if (addr % to) + { + addr += to - addr % to; + } + return (int8_t*) addr; +} + +constexpr size_t alignSize(size_t size, size_t to) +{ + if ((size % to) != 0U) + { + size += to - size % to; + } + return size; +} + +inline int8_t* nextWorkspacePtrCommon(int8_t* ptr, uintptr_t previousWorkspaceSize, uintptr_t const alignment) +{ + uintptr_t addr = (uintptr_t) ptr; + addr += previousWorkspaceSize; + return alignPtr((int8_t*) addr, alignment); +} + +inline int8_t* nextWorkspacePtr(int8_t* ptr, uintptr_t previousWorkspaceSize) +{ + return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, kCudaMemAlign); +} + +inline int8_t* nextWorkspacePtr( + int8_t* const base, uintptr_t& offset, uintptr_t const size, uintptr_t const alignment = kCudaMemAlign) +{ + uintptr_t curr_offset = offset; + uintptr_t next_offset = curr_offset + ((size + alignment - 1) / alignment) * alignment; + int8_t* newptr = size == 0 ? nullptr : base + curr_offset; + offset = next_offset; + return newptr; +} + +inline int8_t* nextWorkspacePtrWithAlignment( + int8_t* ptr, uintptr_t previousWorkspaceSize, uintptr_t const alignment = kCudaMemAlign) +{ + return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, alignment); +} + +inline size_t calculateTotalWorkspaceSize( + size_t const* workspaces, int count, uintptr_t const alignment = kCudaMemAlign) +{ + size_t total = 0; + for (int i = 0; i < count; i++) + { + total += workspaces[i]; + if (workspaces[i] % alignment) + { + total += alignment - (workspaces[i] % alignment); + } + } + return total; +} + +}; // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/copy_red_global.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/copy_red_global.hpp new file mode 100644 index 00000000000..61a41031bfb --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/copy_red_global.hpp @@ -0,0 +1,352 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include + +// Config + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDACC_VER_MAJOR__ >= 10)) +#define CUTE_ARCH_RED_F16_SM70_ENABLED +#endif + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) +#define CUTE_ARCH_RED_VEC_SM90_ENABLED +#define CUTE_ARCH_RED_BF16_SM90_ENABLED +#endif + +namespace cute +{ + +////////////////////////////////// +// Wrapper around CUDA's atomicAdd +////////////////////////////////// + +template +struct TypedAtomicAdd +{ + using SRegisters = T[1]; + using DRegisters = T[1]; + + CUTE_HOST_DEVICE static constexpr void copy(T const& src, T& dst) + { + atomicAdd(&dst, src); + } +}; + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +////////////////////////////////// +// F16 ADD PTX +////////////////////////////////// + +struct SM70_RED_ADD_NOFTZ_F16 +{ + using SRegisters = uint16_t[1]; + using DRegisters = uint16_t[1]; + + CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst) + { +#if defined(CUTE_ARCH_RED_F16_SM70_ENABLED) + asm volatile("red.global.add.noftz.f16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED."); +#endif + } +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +struct SM70_RED_ADD_NOFTZ_F16x2 +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst) + { +#if defined(CUTE_ARCH_RED_F16_SM70_ENABLED) + asm volatile("red.global.add.noftz.f16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED."); +#endif + } +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +struct SM90_RED_ADD_NOFTZ_F16x2_V2 +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint64_t[1]; + + CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst) + { +#if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED) + asm volatile("red.global.add.noftz.v2.f16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED."); +#endif + } +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +struct SM90_RED_ADD_NOFTZ_F16x2_V4 +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void copy( + uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst) + { +#if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED) + asm volatile("red.global.add.noftz.v4.f16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1), + "r"(src2), "r"(src3)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED."); +#endif + } +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +////////////////////////////////// +// BF16 ADD PTX +////////////////////////////////// + +struct SM90_RED_ADD_NOFTZ_BF16 +{ + using SRegisters = uint16_t[1]; + using DRegisters = uint16_t[1]; + + CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst) + { +#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED) + asm volatile("red.global.add.noftz.bf16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."); +#endif + } +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +////////////////////////////////// + +struct SM90_RED_ADD_NOFTZ_BF16x2 +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst) + { +#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED) + asm volatile("red.global.add.noftz.bf16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."); +#endif + } +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +////////////////////////////////// + +struct SM90_RED_ADD_NOFTZ_BF16x2_V2 +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint64_t[1]; + + CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst) + { +#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED) + asm volatile("red.global.add.noftz.v2.bf16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."); +#endif + } +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +////////////////////////////////// + +struct SM90_RED_ADD_NOFTZ_BF16x2_V4 +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void copy( + uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst) + { +#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED) + asm volatile("red.global.add.noftz.v4.bf16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1), + "r"(src2), "r"(src3)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."); +#endif + } +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +////////////////////////////////// + +} // end namespace cute diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/mma.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/mma.h new file mode 100644 index 00000000000..2362da4f7f2 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/mma.h @@ -0,0 +1,120 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates exposing architecture support for multiply-add operations +*/ + +#pragma once +#include "cutlass_extensions/weight_only_quant_op.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace arch +{ + +// Tag which triggers MMA which will trigger +struct OpMultiplyAddDequantizeInterleavedBToA; + +/* + Below we have extra tags to signal what kind of dequantization we want to do + (per col, scale only fine grained, finegrained with zero). This still lets us + the existing template infrastructure (incl. that in CUTLASS). However, we + split out the template below into OpMultiplyAddDequantizeInterleavedBToA along + with the quantization op before instantiating the GEMM pieces. + + Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount of + code we need to duplicate. + */ +struct OpMultiplyAddDequantizeInterleavedBToA_percol_scale; +struct OpMultiplyAddDequantizeInterleavedBToA_fine_scale; +struct OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias; + +// The default just forwards the original operator +template +struct TagOperator +{ + using TaggedOperator = MmaOp; +}; + +// Specializations below attach more information to the operator +template <> +struct TagOperator +{ + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale; +}; + +template <> +struct TagOperator +{ + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scale; +}; + +template <> +struct TagOperator +{ + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias; +}; + +// Here we instantiate some structs to "detag" the tagged operator. It splits it back to the original +// operator + the extra information. If no extra info was tagged, the dequant op per column scaling +// as a default. +template +struct DetagOperator +{ + using Operator = TaggedMmaOp; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; +}; + +template <> +struct DetagOperator +{ + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; +}; + +template <> +struct DetagOperator +{ + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; +}; + +template <> +struct DetagOperator +{ + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; +}; + +} // namespace arch +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/compute_occupancy.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/compute_occupancy.h new file mode 100644 index 00000000000..c83a9a074da --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/compute_occupancy.h @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "cutlass/device_kernel.h" +#include "tensorrt_llm/common/cudaUtils.h" + +namespace tensorrt_llm +{ +namespace cutlass_extensions +{ + +template +inline int compute_occupancy_for_kernel() +{ + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size > (48 << 10)) + { + cudaFuncAttributes attr; + int device = 0; + int max_smem_per_block = 0; + tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device)); + tensorrt_llm::common::check_cuda_error( + cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); + if constexpr (enable_cutlass_3x) + { + tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, cutlass::device_kernel)); + } + else + { + tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, cutlass::Kernel)); + } + if (smem_size + attr.sharedSizeBytes >= static_cast(max_smem_per_block)) + { + // This should mean that + // cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) + // wouldn't work. In that case, we return an occupancy of 0. This will cause the heuristic to ignore this + // configuration. + return 0; + } + + if constexpr (enable_cutlass_3x) + { + tensorrt_llm::common::check_cuda_error(cudaFuncSetAttribute( + cutlass::device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + else + { + tensorrt_llm::common::check_cuda_error(cudaFuncSetAttribute( + cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + } + + int max_active_blocks = -1; + if constexpr (enable_cutlass_3x) + { + tensorrt_llm::common::check_cuda_error( + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, cutlass::device_kernel, + 128 * (GemmKernel::NumLoadWarpGroups + GemmKernel::NumMmaWarpGroups), smem_size)); + } + else + { + tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, cutlass::Kernel, GemmKernel::kThreadCount, smem_size)); + } + + return max_active_blocks; +} + +} // namespace cutlass_extensions +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp new file mode 100644 index 00000000000..bba25ec23a9 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp @@ -0,0 +1,550 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/fast_math.h" + +#include "cute/numeric/numeric_types.hpp" +#include "cute/tensor.hpp" +#include "cutlass/trace.h" + +#include "cutlass_extensions/arch/copy_red_global.hpp" +#include "cutlass_extensions/util/gather_tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace epilogue +{ +namespace collective +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class EpilogueMoeFusedFinalize +{ +public: + using EpilogueSchedule = PtrArrayNoSmemWarpSpecialized; + using DispatchPolicy = PtrArrayNoSmemWarpSpecialized; + + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementIntermediate = typename ThreadEpilogueOp::ElementD; + + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using InternalStrideC = cute::remove_pointer_t; + using ElementD = ElementD_; + using StrideD = StrideD_; + using InternalStrideD = cute::remove_pointer_t; + + static_assert(!is_same_v, "Stride C must be a pointer"); + static_assert(is_same_v, "Stride D must not be a pointer"); + + using CopyAtomR2S = Copy_Atom; + using CopyAtomS2R = Copy_Atom; + using CopyAtomR2G = Copy_Atom; + static constexpr int AlignmentD = CopyAtomR2G::NumValSrc; + + using SmemLayoutD = decltype(tile_to_shape(SmemLayoutAtomD{}, EpilogueTile{})); + + constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); + + struct SharedStorage + { + alignas(SmemAlignmentD) cute::ArrayEngine> smem_D; + }; + + struct TensorMapStorage + { + }; + + struct Arguments + { + typename ThreadEpilogueOp::Params thread{}; + ElementC const** ptr_C{}; + StrideC dC{}; + ElementD* ptr_D{}; + StrideD dD{}; + ElementBias const* ptr_bias; + StrideBias dBias{}; + ElementScale const* ptr_scale; + StrideScale dScale{}; + int64_t const* group_offset{}; + int32_t const* scatter_index{}; + cutlass::FastDivmod num_rows_in_final_output; + }; + + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params to_underlying_arguments( + ProblemShape const&, Arguments const& args, [[maybe_unused]] void* workspace) + { + return args; + } + + template + static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count = 0) + { + return 0; + } + + template + static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, + void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) + { + return cutlass::Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE static bool can_implement( + [[maybe_unused]] ProblemShape problem_shape, [[maybe_unused]] Arguments const& args) + { + bool implementable = true; + if (problem_shape.is_host_problem_shape_available()) + { + // Check alignment for all problem sizes + for (int i = 0; i < problem_shape.groups(); i++) + { + auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1); + auto [M, N, K, L] = problem_shape_MNKL; + implementable = implementable + && cutlass::detail::check_alignment(cute::make_shape(M, N, L), InternalStrideD{}); + } + } + + if (!implementable) + { + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for selected global " + "reduction instruction.\n"); + } + return implementable; + } + + CUTLASS_HOST_DEVICE + EpilogueMoeFusedFinalize(Params const& params_) + : params(params_) + { + } + + CUTLASS_DEVICE + bool is_source_needed() + { + // For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta. + return params.ptr_C != nullptr + && (params.thread.beta_ptr_array || params.thread.beta_ptr || params.thread.beta != 0); + } + + template + CUTLASS_HOST_DEVICE void operator()(ProblemShapeMNKL problem_shape_mnkl, BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, cute::Tensor const& accumulators, TiledMma tiled_mma, + ResidueMNK residue_mnk, int thread_idx, [[maybe_unused]] char* smem_buf) + { + using namespace cute; + using X = Underscore; + + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + + auto synchronize = [&]() + { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + auto mma_tile_m = tile_size<0>(tiled_mma); + auto mma_tile_n = tile_size<1>(tiled_mma); + auto epi_tile_m = size<0>(EpilogueTile{}); + auto epi_tile_n = size<1>(EpilogueTile{}); + + CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M"); + CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); + + // Batches are managed by using appropriate pointers to C and D matrices + int32_t const mock_L = 1; + int32_t const mock_l_coord = 0; + + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + + // If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups. + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups, + // we get the correct alpha/beta values for the current batch/group using group index. + ThreadEpilogueOp epilogue_op(params.thread, l_coord); + + SharedStorage& storage = *reinterpret_cast(smem_buf); + + Tensor sD_ = make_tensor(make_smem_ptr(storage.smem_D.begin()), SmemLayoutD{}); + Tensor sD = as_position_independent_swizzle_tensor(sD_); + + // Function to scatter output rows + auto& num_rows = params.num_rows_in_final_output; + auto read_scatter_map = IndexedGather(make_gmem_ptr(params.scatter_index + params.group_offset[l_coord])); + auto get_scatter_idx = [&](auto i) + { + auto scatter = read_scatter_map(i); + int quot, rem; + num_rows(quot, rem, scatter); + return rem; + }; + + // Represent the full output tensor + ElementC const* ptr_C = epilogue_op.is_source_needed() ? params.ptr_C[l_coord] : nullptr; + auto dC = epilogue_op.is_source_needed() ? params.dC[l_coord] : InternalStrideC{}; + Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C), make_shape(M, N, mock_L), dC); // (m,n,l) + Tensor mD_mnl = make_gather_tensor( + make_gmem_ptr(params.ptr_D), make_shape(M, N, mock_L), params.dD, get_scatter_idx); // (m,n,l) + + // Use fake shape for bias, it doesn't matter + bool const is_bias_needed = params.ptr_bias != nullptr; + Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_bias), make_shape(M, N, 1), params.dBias); + Tensor mScale_mnl = make_tensor( + make_gmem_ptr(params.ptr_scale + params.group_offset[l_coord]), make_shape(M, N), params.dScale); + + Tensor gC_mnl + = local_tile(mC_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl + = local_tile(mD_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) + + Tensor gC = gC_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N) + + Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + Tensor gBias_mnl + = local_tile(mBias_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gScale_mnl + = local_tile(mScale_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) + + Tensor gBias = gBias_mnl(_, _, m_coord, n_coord, l_coord); // (BLK_M,BLK_N) + Tensor gScale = gScale_mnl(_, _, m_coord, n_coord); // (BLK_M,BLK_N) + + Tensor gBias_epi = flat_divide(gBias, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor gScale_epi = flat_divide(gScale, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + // Get the smallest tiled copy we can use to retile the accumulators + TiledCopy tiled_copy_C_atom + = make_tiled_copy_C_atom(Copy_Atom{}, tiled_mma); + TiledCopy tiled_r2s = make_tiled_copy_S(CopyAtomR2S{}, tiled_copy_C_atom); + + auto thread_r2s = tiled_r2s.get_thread_slice(thread_idx); + Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N) + Tensor tRS_sD = thread_r2s.partition_D(sD); // ((R2S,R2S_V),R2S_M,R2S_N) + Tensor tRS_rD = make_tensor(shape(tRS_sD)); // ((R2S,R2S_V),R2S_M,R2S_N) + + // Make a tiled copy vectorized along major direction of D + auto tiled_s2r = [&]() + { + if constexpr (cutlass::gemm::detail::is_k_major()) + { + constexpr int NumThreadsMajor = epi_tile_n / AlignmentD; + constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor; + return make_tiled_copy(CopyAtomS2R{}, + Layout, Int>, Stride, _1>>{}, + Layout>>{}); + } + else if constexpr (cutlass::gemm::detail::is_mn_major()) + { + constexpr int NumThreadsMajor = epi_tile_m / AlignmentD; + constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor; + return make_tiled_copy(CopyAtomS2R{}, + Layout, Int>, Stride<_1, Int>>{}, + Layout, _1>>{}); + } + else + { + static_assert(cute::is_void_v, "Unsupported D gmem layout."); + } + }(); + + auto thread_s2r = tiled_s2r.get_thread_slice(thread_idx); + Tensor tSR_sD = thread_s2r.partition_S(sD); // ((S2R,S2R_V),S2R_M,S2R_N) + Tensor tSR_gD = thread_s2r.partition_D(gD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) + Tensor tSR_gC = thread_s2r.partition_D(gC_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) + Tensor tSR_gBias = thread_s2r.partition_D(gBias_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) + Tensor tSR_gScale = thread_s2r.partition_D(gScale_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) + + // Allocate intermediate registers for a single subtile + Tensor tSR_rD = make_tensor(take<0, 3>(shape(tSR_gD))); // ((S2R,S2R_V),S2R_M,S2R_N) + Tensor tSR_rD_final = make_tensor(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N) + Tensor tSR_rC = make_tensor(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N) + Tensor tSR_rBias = make_tensor(tSR_gBias(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N) + Tensor tSR_rScale = make_tensor(tSR_gScale(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N) + + // Make an identity coordinate tensor for predicating our output MN tile + Tensor cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); + Tensor cD_epi = flat_divide(cD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor tSR_cD = thread_s2r.partition_D(cD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) + + // epilogue subtile loop + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m) + { + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n) + { + int mma_m = (epi_m * epi_tile_m) / mma_tile_m; + int mma_n = (epi_n * epi_tile_n) / mma_tile_n; + Tensor tRS_rAcc_mn = tRS_rAcc(_, mma_m, mma_n); + + int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n); + int r2s_v = epi_n_in_mma * size(tRS_rD); + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(tRS_rD); ++epi_v) + { + tRS_rD(epi_v) = tRS_rAcc_mn(r2s_v + epi_v); + } + + copy(tiled_r2s, tRS_rD, tRS_sD); + synchronize(); + + copy(tiled_s2r, tSR_sD, tSR_rD); + synchronize(); + + Tensor tSR_gC_mn = tSR_gC(_, _, _, epi_m, epi_n); + Tensor tSR_gBias_mn = tSR_gBias(_, _, _, epi_m, epi_n); + Tensor tSR_gScale_mn = tSR_gScale(_, _, _, epi_m, epi_n); + Tensor tSR_cD_mn = tSR_cD(_, _, _, epi_m, epi_n); + Tensor tSR_gD_mn = tSR_gD(_, _, _, epi_m, epi_n); + + if (epilogue_op.is_source_needed()) + { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_rD); ++m) + { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_rD); ++n) + { + if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) + { + copy(tSR_gC_mn(_, m, n), tSR_rC(_, m, n)); + if (is_bias_needed) + { + copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n)); + } + copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(tSR_rD); ++i) + { + auto epi_value = epilogue_op(tSR_rD(i, m, n), tSR_rC(i, m, n)); + if (is_bias_needed) + { + epi_value += static_cast(tSR_rBias(i, m, n)); + } + tSR_rD_final(i, m, n) = static_cast(tSR_rScale(i, m, n) * epi_value); + } + copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n)); + } + } + } + } + else + { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_rD); ++m) + { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_rD); ++n) + { + if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) + { + if (is_bias_needed) + { + copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n)); + } + copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(tSR_rD); ++i) + { + auto epi_value = epilogue_op(tSR_rD(i, m, n)); + if (is_bias_needed) + { + epi_value += static_cast(tSR_rBias(i, m, n)); + } + tSR_rD_final(i, m, n) = static_cast(tSR_rScale(i, m, n) * epi_value); + } + copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n)); + } + } + } + } + } + } + } + +private: + Params params; +}; + +namespace detail +{ + +template +constexpr auto get_vectorized_atomic_add_op() +{ + using namespace cute; + + auto constexpr MaxVecSize = size(MaxVec{}); + + if constexpr (is_same_v) + { + if constexpr (MaxVecSize >= 8) + { + return SM90_RED_ADD_NOFTZ_F16x2_V4{}; + } + else if constexpr (MaxVecSize >= 4) + { + return SM90_RED_ADD_NOFTZ_F16x2_V2{}; + } + else if constexpr (MaxVecSize >= 2) + { + return SM70_RED_ADD_NOFTZ_F16x2{}; + } + else + { + return SM70_RED_ADD_NOFTZ_F16{}; + } + } + else if constexpr (is_same_v) + { + if constexpr (MaxVecSize >= 8) + { + return SM90_RED_ADD_NOFTZ_BF16x2_V4{}; + } + else if constexpr (MaxVecSize >= 4) + { + return SM90_RED_ADD_NOFTZ_BF16x2_V2{}; + } + else if constexpr (MaxVecSize >= 2) + { + return SM90_RED_ADD_NOFTZ_BF16x2{}; + } + else + { + return SM90_RED_ADD_NOFTZ_BF16{}; + } + } + else + { + // non-vectorized atomic add for all other types until supported + return TypedAtomicAdd{}; + } +} + +} // namespace detail + +template +struct EpilogueMoeFusedFinalizeBuilder +{ + + // assuming cooperative kernel schedule + using EpiTileN = decltype(cute::min(size<1>(TileShape{}), _32{})); + using EpilogueTile = Shape<_128, EpiTileN>; + + // Output of linear combination is ElementCompute instead of ElementD + // since we will be doing more computate on it, no need to cast yet. + using ThreadEpilogueOp + = cutlass::epilogue::thread::LinearCombination; + + using SmemLayoutAtomD + = decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()); + using CopyAtomR2S = decltype(detail::sm90_get_smem_store_op_for_accumulator()); + using CopyAtomS2R = DefaultCopy; + using CopyAtomR2G = decltype(detail::get_vectorized_atomic_add_op()); + + template + struct Sm90TmaWarpSpecializedAdapterWithSmemStorage : detail::Sm90TmaWarpSpecializedAdapter + { + // We need to override this one using declaration because otherwise we double up on the smem + using TensorMapStorage = typename EpilogueOp::TensorMapStorage; + + using Base = detail::Sm90TmaWarpSpecializedAdapter; + + CUTLASS_HOST_DEVICE + Sm90TmaWarpSpecializedAdapterWithSmemStorage( + typename EpilogueOp::Params const& params, [[maybe_unused]] typename Base::TensorStorage& shared_tensors) + : Base(params) + { + } + + // These functions depend on the type of TensorMapStorage + template + CUTLASS_DEVICE void tensormaps_perform_update([[maybe_unused]] TensorMapStorage& shared_tensormap, + [[maybe_unused]] typename EpilogueOp::Params const& params, + [[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] int32_t next_batch) + { + } + + template + CUTLASS_DEVICE void tensormaps_cp_fence_release([[maybe_unused]] TensorMapStorage& shared_tensormap, + [[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] uint32_t lane_predicate) + { + } + }; + + using CollectiveOp = Sm90TmaWarpSpecializedAdapterWithSmemStorage< + EpilogueMoeFusedFinalize>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h new file mode 100644 index 00000000000..f3c622b88a5 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h @@ -0,0 +1,105 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing linear combination with a maximum operation used by epilogues. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/functional.h" +#include "cutlass/half.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace epilogue +{ +namespace thread +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +__forceinline__ __device__ float copysignf_pos(float a, float b) +{ + float r; + r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); + return r; +} + +__forceinline__ __device__ float tanh_opt(float x) +{ +#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750) + float const exp_val = -1.f * fabs(2 * x); + return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); +#else + return fast_tanh(x); +#endif +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +template <> +struct GELU_taylor +{ + static bool const kIsHeavy = true; + + CUTLASS_DEVICE + float operator()(float const& z) const + { + + float k0 = float(0.7978845608028654); + float k1 = float(0.044715); + + return float(cutlass::constants::half() * z + * (cutlass::constants::one() + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z)))); + } + + using Params = LinearCombinationGenericParams; + + CUTLASS_DEVICE + float operator()(float const& scalar, Params const& params_) const + { + return this->operator()(scalar); + } +}; + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h new file mode 100644 index 00000000000..d3d4d0a45ab --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h @@ -0,0 +1,352 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one scaling factor per row, and one per column. + + original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h + +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/arch/memory.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/numeric_conversion.h" +#include "tensorrt_llm/common/quantization.h" + +namespace tk = tensorrt_llm::common; + +namespace cutlass +{ +namespace epilogue +{ +namespace threadblock +{ + +template +class EpilogueVisitorPerRowPerCol +{ +public: + using ThreadblockShape = ThreadblockShape_; + static int const kThreadCount = ThreadCount; + + using ScaleTileIterator = ScaleTileIterator_; + using OutputTileIterator = OutputTileIterator_; + using ElementwiseFunctor = ElementwiseFunctor_; + + static int const kIterations = OutputTileIterator::kIterations; + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + using ElementOutput = typename OutputTileIterator::Element; + using LayoutOutput = cutlass::layout::RowMajor; + using ElementAccumulator = ElementAccumulator_; + + using AlphaScaleElementType = typename ScaleTileIterator::Element; + + using ElementCompute = ElementCompute_; + using AccumulatorFragment = Array; + using ComputeFragment = Array; + using OutputVector = Array; + + static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth; + static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1); + + /// Argument structure + struct Arguments + { + + typename ElementwiseFunctor::Params elementwise; + int64_t batch_stride_alpha; + int64_t batch_stride_C; + int64_t batch_stride_D; + + // + // Methods + // + Arguments() + : batch_stride_alpha(0) + , batch_stride_C(0) + , batch_stride_D(0) + { + } + + Arguments(typename ElementwiseFunctor::Params elementwise_) + : elementwise(elementwise_) + , batch_stride_alpha(0) + , batch_stride_C(0) + , batch_stride_D(0) + { + } + + Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_, + int64_t batch_stride_C_, int64_t batch_stride_D_) + : elementwise(elementwise_) + , batch_stride_alpha(batch_stride_alpha_) + , batch_stride_C(batch_stride_C_) + , batch_stride_D(batch_stride_D_) + { + } + }; + + struct Params + { + + typename ElementwiseFunctor::Params elementwise; + int64_t batch_stride_alpha; + int64_t batch_stride_C; + int64_t batch_stride_D; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Arguments const& args) + : elementwise(args.elementwise) + , batch_stride_alpha(args.batch_stride_alpha) + , batch_stride_C(args.batch_stride_C) + , batch_stride_D(args.batch_stride_D) + { + } + }; + + /// Shared storage + struct SharedStorage + { + }; + +private: + Params const& params_; + SharedStorage& shared_storage_; + MatrixCoord extent_; + MatrixCoord extent_real_; + ElementwiseFunctor elementwise_; + + bool const per_token_quant_; + bool const per_channel_quant_; + + AlphaScaleElementType* ptr_alpha_row_; + AlphaScaleElementType* ptr_alpha_col_; + ScaleTileIterator iterator_alpha_col_; + OutputTileIterator iterator_C_; + OutputTileIterator iterator_D_; + + AlphaScaleElementType element_alpha_row_ = 1.0f; + AlphaScaleElementType element_alpha_col_ = 1.0f; + typename ScaleTileIterator::Fragment fragment_alpha_col_; + typename OutputTileIterator::Fragment fragment_C_; + typename OutputTileIterator::Fragment fragment_D_; + + ElementAccumulator beta_; + + int column_offset_; + + MatrixCoord thread_offset_; + +public: + CUTLASS_DEVICE + EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage, + cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx, + typename ScaleTileIterator::Params params_alpha_col, typename OutputTileIterator::Params params_C, + typename OutputTileIterator::Params params_D, tk::QuantMode quant_option, AlphaScaleElementType* ptr_alpha_row, + AlphaScaleElementType* ptr_alpha_col, typename OutputTileIterator::Element* ptr_C, + typename OutputTileIterator::Element* ptr_D, + cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), int column_offset = 0, + cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0)) + : params_(params) + , shared_storage_(shared_storage) + , extent_(problem_size) + , elementwise_(params.elementwise) + , per_token_quant_(quant_option.hasPerTokenScaling()) + , per_channel_quant_(quant_option.hasPerChannelScaling()) + , ptr_alpha_row_(ptr_alpha_row) + , ptr_alpha_col_(ptr_alpha_col) + , iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset) + , iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset) + , iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset) + , extent_real_(problem_size_real) + { + beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta); + + if (beta_ == ElementAccumulator()) + { + iterator_C_.clear_mask(); + } + + if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr)) + { + element_alpha_col_ = *ptr_alpha_col_; + } + + if (!per_token_quant_ && (ptr_alpha_row_ != nullptr)) + { + element_alpha_row_ = *ptr_alpha_row_; + } + } + + /// Helper to indicate split-K behavior + CUTLASS_DEVICE + void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme + int split_k_slices) + { ///< Total number of split-K slices + } + + /// Called to set the batch index + CUTLASS_DEVICE + void set_batch_index(int batch_idx) + { + iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha); + iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C); + iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D); + } + + /// Called at the start of the epilogue just before iterating over accumulator slices + CUTLASS_DEVICE + void begin_epilogue() + { + if (per_channel_quant_) + { + iterator_alpha_col_.load(fragment_alpha_col_); + } + } + + /// Called at the start of one step before starting accumulator exchange + CUTLASS_DEVICE + void begin_step(int step_idx) + { + fragment_D_.clear(); + fragment_C_.clear(); + + if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) + { + iterator_C_.load(fragment_C_); + ++iterator_C_; + } + } + + /// Called at the start of a row + CUTLASS_DEVICE + void begin_row(int row_idx) + { + // load alpha_row in begin_step only when per token(row) scaling is used + if (per_token_quant_) + { + int thread_offset_row + = iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row(); + + arch::global_load( + element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row()); + } + } + + /// Called after accumulators have been exchanged for each accumulator vector + CUTLASS_DEVICE + void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum) + { + + NumericArrayConverter source_converter; + + ComputeFragment result = source_converter(accum); + if (per_channel_quant_) + { + ComputeFragment alpha_col = reinterpret_cast(&fragment_alpha_col_)[column_idx]; + result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_); + } + else + { + result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_); + } + + // Convert to the output + NumericArrayConverter output_converter; + OutputVector& output = reinterpret_cast(&fragment_D_)[frag_idx]; + output = output_converter(result); + } + + /// Called at the end of a row + CUTLASS_DEVICE + void end_row(int row_idx) {} + + /// Called after all accumulator elements have been visited + CUTLASS_DEVICE + void end_step(int step_idx) + { + + iterator_D_.store(fragment_D_); + ++iterator_D_; + } + + /// Called after all steps have been completed + CUTLASS_DEVICE + void end_epilogue() {} + +private: + CUTLASS_DEVICE + ComputeFragment per_token_channel_scale_accumulator_( + ComputeFragment const& accum, ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row) + { + + ComputeFragment result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ComputeFragment::kElements; ++i) + { + result[i] = accum[i] * (scale_col[i] * scale_row); + } + + return result; + } + + CUTLASS_DEVICE + ComputeFragment per_token_scale_accumulator_( + ComputeFragment const& accum, AlphaScaleElementType const& scale_col, AlphaScaleElementType const& scale_row) + { + + ComputeFragment result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ComputeFragment::kElements; ++i) + { + result[i] = accum[i] * (scale_col * scale_row); + } + + return result; + } +}; + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h new file mode 100644 index 00000000000..6f26d790170 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h @@ -0,0 +1,282 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + + original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h + +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/platform/platform.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_clamp.h" +#include "cutlass/epilogue/thread/linear_combination_gelu.h" +#include "cutlass/epilogue/thread/linear_combination_hardswish.h" +#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/thread/linear_combination_relu0.h" +#include "cutlass/epilogue/thread/linear_combination_sigmoid.h" + +#include "cutlass/epilogue/thread/conversion_op.h" +#include "cutlass/epilogue/thread/reduction_op.h" + +#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" + +#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h" +#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h" +#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h" + +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/threadblock/interleaved_epilogue.h" + +#include "cutlass/layout/permute.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace epilogue +{ +namespace threadblock +{ + +//////////////////////////////////////////////////////////////////////////////// + +namespace detail +{ + +/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts. +template +struct DefaultIteratorsTensorOp +{ + using WarpTileIterator + = cutlass::epilogue::warp::TileIteratorTensorOpMixed; + + using SharedLoadIterator + = cutlass::epilogue::threadblock::SharedLoadIteratorMixed; + + static int const kFragmentsPerIteration = 2; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load output tile from shared memory in epilogue. +/// +/// Satisfies: ReadableTileIterator +/// +template +class SharedLoadIteratorMixed +{ +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = int32_t; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits::value / 8; + + static int const kThreads = ThreadMap::kThreads; + + /// Fragment object + using Fragment = Array; + + /// Memory access size + using AccessType = AlignedArray; + + /// Vector type used for SMEM loads + using LoadType = AlignedArray::value, ThreadMap::kElementsPerAccess), + const_min(16, kAlignment)>; + + static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements; + +private: + // + // Data members + // + + /// Byte-level pointer + LoadType const* pointers_[kLoadsPerAccess]; + + /// Stride along adjacent rows in units of LoadType + int stride_; + +public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + SharedLoadIteratorMixed(TensorRef ref, int thread_idx) + : stride_((ref.stride(0) / LoadType::kElements)) + { + + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); + + // Initialize pointers + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) + { + pointers_[i] = reinterpret_cast(ref.data()); + + int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess; + int bank_offset = (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess; + + col_idx += (bank_offset + i) % kLoadsPerAccess; + + pointers_[i] += thread_offset.row() * stride_ + col_idx; + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) + { + pointers_[i] += pointer_offset / LoadType::kElements; + } + } + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& offset) + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) + { + pointers_[i] + += offset.row() * Shape::kRow * stride_ + offset.column() * Shape::kColumn / LoadType::kElements; + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const + { + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) + { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) + { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) + { + + int row_ptr_offset = row * ThreadMap::Delta::kRow * stride_ + + group * ThreadMap::Delta::kGroup * stride_ + cluster * ThreadMap::Delta::kCluster * stride_ + + pointer_offset / LoadType::kElements; + + int frag_row_idx + = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + LoadType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) + { + + int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kLoadsPerAccess; ++v) + { + + int vector_idx + = (column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess); + + LoadType const* memory_pointer = pointers_[v] + row_ptr_offset; + + frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx]; + } + } + } + } + } + } + + /// Loads a fragment + CUTLASS_DEVICE + void load(Fragment& frag) const + { + + load_with_pointer_offset(frag, 0); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h new file mode 100644 index 00000000000..233d633a823 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h @@ -0,0 +1,141 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * @file epilogue_helpers.h + * + * This file includes types for the epilogues. The empty structs exist so we can signal to template + * code the type of epilogue we want to run, and let the underlying code specify the details such as + * element types, accumulator type and elements per vector access. + * + */ + +#pragma once + +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/thread/linear_combination_silu.h" +#include "cutlass_extensions/epilogue/thread/fused_activations.h" +#include + +namespace tensorrt_llm +{ +namespace cutlass_extensions +{ + +struct EpilogueOpBiasSilu +{ +}; + +struct EpilogueOpBiasReLU +{ +}; + +struct EpilogueOpBiasFtGelu +{ +}; + +struct EpilogueOpBias +{ +}; + +struct EpilogueOpDefaultSilu +{ +}; + +struct EpilogueOpDefaultReLU +{ +}; + +struct EpilogueOpDefaultFtGelu +{ +}; + +struct EpilogueOpDefault +{ +}; + +template +struct Epilogue +{ + static_assert(sizeof(ElementType) == 0, "Unrecognized Epilogue Tag"); +}; + +constexpr auto BiasScaleMode = cutlass::epilogue::thread::ScaleType::NoBetaScaling; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombinationSilu; +}; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombinationRelu; +}; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombinationGeneric; +}; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombination; +}; + +constexpr auto DefaultScaleMode = cutlass::epilogue::thread::ScaleType::Default; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombinationSilu; +}; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombinationRelu; +}; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombinationGeneric; +}; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombination; +}; + +} // namespace cutlass_extensions +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl new file mode 100644 index 00000000000..593eca06e3d --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl @@ -0,0 +1,221 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/arch/mma.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/gemm/collective/builders/sm90_common.inl" + +// SM90 Collective Builders should be used only starting CUDA 12.0 +#if (__CUDACC_VER_MAJOR__ >= 12) +#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail +{ + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int compute_stage_count_or_override_gated(StageCountAutoCarveout stage_count) +{ + // 32 bytes to account for barriers etc. + constexpr int stage_barrier_bytes = 32; + constexpr int a_bits = static_cast(sizeof_bits::value); + constexpr int b_bits = static_cast(sizeof_bits::value); + constexpr int stage_bytes = [&]() -> int + { + if constexpr (SwapAB) + { + return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / 8 + + (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 + stage_barrier_bytes; + } + else + { + return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 + + (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / 8 + stage_barrier_bytes; + } + }(); + + return (CapacityBytes - carveout_bytes) / stage_bytes; +} + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_WS_SS +template class Activation, bool SwapAB> +struct CollectiveBuilderGated + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v) &¬ detail:: + is_use_rmem_A()>> +{ + static_assert(is_static::value); + static_assert(is_static::value); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + static_assert(detail::is_aligned(), + "Should meet TMA alignment requirement\n"); + + static constexpr bool IsArrayOfPointersGemm + = (cute::is_same_v); + static constexpr bool IsFP8Input = detail::is_input_fp8(); + static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm), + "Kernel[Array/Group]TmaWarpSpecializedCooperative is only compatible with FP8 FastAccum version right now\n"); + + // For fp32 types, map to tf32 MMA value type + using MmaElementA = cute::conditional_t, tfloat32_t, ElementA>; + using MmaElementB = cute::conditional_t, tfloat32_t, ElementB>; + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + using AtomLayoutMNK = cute::conditional_t + || IsArrayOfPointersGemm, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector(), + AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages + = detail::compute_stage_count_or_override_gated(StageCountType{}); + using DispatchPolicy = cute::conditional_t, + /* For FP8 use a separate mainloop compared to other datatypes */ + cute::conditional_t, + MainloopSm90TmaGmmaWarpSpecialized>>; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMmaGated, + ElementB, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_WS_FP8_FAST_ACCUM_SS +template class Activation, bool SwapAB> +struct CollectiveBuilderGated + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v>> +{ + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(detail::is_aligned(), + "Not meet TMA alignment requirement yet\n"); + static_assert( + detail::is_input_fp8(), "Only FP8 datatypes are compatible with these kernel schedules\n"); + // Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder + static_assert(!detail::is_use_rmem_A(), + "Not supported for fp8 non-TN warp specialized kernels yet\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + static constexpr bool IsArrayOfPointersGemm + = (cute::is_same_v); + using AtomLayoutMNK + = cute::conditional_t + || IsArrayOfPointersGemm, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma( + cute::GMMA::ss_op_selector(), + AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages + = detail::compute_stage_count_or_override_gated(StageCountType{}); + using DispatchPolicy = cute::conditional_t, + MainloopSm90TmaGmmaWarpSpecialized>; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMmaGated, + ElementB, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_builder_gated.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_builder_gated.hpp new file mode 100644 index 00000000000..2f2422c9914 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_builder_gated.hpp @@ -0,0 +1,58 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass_extensions/gemm/collective/collective_mma_gated.hpp" + +namespace cutlass::gemm::collective +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template class Activation, + bool SwapAB = false, class Enable = void> +struct CollectiveBuilderGated +{ + static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_mma_gated.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_mma_gated.hpp new file mode 100644 index 00000000000..d850f36df5f --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_mma_gated.hpp @@ -0,0 +1,59 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/detail/dependent_false.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template class Activation, bool SwapAB = false> +struct CollectiveMmaGated +{ + static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp" +#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp new file mode 100644 index 00000000000..dcba6ee6377 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp @@ -0,0 +1,642 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cute/tensor_predicate.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective +{ +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template class Activation_, bool SwapAB_> +struct CollectiveMmaGated, TileShape_, + ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, + GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_, SwapAB_> +{ + static constexpr bool isGated = true; + static constexpr bool SwapAB = SwapAB_; + + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecialized; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + using Activation = Activation_; + + using ElementAux = cute::conditional_t; + using ValTypeAux = cute::conditional_t; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert( + (size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert( + (size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert( + (size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert( + (size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape(SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + using SmemLayoutB = decltype(tile_to_shape(SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + using SmemLayoutAux = cute::conditional_t; + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(cute::is_base_of::value + && cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert( + cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using InternalElementA = cute::conditional_t>>; + using InternalElementB = cute::conditional_t>>; + using InternalElementAux = cute::conditional_t; + + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128> + { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + cute::array_aligned> smem_Aux; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments + { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + float scale_d0 = 1.0f; + float scale_d1 = 1.0f; + uint32_t mma_promotion_interval = 4; + }; + + // Device side kernel params + struct Params + { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy(GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy(GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + using TMA_Aux = cute::conditional_t; + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_Aux tma_load_aux; + float scale_d0 = 1.0f; + float scale_d1 = 1.0f; + }; + + // + // Methods + // + + template + static constexpr Params to_underlying_arguments( + ProblemShape const& problem_shape, Arguments const& args, void* workspace) + { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N, K, L), args.dB)); + typename Params::TMA_A tma_load_a = make_tma_copy(GmemTiledCopyA{}, tensor_a, + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + typename Params::TMA_B tma_load_b = make_tma_copy(GmemTiledCopyB{}, tensor_b, + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + + if constexpr (SwapAB) + { + auto ptr_Aux = reinterpret_cast(args.ptr_A + size(make_shape(M, K, L))); + Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA)); + typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyA{}, tensor_aux, + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1}; + } + else + { + auto ptr_Aux = reinterpret_cast(args.ptr_B + size(make_shape(N, K, L))); + Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB)); + typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyB{}, tensor_aux, + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1}; + } + } + + template + static bool can_implement(ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) + { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable + && cutlass::detail::check_alignment(cute::make_shape(M, K, L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable + && cutlass::detail::check_alignment(cute::make_shape(N, K, L), StrideB{}); + + if (!implementable) + { + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytes + = (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)) / 8 + + (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)) / 8 + + (size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) * static_cast(sizeof_bits::value)) + / 8; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) + { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_aux.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// gAux_xkl - The tma tensor, A/B after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const + { + using X = Underscore; + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M, K, L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,n,k,l) + + if constexpr (SwapAB) + { + Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(M, K, L)); // (m,k,l) + Tensor gAux_xkl + = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); + } + else + { + Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(N, K, L)); // (n,k,l) + Tensor gAux_xkl + = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,n,k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template + CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write, + cute::tuple const& load_inputs, BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) + { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) + { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id + = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + Tensor gAux_xkl = get<2>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + auto block_tma_aux = SwapAB ? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y) + : mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.x); + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k) + Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord) : gAux_xkl(_, _, n_coord, _, l_coord); + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + Tensor tAuxgAux = block_tma_aux.partition_S(gAux); + Tensor tAuxsAux = block_tma_aux.partition_D(sAux); + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + uint16_t mcast_mask_aux = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) + { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) + { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, Int<0>{})); + } + } + + if constexpr (cute::is_same_v) + { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) + { + mcast_mask_b |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, Int<0>{})); + } + } + + if constexpr (SwapAB) + { + mcast_mask_aux = mcast_mask_a; + } + else + { + mcast_mask_aux = mcast_mask_b; + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) + { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_, _, _, *k_tile_iter), + tAsA(_, _, _, write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_, _, _, *k_tile_iter), + tBsB(_, _, _, write_stage)); + copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux), tAuxgAux(_, _, _, *k_tile_iter), + tAuxsAux(_, _, _, write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) + { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) + { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template + CUTLASS_DEVICE void mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, FrgTensorC& accum0, + FrgTensorC& accum1, int k_tile_count, int thread_idx, TensorStorage& shared_tensors, + Params const& mainloop_params) + { + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutAux{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + auto tCsAux = [&]() -> auto + { + if constexpr (SwapAB) + { + return thread_mma.partition_A(sAux); + } + else + { + return thread_mma.partition_B(sAux); + } + }(); + auto tCrAux = [&]() -> auto + { + if constexpr (SwapAB) + { + return thread_mma.make_fragment_A(tCsAux); + } + else + { + return thread_mma.make_fragment_B(tCsAux); + } + }(); + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + if constexpr (SwapAB) + { + CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE + } + else + { + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE + } + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sAux)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) + { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum0); + if constexpr (SwapAB) + { + cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum1); + } + else + { + cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accum1); + } + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) + { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum0); + if constexpr (SwapAB) + { + cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum1); + } + else + { + cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accum1); + } + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + + // UNLOCK smem_pipe_release, done _computing_ on it + pipeline.consumer_release(smem_pipe_release); + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) + { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) + { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp new file mode 100644 index 00000000000..72c1adf293f --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp @@ -0,0 +1,665 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cute/tensor_predicate.hpp" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/gemm/collective/fp8_accumulation.hpp" +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective +{ +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template class Activation_, bool SwapAB_> +struct CollectiveMmaGated, TileShape_, + ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, + GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_, SwapAB_> +{ + static constexpr bool isGated = true; + static constexpr bool SwapAB = SwapAB_; + + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedFP8; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + using Activation = Activation_; + + using ElementAux = cute::conditional_t; + using ValTypeAux = cute::conditional_t; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert( + (size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert( + (size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert( + (size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert( + (size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape(SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + using SmemLayoutB = decltype(tile_to_shape(SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + using SmemLayoutAux = cute::conditional_t; + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value + && cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert( + cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128> + { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + cute::array_aligned> smem_Aux; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments + { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + float scale_d0 = 1.0f; + float scale_d1 = 1.0f; + uint32_t mma_promotion_interval = 4; + }; + + // Device side kernel params + struct Params + { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy(GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_, _, 0), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy(GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_, _, 0), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + using TMA_Aux = cute::conditional_t; + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_Aux tma_load_aux; + float scale_d0 = 1.0f; + float scale_d1 = 1.0f; + uint32_t mma_promotion_interval = 4; + }; + + // + // Methods + // + + template + static constexpr Params to_underlying_arguments( + ProblemShape const& problem_shape, Arguments const& args, void* workspace) + { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N, K, L), args.dB)); + typename Params::TMA_A tma_load_a = make_tma_copy(GmemTiledCopyA{}, tensor_a, + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + typename Params::TMA_B tma_load_b = make_tma_copy(GmemTiledCopyB{}, tensor_b, + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + if constexpr (SwapAB) + { + auto ptr_Aux = reinterpret_cast(args.ptr_A + size(make_shape(M, K, L))); + Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA)); + typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyA{}, tensor_aux, + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1, args.mma_promotion_interval}; + } + else + { + auto ptr_Aux = reinterpret_cast(args.ptr_B + size(make_shape(N, K, L))); + Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB)); + typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyB{}, tensor_aux, + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1, args.mma_promotion_interval}; + } + } + + template + static bool can_implement(ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) + { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable + && cutlass::detail::check_alignment(cute::make_shape(M, K, L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable + && cutlass::detail::check_alignment(cute::make_shape(N, K, L), StrideB{}); + /* MMA promotion interval should be a multiple of 4, since each mainloop iteration would issue 4 MMA + * instructions. */ + implementable = implementable && (args.mma_promotion_interval % 4 == 0); + + if (!implementable) + { + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytes + = (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)) / 8 + + (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)) / 8 + + (size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) * static_cast(sizeof_bits::value)) + / 8; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) + { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_aux.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// gAux_xkl - The tma tensor, A/B after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l) + template + CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const + { + using X = Underscore; + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M, K, L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,n,k,l) + + if constexpr (SwapAB) + { + Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(M, K, L)); // (m,k,l) + Tensor gAux_xkl + = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); + } + else + { + Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(N, K, L)); // (n,k,l) + Tensor gAux_xkl + = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,n,k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template + CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write, + cute::tuple const& load_inputs, BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) + { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) + { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id + = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + Tensor gAux_xkl = get<2>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + auto block_tma_aux = SwapAB ? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y) + : mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k) + Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord) : gAux_xkl(_, _, n_coord, _, l_coord); + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + Tensor tAuxgAux = block_tma_aux.partition_S(gAux); + Tensor tAuxsAux = block_tma_aux.partition_D(sAux); + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + uint16_t mcast_mask_aux = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) + { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) + { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, Int<0>{})); + } + } + + if constexpr (cute::is_same_v) + { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) + { + mcast_mask_b |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, Int<0>{})); + } + } + + if constexpr (SwapAB) + { + mcast_mask_aux = mcast_mask_a; + } + else + { + mcast_mask_aux = mcast_mask_b; + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) + { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_, _, _, *k_tile_iter), + tAsA(_, _, _, write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_, _, _, *k_tile_iter), + tBsB(_, _, _, write_stage)); + copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux), tAuxgAux(_, _, _, *k_tile_iter), + tAuxsAux(_, _, _, write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) + { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) + { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template + CUTLASS_DEVICE void mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, FrgTensorC& accum0, + FrgTensorC& accum1, int k_tile_count, int thread_idx, TensorStorage& shared_tensors, + Params const& mainloop_params) + { + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + auto tCsAux = [&]() -> auto + { + if constexpr (SwapAB) + { + return thread_mma.partition_A(sAux); + } + else + { + return thread_mma.partition_B(sAux); + } + }(); + auto tCrAux = [&]() -> auto + { + if constexpr (SwapAB) + { + return thread_mma.make_fragment_A(tCsAux); + } + else + { + return thread_mma.make_fragment_B(tCsAux); + } + }(); + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + if constexpr (SwapAB) + { + CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE + } + else + { + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE + } + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sAux)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + GmmaFP8Accumulation accumulation0(accum0, mainloop_params.mma_promotion_interval, size<2>(tCrA)); + GmmaFP8Accumulation accumulation1(accum1, mainloop_params.mma_promotion_interval, size<2>(tCrA)); + warpgroup_fence_operand(accumulation0()); + warpgroup_fence_operand(accumulation1()); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + if (accumulation0.prepare_if_needed()) + { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) + { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm( + tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation0()); + if constexpr (SwapAB) + { + cute::gemm( + tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation1()); + } + else + { + cute::gemm( + tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accumulation1()); + } + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + accumulation0.promote_if_needed(); + accumulation1.promote_if_needed(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accumulation0()); + warpgroup_fence_operand(accumulation1()); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + if (accumulation0.prepare_if_needed()) + { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + warpgroup_fence_operand(accumulation0()); + warpgroup_fence_operand(accumulation1()); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) + { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm( + tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation0()); + if constexpr (SwapAB) + { + cute::gemm( + tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation1()); + } + else + { + cute::gemm( + tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accumulation1()); + } + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accumulation0()); + warpgroup_fence_operand(accumulation1()); + + accumulation0.promote_if_needed(); + accumulation1.promote_if_needed(); + + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + accumulation0.promote_residue_if_needed(); + accumulation1.promote_residue_if_needed(); + + warpgroup_fence_operand(accumulation0()); + warpgroup_fence_operand(accumulation1()); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) + { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) + { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h new file mode 100644 index 00000000000..2edd5a228b4 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h @@ -0,0 +1,438 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ + +#pragma once + +// #include + +#include "cutlass/arch/arch.h" +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_universal.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm_universal.h" + +#include "cutlass/trace.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace device +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/* + This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088) + It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs + and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs. + + Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support + that feature at the moment. + */ + +template +class GemmUniversalBaseCompat +{ +public: + using GemmKernel = GemmKernel_; + using ThreadblockShape = typename GemmKernel::Mma::Shape; + + using ElementA = typename GemmKernel::ElementA; + using LayoutA = typename GemmKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = GemmKernel::kTransformA; + + using ElementB = typename GemmKernel::ElementB; + using LayoutB = typename GemmKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = GemmKernel::kTransformB; + + using ElementC = typename GemmKernel::ElementC; + using LayoutC = typename GemmKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + + using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC; + + using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; + using Operator = typename GemmKernel::Operator; + + /// Argument structure + using Arguments = typename GemmKernel::Arguments; + +protected: + /// Kernel parameters object + typename GemmKernel::Params params_; + +protected: + /// Private helper to obtain the grid dimensions with fix-up for split-K + static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args) + { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); + + gemm_k_size = args.problem_size.k(); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) + { + + int const kAlignK + = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + + if (gemm_k_size) + { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + } + +public: + /// Constructs the GEMM. + GemmUniversalBaseCompat() {} + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args) + { + + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + ThreadblockSwizzle threadblock_swizzle; + dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + + uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1); + + if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) + { + + return Status::kErrorInvalidProblem; + } + + return GemmKernel::can_implement(args); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args) + { + + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()"); + + size_t workspace_bytes = 0; + + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + if (args.mode == GemmUniversalMode::kGemmSplitKParallel) + { + + // Split-K parallel always requires a temporary workspace + workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k()); + } + else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) + { + + // Serial split-K only requires a temporary workspace if the number of partitions along the + // GEMM K dimension is greater than one. + workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); + } + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape); + + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const& args) + { + + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()"); + + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + + CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n" + << " result = {" << result << "}"); + + return result; + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) + { + + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()"); + + int max_active_blocks = -1; + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + + if (smem_size <= (48 << 10)) + { + + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, Kernel, GemmKernel::kThreadCount, smem_size); + + if (result == cudaSuccess) + { + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + } + else + { + + // Query assuming zero shared memory then compute occupancy limit based on SMEM + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, Kernel, GemmKernel::kThreadCount, 0); + + if (result != cudaSuccess) + { + + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); + + return -1; + } + + if (smem_capacity < 0) + { + int device_idx = 0; + result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) + { + return -1; + } + + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) + { + return -1; + } + + smem_capacity = static_cast(properties.sharedMemPerMultiprocessor); + } + + int occupancy = std::min(max_active_blocks, smem_capacity / smem_size); + + CUTLASS_TRACE_HOST(" occupancy: " << occupancy); + + return occupancy; + } + + CUTLASS_TRACE_HOST(" returning internal error"); + + return -1; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) + { + + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + size_t workspace_bytes = get_workspace_size(args); + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + if (workspace_bytes) + { + + if (!workspace) + { + CUTLASS_TRACE_HOST(" error: device workspace must not be null"); + + return Status::kErrorWorkspaceNull; + } + + if (args.mode == GemmUniversalMode::kGemm) + { + CUTLASS_TRACE_HOST(" clearing device workspace"); + cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream); + + if (result != cudaSuccess) + { + CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); + + return Status::kErrorInternal; + } + } + } + + // Get CUDA grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + // Initialize the Params structure + params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast(workspace)); + + // Specify shared memory capacity for kernel. + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) + { + cudaError_t result + = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) + { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const& args, void* workspace = nullptr) + { + + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) + { + return Status::kErrorWorkspaceNull; + } + + params_.update(args, workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) + { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()"); + + // + // Configure grid and block dimensions + // + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + // + // Launch kernel + // + + CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes"); + + // Launch + cutlass::Kernel<<>>(params_); + + // + // Query for errors + // + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) + { + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) + { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) + { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) + { + status = run(stream); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/splitk_gemm_grouped.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/splitk_gemm_grouped.h new file mode 100644 index 00000000000..bfd3666b9c1 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/splitk_gemm_grouped.h @@ -0,0 +1,542 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/arch/arch.h" +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_universal.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm_universal.h" + +#include "cutlass/trace.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace device +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, GemmCoord const* problem_sizes, int splitk, + int64_t* splitk_buffer_offsets) +{ + // in_tensor: [problem_idx, k_partition, hidden_size] + // Note that different requests of in_tensor might have different hidden_size (=m*n) + // so, we need to use splitk_buffer_offsets. + // out_tensor: problem_idx * [hidden_size] + + int const problem_idx = blockIdx.y; + GemmCoord problem = problem_sizes[problem_idx]; + int const hidden_size = problem.m() * problem.n(); + const T_IN* in_tensor_ = in_tensor + splitk_buffer_offsets[problem_idx] * splitk; + T_OUT* out_tensor_ = out_tensor[problem_idx]; + + for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < hidden_size; i += blockDim.x * gridDim.x) + { + float sum = 0.0f; + for (int k_idx = 0; k_idx < splitk; k_idx++) + { + sum += (float) in_tensor_[k_idx * hidden_size + i]; + } + out_tensor_[i] = (T_OUT) (sum); + } +} + +/// GEMM Grouped +template +class BaseSplitkGrouped +{ +public: + using BaseKernel = BaseKernel_; + + using ElementA = typename BaseKernel::ElementA; + using LayoutA = typename BaseKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = BaseKernel::kTransformA; + static int const kAlignmentA = BaseKernel::kAlignmentA; + + using ElementB = typename BaseKernel::ElementB; + using LayoutB = typename BaseKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = BaseKernel::kTransformB; + static int const kAlignmentB = BaseKernel::kAlignmentB; + + using ElementC = typename BaseKernel::ElementC; + using LayoutC = typename BaseKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + static int const kAlignmentC = BaseKernel::kAlignmentC; + + using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC; + + using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename threadblock::GemmSplitKHorizontalThreadblockSwizzle; + + using Operator = typename BaseKernel::Operator; + using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename WarpMmaOperator::MathOperator; + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + using ThreadblockShape = typename BaseKernel::Mma::Shape; + using WarpShape = typename BaseKernel::WarpShape; + using InstructionShape = typename BaseKernel::InstructionShape; + static int const kStages = BaseKernel::Mma::kStages; + + /// Argument structure + using Arguments = typename BaseKernel::Arguments; + + using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo; + +protected: + /// Kernel parameters object + typename BaseKernel::Params gemm_params_; + +private: + /// Get the number of tiles across all problems in a group + static int32_t group_tile_count(cutlass::gemm::GemmCoord const* problem_sizes_ptr, int problem_count) + { + int32_t tiles = 0; + for (int32_t i = 0; i < problem_count; ++i) + { + cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i]; + BaseKernel::ProblemVisitor::possibly_transpose_problem(problem); + tiles += problem_tile_count(problem); + } + return tiles; + } + + /// Copy from `data` to `workspace` + Status copy_to_workspace(void* workspace, void* data, size_t bytes) + { + cudaError_t cuda_error = cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice); + if (cuda_error != cudaSuccess) + { + // Call cudaGetLastError() to clear the error bit + cuda_error = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaMemcpy() returned error " << cudaGetErrorString(cuda_error)); + return Status::kErrorInternal; + } + + return Status::kSuccess; + } + + /// Precomputes scheduling information for the grouped GEMM + Status precompute(Arguments const& args, int32_t tile_count, void* workspace) + { + size_t workspace_bytes = get_workspace_size(args); + std::vector host_workspace(workspace_bytes); + BaseKernel::ProblemVisitor::host_precompute( + args.host_problem_sizes, args.problem_count, args.threadblock_count, (void*) host_workspace.data()); + return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes); + } + + /// Reorder `data` according to `indices` + template + static void reorder_array(T* data, std::vector const& indices) + { + // For now, simply create a copy of the data and then copy over to the original. + std::vector copy(indices.size()); + for (size_t i = 0; i < indices.size(); ++i) + { + copy.at(i) = data[indices[i]]; + } + + memcpy(data, copy.data(), indices.size() * sizeof(T)); + } + +public: + /// Constructs the GEMM. + BaseSplitkGrouped() {} + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args) + { + + return BaseKernel::can_implement(args); + } + + /// Get the number of tiles in a problem + static int32_t problem_tile_count(cutlass::gemm::GemmCoord const& problem) + { + auto grid = BaseKernel::ProblemVisitor::grid_shape(problem); + return BaseKernel::ProblemVisitor::tile_count(grid); + } + + /// Get the number of tiles across all problems in a group + static int32_t group_tile_count(Arguments const& args) + { + if (args.host_problem_sizes == nullptr) + { + CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes"); + return -1; + } + + return group_tile_count(args.host_problem_sizes, args.problem_count); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args) + { + size_t total_mn = 0; + for (int i = 0; i < args.problem_count; i++) + { + total_mn += args.host_problem_sizes[i].m() * args.host_problem_sizes[i].n(); + } + size_t workSpaceSize = total_mn * sizeof(ElementAccumulator) * args.split_k_slices; + + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) + { + workSpaceSize += BaseKernel::ProblemVisitor::get_workspace_size( + args.host_problem_sizes, args.problem_count, args.threadblock_count); + } + return workSpaceSize; + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const& args) + { + + return dim3(args.threadblock_count, 1, 1); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) + { + + CUTLASS_TRACE_HOST("BaseSplitkGrouped::maximum_active_blocks()"); + + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + + cudaError_t result; + if (smem_size > (48 << 10)) + { + result = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) + { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result)); + return -1; + } + } + + int max_active_blocks = -1; + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, Kernel, BaseKernel::kThreadCount, smem_size); + + if (result != cudaSuccess) + { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Sorts each pointer passed in according to the indices that sort + /// `problem_sizes_ptr` in descending order of problem-K dimension. + static void sort_problems(int problem_count, cutlass::gemm::GemmCoord* problem_sizes_ptr, int64_t* lda_host_ptr, + int64_t* ldb_host_ptr, int64_t* ldc_host_ptr, int64_t* ldd_host_ptr, int64_t* offset_A_ptr, + int64_t* offset_B_ptr, int64_t* offset_C_ptr, int64_t* offset_D_ptr) + { + std::vector indices(problem_count); + std::iota(indices.begin(), indices.end(), 0); + std::stable_sort(indices.begin(), indices.end(), + [&problem_sizes_ptr](size_t i, size_t j) { return problem_sizes_ptr[i].k() > problem_sizes_ptr[j].k(); }); + + reorder_array(problem_sizes_ptr, indices); + reorder_array(lda_host_ptr, indices); + reorder_array(ldb_host_ptr, indices); + reorder_array(ldc_host_ptr, indices); + reorder_array(ldd_host_ptr, indices); + reorder_array(offset_A_ptr, indices); + reorder_array(offset_B_ptr, indices); + reorder_array(offset_C_ptr, indices); + reorder_array(offset_D_ptr, indices); + } + + /// Computes the number of threadblocks to launch for the grouped kernel + static int sufficient( + cutlass::gemm::GemmCoord const* problem_sizes_ptr = nullptr, int problem_count = 0, int available_sm_count = -1) + { + // Determine the number of blocks that would be launched to fill up a single + // wave on the GPU with each SM having maximum occupancy. + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + if (result != cudaSuccess) + { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(result)); + return 0; + } + + int multiprocessor_count; + result = cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device_idx); + if (result != cudaSuccess) + { + CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(result)); + return 0; + } + + bool override_sm_count = (available_sm_count < 0 || available_sm_count > multiprocessor_count); + if (override_sm_count) + { + available_sm_count = multiprocessor_count; + } + + int max_active_blocks = maximum_active_blocks(); + if (max_active_blocks <= 0) + { + return 0; + } + + int occupancy_based_block_count = available_sm_count * max_active_blocks; + + if (problem_sizes_ptr == nullptr || problem_count == 0) + { + return occupancy_based_block_count; + } + + int total_tiles = group_tile_count(problem_sizes_ptr, problem_count); + + // If the group contains a single problem, launching the exact number of + // threadblocks needed to cover the problem minimizes the work performed + // per threadblock in finding the next tile to compute. We return total_tiles + // unless the user has provided the SM count. + if (problem_count == 1 && override_sm_count) + { + return total_tiles; + } + + // Choose between the full wave of threadblocks and the tile count. If there + // are fewer tiles in the group than threadblocks in the full wave, only + // some threadblocks will be assigned tiles. Those threadblocks + // which are not assigned tiles still need to perform the work of iterating through + // problem sizes to determine that they have no work to do. This competes for cycles + // with those threadblocks that are assigned tiles to compute. + return std::min(total_tiles, occupancy_based_block_count); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) + { + + CUTLASS_TRACE_HOST("BaseSplitkGrouped::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Workspace + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) + { + return Status::kErrorWorkspaceNull; + } + + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) + { + int32_t tile_count = group_tile_count(args); + Status status = precompute(args, tile_count, workspace); + if (status != Status::kSuccess) + { + return status; + } + + gemm_params_ = typename BaseKernel::Params(args, workspace, tile_count); + } + else + { + gemm_params_ = typename BaseKernel::Params(args, workspace); + } + + // Specify shared memory capacity for kernel. + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) + { + cudaError_t result + = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) + { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const& args, void* workspace = nullptr) + { + + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) + { + return Status::kErrorWorkspaceNull; + } + + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) + { + int32_t tile_count = group_tile_count(args); + Status status = precompute(args, tile_count, workspace); + if (status != Status::kSuccess) + { + return status; + } + + gemm_params_.update(args, workspace, tile_count); + } + else + { + gemm_params_.update(args, workspace); + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) + { + if (!gemm_params_.problem_visitor.problem_count) + { + return Status::kSuccess; + } + + // + // Launch kernel + // + + // Launch splitk grouped gemm + { + dim3 grid(gemm_params_.threadblock_count, 1, gemm_params_.split_k_slices); + dim3 block(BaseKernel::kThreadCount, 1, 1); + + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + cutlass::Kernel<<>>(gemm_params_); + + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) + { + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + // Launch splitkReduction + { + dim3 grid(32, gemm_params_.problem_visitor.problem_count); + dim3 block(256); + splitkReduction<<>>(gemm_params_.ptr_D, gemm_params_.ptr_D_split, + gemm_params_.problem_visitor.problem_sizes, gemm_params_.split_k_slices, + gemm_params_.splitk_buffer_offsets); + + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) + { + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) + { + return run(stream); + } + + /// Initializes and runs the kernel. + Status operator()(Arguments const& args, void* workspace, cudaStream_t stream = nullptr) + { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) + { + status = run(stream); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GEMM Grouped +template +class SplitkGemmGrouped : public BaseSplitkGrouped +{ +public: + using GemmKernel = GemmKernel_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h new file mode 100644 index 00000000000..100a1161a88 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h @@ -0,0 +1,162 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/bfloat16.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/half.h" +#include "cutlass/layout/matrix.h" + +#include "cutlass_extensions/arch/mma.h" +#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +template +struct MixedGemmArchTraits +{ + static_assert(dependent_false, "Unrecognised parameterization"); +}; + +template +struct MixedGemmArchTraits +{ + static constexpr int Stages = 2; + using OperatorClass = cutlass::arch::OpClassSimt; + using AccType = float; + using LayoutB = cutlass::layout::ColumnMajor; + + static constexpr int ElementsPerAccessA = 1; + static constexpr int ElementsPerAccessB = 1; + static constexpr int ElementsPerAccessC = 1; + static constexpr int ThreadblockK = 8; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// ======================= Turing Traits ============================== +// Note that turing does not have native bfloat support so weights and activations will be casted to fp16 +// and compute will happen in fp16 then will be converted for bf16 output. +template +struct MixedGemmArchTraits::value + || cutlass::platform::is_same::value>::type> +{ +private: + using LayoutDetails = LayoutDetailsB; + +public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using Operator = typename LayoutDetails::Operator; +}; + +// ======================= Ampere Traits ============================== +template +struct MixedGemmArchTraits::value + || cutlass::platform::is_same::value>::type> +{ +private: + using LayoutDetails = LayoutDetailsB; + +public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + using Operator = typename LayoutDetails::Operator; +}; + +// ======================= Ada Traits ============================== +template +struct MixedGemmArchTraits::value + || cutlass::platform::is_same::value>::type> +{ +private: + using LayoutDetails = LayoutDetailsB; + +public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits::value>; + + using Operator = typename LayoutDetails::Operator; +}; + +// FP8 A/B = fp8, C/D = fp32 +template +struct MixedGemmArchTraits::value + || cutlass::platform::is_same::value>::type> +{ +private: + using LayoutDetails = LayoutDetailsB; + +public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + // be careful, TypeC should align with HopperGroupedGemmInput::OutputTypeAdaptor_t + using TypeC = __nv_bfloat16; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits::value>; + + using Operator = typename LayoutDetails::Operator; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_int8_traits.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_int8_traits.h new file mode 100644 index 00000000000..3fd722994e2 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_int8_traits.h @@ -0,0 +1,57 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +template +struct Int8GemmArchTraits +{ + using OperatorClass = cutlass::arch::OpClassSimt; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; +}; + +// ======================= Turing Traits ============================== +template <> +struct Int8GemmArchTraits +{ + using OperatorClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; +}; + +// ======================= Ampere Traits ============================== +template <> +struct Int8GemmArchTraits +{ + using OperatorClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h new file mode 100644 index 00000000000..1dbd0b1765f --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h @@ -0,0 +1,207 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with + the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are + accommodated by exchanging A and B operands and assuming transposed layouts. Partial + specializations here choose 'device::GemmTransposed' to implement this functionality. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/complex.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/kernel/default_gemm_complex.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" + +#include "cutlass/layout/permute.h" + +#include "splitk_gemm_grouped.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Whether the schedule of problems to visit has been precomputed + GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly, + /// Operation performed by GEMM + typename Operator = typename device::DefaultGemmConfiguration::Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Permute result D + typename PermuteDLayout = layout::NoPermute, + /// + typename Enable = void> +struct DefaultSplitkGemmGrouped; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Real-valued GEMM kernels +// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Whether the schedule of problems to visit has been precomputed + GroupScheduleMode GroupScheduleMode_, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Permute result D + typename PermuteDLayout> +struct DefaultSplitkGemmGrouped::value>::type> +{ + + // If true, we must construct a 'transposed-and-exchanged' Mma operator. + static bool const kInternalTranspose = platform::is_same::value; + + using MapArguments = kernel::detail::MapArguments; + + // Define the default GEMM kernel + using DefaultGemmKernel = typename kernel::DefaultGemm::GemmKernel; + + /// Define the kernel in terms of the default kernel + using GemmKernel = kernel::SplitkGemmGrouped; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h new file mode 100644 index 00000000000..0baec58ea9a --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h @@ -0,0 +1,566 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail +{ +template +inline constexpr bool dependent_false_v = false; +} + +template +struct GemmFpAIntB +{ + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static bool const kSplitKSerial = SplitKSerial; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Element; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Mma::LayoutC; + using ElementScale = ElementC; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformA; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + + /// Parameters structure + struct Arguments + { + GemmUniversalMode mode = GemmUniversalMode::kGemm; + + cutlass::gemm::GemmCoord problem_size; + int group_size; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorScale::TensorRef ref_scale; + typename Mma::IteratorScale::TensorRef ref_zero; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + + // Control serial split-k + int batch_count; + + typename EpilogueOutputOp::Params output_op; + + // For gather+scatter operations + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; + + // Included so we can use Gemm Universal + int batch_stride_D = 0; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Arguments() {} + + CUTLASS_HOST_DEVICE + Arguments(cutlass::gemm::GemmCoord const& problem_size, int const group_size, + typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorScale::TensorRef ref_scale, typename Mma::IteratorScale::TensorRef ref_zero, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D, int serial_split_k_factor, + typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(), + int const* gather_A_indices = nullptr, int const* gather_B_indices = nullptr, + int const* scatter_D_indices = nullptr) + : problem_size(problem_size) + , group_size(group_size) + , ref_A(ref_A) + , ref_B(ref_B) + , ref_scale(ref_scale) + , ref_zero(ref_zero) + , ref_C(ref_C) + , ref_D(ref_D) + , batch_count(serial_split_k_factor) + , output_op(output_op) + , gather_A_indices(gather_A_indices) + , gather_B_indices(gather_B_indices) + , scatter_D_indices(scatter_D_indices) + { + } + }; + + /// Parameters structure + struct Params + { + cutlass::gemm::GemmCoord problem_size; + int group_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::Params params_B; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorScale::Params params_scale; + typename Mma::IteratorScale::TensorRef ref_scale; + typename Mma::IteratorScale::TensorRef ref_zero; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + typename EpilogueOutputOp::Params output_op; + int* semaphore; + int gemm_k_size; + // For gather+scatter operations + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : swizzle_log_tile(0) + , semaphore(0) + , gemm_k_size(0) + { + } + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, int const gemm_k_size, + void* workspace = nullptr) + : problem_size(args.problem_size) + , group_size(args.group_size) + , grid_tiled_shape(grid_tiled_shape) + , swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)) + , params_A(args.ref_A.layout()) + , ref_A(args.ref_A) + , params_B(args.ref_B.layout()) + , ref_B(args.ref_B) + , params_scale(args.ref_scale.layout()) + , ref_scale(args.ref_scale) + , ref_zero(args.ref_zero) + , params_C(args.ref_C.layout()) + , ref_C(args.ref_C) + , params_D(args.ref_D.layout()) + , ref_D(args.ref_D) + , output_op(args.output_op) + , semaphore(static_cast(workspace)) + , gemm_k_size(gemm_k_size) + , gather_A_indices(args.gather_A_indices) + , gather_B_indices(args.gather_B_indices) + , scatter_D_indices(args.scatter_D_indices) + { + } + }; + + /// Shared memory storage structure + union SharedStorage + { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + GemmFpAIntB() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(Arguments const& args) + { + static int const kAlignmentA + = (platform::is_same>::value) ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB + = (platform::is_same>::value) ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorB::AccessType::kElements; + + static int const kAlignmentScale = Mma::IteratorScale::AccessType::kElements; + + static int const kAlignmentC = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Epilogue::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(args.ref_A, kAlignmentA)) + { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_B, kAlignmentB)) + { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_scale, kAlignmentScale)) + { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_zero, kAlignmentScale)) + { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_C, kAlignmentC)) + { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_D, kAlignmentC)) + { + return Status::kErrorMisalignedOperand; + } + + if (!args.ref_scale.good()) + { + return Status::kErrorNotSupported; + } + + if constexpr (hasZero(Mma::QuantOp)) + { + if (!args.ref_zero.good()) + { + return Status::kErrorNotSupported; + } + } + else + { + if (args.ref_zero.good()) + { + return Status::kErrorNotSupported; + } + } + + if constexpr (isFinegrained(Mma::QuantOp)) + { + if (args.group_size != 64 && args.group_size != 128) + { + return Status::kErrorNotSupported; + } + } + + return Status::kSuccess; + } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) + { + + return 0; + } + + // Initializes the fine grained scale+bias iterator. Needed since the fine grained iterator + // has a different constructor signature than a regular cutlass iterator + template = true> + CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero, + typename IteratorScale::TensorCoord extent, int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, int group_size) + { + + return IteratorScale(params, pointer_scale, pointer_zero, extent, thread_id, threadblock_offset, group_size); + } + + template = true> + CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero, + typename IteratorScale::TensorCoord extent, int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, int group_size) + { + + return IteratorScale(params, pointer_scale, extent, thread_id, threadblock_offset); + } + + CUTLASS_DEVICE + void run_kernel_(Params const& params, SharedStorage& shared_storage) + { + using LayoutB = typename Mma::IteratorB::Layout; + static_assert(platform::is_same::value && kInterleave == 1 + || platform::is_same::value && kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() + || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) + { + + return; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size, + }; + + cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave, + threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave}; + + typename MatrixCoord::Index fg_row_offset = threadblock_tile_offset.k() * params.gemm_k_size / 64; + typename MatrixCoord::Index scale_row_offset = isFinegrained(Mma::QuantOp) ? fg_row_offset : 0; + cutlass::MatrixCoord tb_offset_scale{scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN}; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params.params_A, params.ref_A.data(), + {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A, params.gather_A_indices); + + typename Mma::IteratorB iterator_B(params.params_B, params.ref_B.data(), + {problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, thread_idx, tb_offset_B, + params.gather_B_indices); + + typename MatrixCoord::Index scale_row_extent = isFinegrained(Mma::QuantOp) ? problem_size_k / 64 : 1; + typename Mma::IteratorScale iterator_scale = initialize_scale( + params.params_scale, params.ref_scale.data(), params.ref_zero.data(), + {scale_row_extent, params.problem_size.n()}, thread_idx, tb_offset_scale, params.group_size); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + if (!kSplitKSerial || gemm_k_iterations > 0) + { + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators); + } + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) + { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C(params.params_C, params.ref_C.data(), params.problem_size.mn(), + thread_idx, threadblock_offset, params.scatter_D_indices); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D(params.params_D, params.ref_D.data(), params.problem_size.mn(), + thread_idx, threadblock_offset, params.scatter_D_indices); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) + { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) + { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) + { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) + { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else + { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } + + template + CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) + { + if constexpr (platform::is_same::value) + { + run_kernel_(params, shared_storage); + } + else + { + CUTLASS_NOT_IMPLEMENTED(); + } + } + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) + { +#if defined(__CUDA_ARCH__) +#if (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ == 890) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 900) + CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels. +#else + static_assert( + false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); +#endif +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh new file mode 100644 index 00000000000..1bd0a3f11a8 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh @@ -0,0 +1,218 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include +#include + +namespace fused_moe +{ +template +struct Fused_Moe_Kernel_sm80 +{ + static constexpr int kMaxTileM = MaxTileM_; + static constexpr int kTileN = isGateActivation(activation_type_) ? TileN_ / 2 : TileN_; + static constexpr int kTileK = TileK_; + static constexpr int kStages = Stages_; + static constexpr Activation_Type activation_type = activation_type_; + + using ElementInput = ElementInput_; + using ElementWeight = ElementWeight_; + using ElementOutput = ElementOutput_; + using BaseKernelTraits = Fused_Moe_Kernel_traits_sm80; + using Routine_Arguments = Routine_Arguments; + using Routine_Params = Routine_Params; + using ProblemVisitor + = cutlass::gemm::kernel::MoeProblemVisitor, false>, + cutlass::gemm::GemmShape, cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, + BaseKernelTraits::kThreadCount, BaseKernelTraits::kThreadCount>; + + struct Arguments + { + Routine_Arguments routine_args; + int problem_count{}; + int threadblock_count{}; + }; + + struct Params + { + Routine_Params routine_params; + int threadblock_count{}; + typename ProblemVisitor::Params problem_visitor_param; + }; + + using BaseKernelTraits_m16 = Fused_Moe_Kernel_traits_sm80; + static constexpr bool use_m16 = TileK_ >= 64; // use tileshape m = 16 when original tileshape k >= 64 + + static constexpr int kSmemSize = use_m16 + ? (BaseKernelTraits::kSmemSize > BaseKernelTraits_m16::kSmemSize ? BaseKernelTraits::kSmemSize + : BaseKernelTraits_m16::kSmemSize) + : BaseKernelTraits::kSmemSize; + static constexpr int kThreadCount = BaseKernelTraits::kThreadCount; + + static constexpr bool can_implement(int const avaliable_smem_size) + { + return BaseKernelTraits::can_implement(avaliable_smem_size); + } + + static Params to_underlying_arguments(Arguments const& args) + { + return { + {args.routine_args.ptr_input, args.routine_args.ptr_fc1, args.routine_args.ptr_bias, + args.routine_args.ptr_output, args.routine_args.total_tokens_including_expert, args.routine_args.gemm_n, + args.routine_args.gemm_k, args.routine_args.num_expert, args.routine_args.bias_is_broadcast}, + args.threadblock_count, + {args.routine_args.total_tokens_including_expert, args.routine_args.gemm_n, args.routine_args.gemm_k, + args.problem_count, nullptr, 0}}; + } + + CUTE_DEVICE + void run_device(Params const& params) + { +#define ROUTINE_PATH(kTileM_size) \ + { \ + constexpr int kTileM = use_m16 ? (kTileM_size) : ((kTileM_size) == 16 ? 32 : (kTileM_size)); \ + using RoutineTraits = Fused_Moe_Kernel_routine_sm80; \ + RoutineTraits routine{}; \ + int const block_m_idx = (block_m_idx_temp) *kMaxTileM / kTileM; \ + routine.run_routine(params.routine_params, problem_index, block_m_idx, block_n_idx, gemm_m); \ + } + typename ProblemVisitor::SharedStorage dummy_storage{}; + ProblemVisitor problem_visitor(params.problem_visitor_param, dummy_storage, blockIdx.x); + while (problem_visitor.next_tile()) + { + auto problem_size = problem_visitor.problem_size(); + auto grid_size = problem_visitor.grid_shape(problem_size); + auto problem_index = problem_visitor.problem_index(); + int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); + int const gemm_m = problem_size.m(); + const int32_t block_m_idx_temp = cta_idx / grid_size.n(); + const int32_t block_n_idx = cta_idx % grid_size.n(); + + int const residue_m = gemm_m - kMaxTileM * block_m_idx_temp; + if (residue_m > kMaxTileM / 2) + { + using RoutineTraits = Fused_Moe_Kernel_routine_sm80; + RoutineTraits routine{}; + routine.run_routine(params.routine_params, problem_index, block_m_idx_temp, block_n_idx, gemm_m); + } + else + { + + if constexpr (kMaxTileM >= 128) + { + if (residue_m > 32) + { + ROUTINE_PATH(64); + } + else if (residue_m > 16) + { + ROUTINE_PATH(32); + } + else + { + // TODO: use cuda core gemm here + ROUTINE_PATH(16); + } + } + else if (kMaxTileM == 64) + { + if (residue_m > 16) + { + ROUTINE_PATH(32); + } + else + { + // TODO: use cuda core gemm here + ROUTINE_PATH(16); + } + } + else if (kMaxTileM == 32) + { + // TODO: use cuda core gemm here + ROUTINE_PATH(16); + } + else + { + // TODO: use cuda core gemm here + ROUTINE_PATH(16); + } + } + problem_visitor.advance(gridDim.x); + } +#undef ROUTINE_PATH + } +}; + +template +__global__ void run_global(__grid_constant__ typename GemmType::Params const params) +{ + GemmType gemm; + gemm.run_device(params); +} + +/// Computes the maximum number of active blocks per multiprocessor +template +static int fused_gemm_maximum_active_blocks(int smem_capacity = -1) +{ + + CUTLASS_TRACE_HOST("BaseGrouped::maximum_active_blocks()"); + + constexpr int smem_size = GemmType::kSmemSize; + + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + + cudaError_t result; + if (smem_size > (48 << 10)) + { + result = cudaFuncSetAttribute(run_global, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) + { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result)); + return -1; + } + } + + int max_active_blocks = -1; + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, run_global, GemmType::kThreadCount, smem_size); + + if (result != cudaSuccess) + { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; +} +} // namespace fused_moe diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh new file mode 100644 index 00000000000..4c46a541efd --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh @@ -0,0 +1,799 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace fused_moe +{ + +template +struct Fused_Moe_Kernel_routine_sm80; + +template +struct Fused_Moe_Kernel_routine_sm80> +{ + using KT = Fused_Moe_Kernel_traits_sm80; + using Params = Routine_Params; + + CUTE_DEVICE auto gmem_tensor_init(int const problem_index, int const gemm_m, Params const& params) + { + using X = cute::Underscore; + + int const M = gemm_m; + int const N1 = params.gemm_n; + int const K1 = params.gemm_k; + bool const bias_is_broadcast = params.bias_is_broadcast; + + int const row_jump = ((problem_index == 0) ? 0 : params.total_tokens_including_expert[problem_index - 1]); + typename KT::ElementInput const* ptr_input_ = params.ptr_input + row_jump * K1; + typename KT::ElementWeight const* ptr_fc1_gate_ + = params.ptr_fc1 + (2 * problem_index + 1) * N1 * K1; // TODO: we only focus on gated activation.. + typename KT::ElementWeight const* ptr_fc1_ + = params.ptr_fc1 + 2 * problem_index * N1 * K1; // TODO: we only focus on gated activation.. + typename KT::ElementInput const* ptr_bias_ = (params.ptr_bias == nullptr) + ? nullptr + : (bias_is_broadcast ? params.ptr_bias + 2 * problem_index * N1 : params.ptr_bias + 2 * row_jump * N1); + typename KT::ElementInput const* ptr_bias_gate_ = (params.ptr_bias == nullptr) + ? nullptr + : (bias_is_broadcast ? params.ptr_bias + (2 * problem_index + 1) * N1 + : params.ptr_bias + (2 * row_jump + 1) * N1); + typename KT::ElementOutput* ptr_output_ = params.ptr_output + row_jump * N1; + + cute::Tensor mInput_mk + = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_input_)), + cute::make_shape(M, K1), cute::make_stride(K1, cute::_1{})); + + cute::Tensor mfc1_gate_nk + = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_fc1_gate_)), + cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{})); + + cute::Tensor mfc1_nk + = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_fc1_)), + cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{})); + + cute::Tensor mBias_mn = cute::make_tensor( + cute::make_gmem_ptr(static_cast(ptr_bias_)), cute::make_shape(M, N1), + cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1 * 2, + cute::_1{})); // trick: bias shape is [1, N], but we use [M, N]. + + cute::Tensor mBias_gate_mn = cute::make_tensor( + cute::make_gmem_ptr(static_cast(ptr_bias_gate_)), cute::make_shape(M, N1), + cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1 * 2, + cute::_1{})); // trick: bias shape is [1, N], but we use [M, N]. + + cute::Tensor mOutput_mn + = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_output_)), + cute::make_shape(M, N1), cute::make_stride(N1, cute::_1{})); + + cute::Tensor gInput_mk = cute::local_tile(mInput_mk, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_K, m, k) + cute::Tensor gfc1_gate_nk = cute::local_tile(mfc1_gate_nk, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_N, BLK_K, n, k) + cute::Tensor gfc1_nk = cute::local_tile(mfc1_nk, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_N, BLK_K, n, k) + + cute::Tensor gBias_mn = cute::local_tile(mBias_mn, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) + + cute::Tensor gBias_gate_mn = cute::local_tile(mBias_gate_mn, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) + + cute::Tensor gOutput_mn = cute::local_tile(mOutput_mn, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) + + return cute::make_tuple(gInput_mk, gfc1_gate_nk, gfc1_nk, gBias_mn, gBias_gate_mn, gOutput_mn); + } + + // be careful, m_idx will change when use another tile shape.. + CUTE_DEVICE void run_routine( + Params const& params, int const problem_index, int const block_m_idx, int const block_n_idx, int const gemm_m) + { + extern __shared__ char smem_[]; + typename KT::SharedStorage& shared_storage = *reinterpret_cast(smem_); + int const thread_idx = threadIdx.x; + bool const bias_is_broadcast = params.bias_is_broadcast; + // gmem tensor partition .. + auto [gInput_mk, gfc1_gate_nk, gfc1_nk, gBias_mn, gBias_gate_mn, gOutput_mn] + = gmem_tensor_init(problem_index, gemm_m, params); + int const residue_m = gemm_m - block_m_idx * cute::size<0>(gInput_mk); + auto const n_tile_count = cute::size<2>(gfc1_gate_nk); + + // smem tensor .. + cute::Tensor sInput = cute::make_tensor( + cute::make_smem_ptr(shared_storage.smem_input.data()), typename KT::SmemLayoutA{}); // (BLK_M, BLK_K, Stage) + cute::Tensor sfc1_weight = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_weight.data()), + typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage) + cute::Tensor sfc1_gate_weight + = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_gate_weight.data()), + typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage) + cute::Tensor sO = cute::make_tensor( + cute::make_smem_ptr(shared_storage.smem_o.data()), typename KT::SmemLayoutO{}); // (BLK_M, BLK_N) + + // (1) first step, get the fc1_res and fc1_gate + + // (1.1) get partition for gmem -> smem + cute::Tensor gInput = gInput_mk(cute::_, cute::_, block_m_idx, cute::_); // (BLK_M, BLK_K, k) + cute::Tensor gfc1 = gfc1_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k) + cute::Tensor gfc1g = gfc1_gate_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k) + + typename KT::GmemTiledCopyA gmem_tiled_copy_A; + typename KT::GmemTiledCopyB gmem_tiled_copy_B; + auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx); + auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx); + + cute::Tensor tInputgInput = gmem_thr_copy_A.partition_S(gInput); // (ACPY,ACPY_M,ACPY_K,k) + cute::Tensor tInputsInput = gmem_thr_copy_A.partition_D(sInput); // (ACPY,ACPY_M,ACPY_K,Stage) + cute::Tensor tfc1gfc1 = gmem_thr_copy_B.partition_S(gfc1); // (BCPY,BCPY_N,BCPY_K,k) + cute::Tensor tfc1sfc1 = gmem_thr_copy_B.partition_D(sfc1_weight); // (BCPY,BCPY_N,BCPY_K,Stage) + cute::Tensor tfc1ggfc1g = gmem_thr_copy_B.partition_S(gfc1g); // (BCPY,BCPY_N,BCPY_K,k) + cute::Tensor tfc1gsfc1g = gmem_thr_copy_B.partition_D(sfc1_gate_weight); // (BCPY,BCPY_N,BCPY_K,Stage) + + // Allocate predicate tensors for input and fc weight (actually we only need input predicate tensor) + cute::Tensor tInputpInput + = cute::make_tensor(cute::make_shape(cute::size<1>(tInputsInput), cute::size<2>(tInputsInput)), + cute::Stride{}); + // Construct identity layout for sInput + cute::Tensor cInput = make_identity_tensor( + make_shape(cute::size<0>(sInput), cute::size<1>(sInput))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + + // Repeat the partitioning with identity layouts + cute::Tensor tInputcInput = gmem_thr_copy_A.partition_S(cInput); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + + // Set predicates for m bounds + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < cute::size<0>(tInputpInput); ++m) + { + tInputpInput(m, 0) = cute::get<0>(tInputcInput(0, m, 0)) < residue_m; // blk_m coord < residue_m + } + + // (1.2) prefetch gmem -> smem + cute::clear(tInputsInput); // we don't need to clear tfc1sfc1.. + auto k_tile_iter = cute::make_coord_iterator(cute::size<2>(gInput)); // emm, iter start from 0 + int k_tile_count = cute::size<2>(gInput); + CUTLASS_PRAGMA_UNROLL + for (int k_pipe = 0; k_pipe < KT::Stages - 1; ++k_pipe) + { + if (k_tile_count <= 0) + { + cute::clear(tInputpInput); + } + // cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + // tInputsInput(cute::_, cute::_, cute::_, k_pipe)); + // use copy_if + cute::copy_if(gmem_tiled_copy_A, tInputpInput, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + tInputsInput(cute::_, cute::_, cute::_, k_pipe)); + cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter), + tfc1sfc1(cute::_, cute::_, cute::_, k_pipe)); + cute::copy(gmem_tiled_copy_B, tfc1ggfc1g(cute::_, cute::_, cute::_, *k_tile_iter), + tfc1gsfc1g(cute::_, cute::_, cute::_, k_pipe)); + cute::cp_async_fence(); + k_tile_count--; + if (k_tile_count > 0) + { + ++k_tile_iter; + } + } + + // (1.3) get partition for rf + typename KT::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + cute::Tensor tOrInput = thr_mma.partition_fragment_A(sInput(cute::_, cute::_, 0)); // (MMA,MMA_M,MMA_K) + cute::Tensor tOrfc1 = thr_mma.partition_fragment_B(sfc1_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K) + cute::Tensor tOrfc1g = thr_mma.partition_fragment_B(sfc1_gate_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K) + + cute::Tensor accum + = cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N) + cute::Tensor accum_gate + = cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N) + cute::clear(accum); + cute::clear(accum_gate); + // checkout the shape + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum_gate)); // MMA_M + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum)); // MMA_N + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum_gate)); // MMA_N + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1g) == cute::size<2>(accum)); // MMA_N + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1g) == cute::size<2>(accum_gate)); // MMA_N + CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1)); // MMA_K + CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1g)); // MMA_K + CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_A) == cute::size(tiled_mma)); + CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_B) == cute::size(tiled_mma)); + + // (1.4)retiling the smem and rf for copy.. + auto smem_tiled_copy_A = cute::make_tiled_copy_A(typename KT::SmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + cute::Tensor tOsInput = smem_thr_copy_A.partition_S(sInput); // (CPY,CPY_M,CPY_K,Stage) + cute::Tensor tOrInput_copy_view = smem_thr_copy_A.retile_D(tOrInput); // (CPY,CPY_M,CPY_K) + CUTE_STATIC_ASSERT_V(cute::size<1>(tOsInput) == cute::size<1>(tOrInput_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(cute::size<2>(tOsInput) == cute::size<2>(tOrInput_copy_view)); // CPY_K + + auto smem_tiled_copy_B = cute::make_tiled_copy_B(typename KT::SmemCopyAtomB{}, tiled_mma); + auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); + cute::Tensor tOsfc1 = smem_thr_copy_B.partition_S(sfc1_weight); // (CPY,CPY_N,CPY_K,Stage) + cute::Tensor tOrfc1_copy_view = smem_thr_copy_B.retile_D(tOrfc1); // (CPY,CPY_N,CPY_K) + cute::Tensor tOsfc1g = smem_thr_copy_B.partition_S(sfc1_gate_weight); // (CPY,CPY_N,CPY_K,Stage) + cute::Tensor tOrfc1g_copy_view = smem_thr_copy_B.retile_D(tOrfc1g); // (CPY,CPY_N,CPY_K) + CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1) == cute::size<1>(tOrfc1_copy_view)); // CPY_N + CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1) == cute::size<2>(tOrfc1_copy_view)); // CPY_K + CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1g) == cute::size<1>(tOrfc1g_copy_view)); // CPY_N + CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1g) == cute::size<2>(tOrfc1g_copy_view)); // CPY_K + + // (1.5) mainloop + // Current pipe index in smem to read from + int smem_pipe_read = 0; + // Current pipe index in smem to write to + int smem_pipe_write = KT::Stages - 1; + + cute::Tensor tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); + cute::Tensor tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); + cute::Tensor tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read); + + constexpr int K_BLOCK_MAX = cute::size<2>(tOrInput); + // prefetch register pipeline + if constexpr (K_BLOCK_MAX > 1) + { + cute::cp_async_wait(); + __syncthreads(); + + // Prefetch the first rmem from the first k-tile + cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, cute::Int<0>{}), + tOrInput_copy_view(cute::_, cute::_, cute::Int<0>{})); + cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, cute::Int<0>{}), + tOrfc1_copy_view(cute::_, cute::_, cute::Int<0>{})); + cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, cute::Int<0>{}), + tOrfc1g_copy_view(cute::_, cute::_, cute::Int<0>{})); + } + // k loop for mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) + { + cute::for_each(cute::make_int_sequence{}, + [&](auto k_block) + { + if (k_block == K_BLOCK_MAX - 1) + { + tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); + tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); + tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read); + cute::cp_async_wait(); + __syncthreads(); + } + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; + cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), + tOrInput_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), + tOrfc1_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next), + tOrfc1g_copy_view(cute::_, cute::_, k_block_next)); + // Copy gmem to smem before computing gemm on each k-pipe + if (k_block == 0) + { + // cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + // tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::copy_if(gmem_tiled_copy_A, tInputpInput, + tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter), + tfc1sfc1(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::copy(gmem_tiled_copy_B, tfc1ggfc1g(cute::_, cute::_, cute::_, *k_tile_iter), + tfc1gsfc1g(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::cp_async_fence(); + if (k_tile_count - 1 > 0) + { + ++k_tile_iter; + } + + // Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe) + smem_pipe_write = smem_pipe_read; + ++smem_pipe_read; + smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read; + } + // Thread-level register gemm for k_block + cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), + accum); + cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block), + tOrfc1g(cute::_, cute::_, k_block), accum_gate); + }); + } + + // load tail + cute::for_each(cute::make_int_sequence{}, + [&](auto WaitIndex) + { + k_tile_count--; + using WaitIndex_t = decltype(WaitIndex); + cute::for_each(cute::make_int_sequence{}, + [&](auto k_block) + { + if (k_block == K_BLOCK_MAX - 1) + { + tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); + tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); + tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read); + cute::cp_async_wait(); + __syncthreads(); + } + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; + cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), + tOrInput_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), + tOrfc1_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next), + tOrfc1g_copy_view(cute::_, cute::_, k_block_next)); + if (k_block == 0) + { + // only update smem_pipe_read + ++smem_pipe_read; + smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read; + } + // Thread-level register gemm for k_block + cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), + tOrfc1(cute::_, cute::_, k_block), accum); + cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block), + tOrfc1g(cute::_, cute::_, k_block), accum_gate); + }); + }); + // mma tail + cute::for_each(cute::make_int_sequence{}, + [&](auto k_block) + { + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; + cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), + tOrInput_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), + tOrfc1_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next), + tOrfc1g_copy_view(cute::_, cute::_, k_block_next)); + // Thread-level register gemm for k_block + cute::gemm( + tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), accum); + cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block), + tOrfc1g(cute::_, cute::_, k_block), accum_gate); + }); + // if (cute::thread0()) { + // cute::print(accum_gate(0, 0, 0)); + // printf("\n"); + // } + // (2) add bias if it has.. + if (params.ptr_bias != nullptr) + { + cute::Tensor gBias = gBias_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx); + cute::Tensor gBias_gate = gBias_gate_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx); + cute::Tensor tOgBias = thr_mma.partition_C(gBias); + cute::Tensor tOgBiasg = thr_mma.partition_C(gBias_gate); + for (int i = 0; i < cute::size(accum); i++) + { + accum(i) += tOgBias(i); + accum_gate(i) += tOgBiasg(i); + } + } + + // (3) calculate swiglu + using ActivationFn = typename KT::ActivationFn; + ActivationFn fn{}; + CUTLASS_PRAGMA_UNROLL + for (int temp_iter = 0; temp_iter < cute::size(accum); temp_iter++) + { + accum(temp_iter) = fn(accum_gate(temp_iter)) * accum(temp_iter); + } + + // (4) push all the result to smem + // (4.1) convert result from ElementAccum to ElementInput + cute::Tensor temp_accum = util_convert_type(accum); + // if (cute::thread0()) { + // cute::print(temp_accum(0, 0, 0)); + // printf("\n"); + // } + // (4.2) retile rf and smem for copy back.. + auto smem_tiled_copy_O = cute::make_tiled_copy_C(typename KT::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); + // cute::clear(sO); + cute::Tensor taccumrO = smem_thr_copy_O.retile_S(temp_accum); + cute::Tensor taccumsO = smem_thr_copy_O.partition_D(sO); + + // (4.3) copy rf result to smem (TODO: maybe use forloop for better performance..) + cute::copy(smem_tiled_copy_O, taccumrO, taccumsO); + __syncthreads(); + + // (4.4) sO -> rO -> gO + + typename KT::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + // auto gmem_thr_copy_Bias = gmem_tiled_copy_O.get_thread_slice(thread_idx % KT::kGmemTrheadsPerRow); // + // remember, for all the threads in the same col, they have the same idx for bias.. + cute::Tensor gO = gOutput_mn(cute::_, cute::_, block_m_idx, block_n_idx); + // cute::Tensor gBias = gBias_mn(cute::_, cute::_, 0, block_n_idx); // bias only have one row.. + auto tOsO = gmem_thr_copy_O.partition_S(sO); + auto tOgO = gmem_thr_copy_O.partition_D(gO); + // auto tOgBias = gmem_thr_copy_O.partition_D(gBias); + cute::Tensor cOutput = cute::make_identity_tensor( + cute::make_shape(cute::size<0>(typename KT::TileShape{}), cute::size<1>(typename KT::TileShape{}))); + cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cOutput); + cute::Tensor tOrO = cute::make_tensor(cute::shape(tOgO)); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < cute::size<1>(tOgO); ++m) + { + if (cute::get<0>(tOcO(0, m, 0)) < residue_m) + { + cute::copy(gmem_tiled_copy_O, tOrO(cute::_, m, cute::_), tOgO(cute::_, m, cute::_)); + } + } + } +}; + +template +struct Fused_Moe_Kernel_routine_sm80> +{ + + using KT = Fused_Moe_Kernel_traits_sm80; + using Params = Routine_Params; + + CUTE_DEVICE auto gmem_tensor_init(int const problem_index, int const gemm_m, Params const& params) + { + using X = cute::Underscore; + + int const M = gemm_m; + int const N1 = params.gemm_n; + int const K1 = params.gemm_k; + bool const bias_is_broadcast = params.bias_is_broadcast; + + int const row_jump = ((problem_index == 0) ? 0 : params.total_tokens_including_expert[problem_index - 1]); + typename KT::ElementInput const* ptr_input_ = params.ptr_input + row_jump * K1; + typename KT::ElementWeight const* ptr_fc1_ = params.ptr_fc1 + problem_index * N1 * K1; + typename KT::ElementInput const* ptr_bias_ = (params.ptr_bias == nullptr) + ? nullptr + : (bias_is_broadcast ? params.ptr_bias + problem_index * N1 : params.ptr_bias + row_jump * N1); + typename KT::ElementOutput* ptr_output_ = params.ptr_output + row_jump * N1; + + cute::Tensor mInput_mk + = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_input_)), + cute::make_shape(M, K1), cute::make_stride(K1, cute::_1{})); + + cute::Tensor mfc1_nk + = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_fc1_)), + cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{})); + + cute::Tensor mBias_mn = cute::make_tensor( + cute::make_gmem_ptr(static_cast(ptr_bias_)), cute::make_shape(M, N1), + cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1, + cute::_1{})); // trick: bias shape is [1, N], but we use [M, N]. + + cute::Tensor mOutput_mn + = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_output_)), + cute::make_shape(M, N1), cute::make_stride(N1, cute::_1{})); + + cute::Tensor gInput_mk = cute::local_tile(mInput_mk, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_K, m, k) + cute::Tensor gfc1_nk = cute::local_tile(mfc1_nk, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_N, BLK_K, n, k) + + cute::Tensor gBias_mn = cute::local_tile(mBias_mn, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) + + cute::Tensor gOutput_mn = cute::local_tile(mOutput_mn, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) + + return cute::make_tuple(gInput_mk, gfc1_nk, gBias_mn, gOutput_mn); + } + + // be careful, m_idx will change when use another tile shape.. + CUTE_DEVICE void run_routine( + Params const& params, int const problem_index, int const block_m_idx, int const block_n_idx, int const gemm_m) + { + extern __shared__ char smem_[]; + typename KT::SharedStorage& shared_storage = *reinterpret_cast(smem_); + int const thread_idx = threadIdx.x; + bool const bias_is_broadcast = params.bias_is_broadcast; + // gmem tensor partition .. + auto [gInput_mk, gfc1_nk, gBias_mn, gOutput_mn] = gmem_tensor_init(problem_index, gemm_m, params); + int const residue_m = gemm_m - block_m_idx * cute::size<0>(gInput_mk); + auto const n_tile_count = cute::size<2>(gfc1_nk); + + // smem tensor .. + cute::Tensor sInput = cute::make_tensor( + cute::make_smem_ptr(shared_storage.smem_input.data()), typename KT::SmemLayoutA{}); // (BLK_M, BLK_K, Stage) + cute::Tensor sfc1_weight = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_weight.data()), + typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage) + cute::Tensor sO = cute::make_tensor( + cute::make_smem_ptr(shared_storage.smem_o.data()), typename KT::SmemLayoutO{}); // (BLK_M, BLK_N) + + // (1) first step, get the fc1_res and fc1_gate + + // (1.1) get partition for gmem -> smem + cute::Tensor gInput = gInput_mk(cute::_, cute::_, block_m_idx, cute::_); // (BLK_M, BLK_K, k) + cute::Tensor gfc1 = gfc1_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k) + + typename KT::GmemTiledCopyA gmem_tiled_copy_A; + typename KT::GmemTiledCopyB gmem_tiled_copy_B; + auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx); + auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx); + + cute::Tensor tInputgInput = gmem_thr_copy_A.partition_S(gInput); // (ACPY,ACPY_M,ACPY_K,k) + cute::Tensor tInputsInput = gmem_thr_copy_A.partition_S(sInput); // (ACPY,ACPY_M,ACPY_K,Stage) + cute::Tensor tfc1gfc1 = gmem_thr_copy_B.partition_S(gfc1); // (BCPY,BCPY_N,BCPY_K,k) + cute::Tensor tfc1sfc1 = gmem_thr_copy_B.partition_D(sfc1_weight); // (BCPY,BCPY_N,BCPY_K,Stage) + + // Allocate predicate tensors for input and fc weight (actually we only need input predicate tensor) + cute::Tensor tInputpInput + = cute::make_tensor(cute::make_shape(cute::size<1>(tInputsInput), cute::size<2>(tInputsInput)), + cute::Stride{}); + // Construct identity layout for sInput + cute::Tensor cInput = make_identity_tensor( + make_shape(cute::size<0>(sInput), cute::size<1>(sInput))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + + // Repeat the partitioning with identity layouts + cute::Tensor tInputcInput = gmem_thr_copy_A.partition_S(cInput); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + + // Set predicates for m bounds + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < cute::size<0>(tInputpInput); ++m) + { + tInputpInput(m, 0) = cute::get<0>(tInputcInput(0, m, 0)) < residue_m; // blk_m coord < residue_m + } + + // (1.2) prefetch gmem -> smem + cute::clear(tInputsInput); // we don't need to clear tfc1sfc1.. + auto k_tile_iter = cute::make_coord_iterator(cute::size<2>(gInput)); // emm, iter start from 0 + int k_tile_count = cute::size<2>(gInput); + CUTLASS_PRAGMA_UNROLL + for (int k_pipe = 0; k_pipe < KT::Stages - 1; ++k_pipe) + { + if (k_tile_count <= 0) + { + cute::clear(tInputpInput); + } + // cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + // tInputsInput(cute::_, cute::_, cute::_, k_pipe)); + // use copy_if + cute::copy_if(gmem_tiled_copy_A, tInputpInput, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + tInputsInput(cute::_, cute::_, cute::_, k_pipe)); + cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter), + tfc1sfc1(cute::_, cute::_, cute::_, k_pipe)); + cute::cp_async_fence(); + k_tile_count--; + if (k_tile_count > 0) + { + ++k_tile_iter; + } + } + + // (1.3) get partition for rf + typename KT::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + cute::Tensor tOrInput = thr_mma.partition_fragment_A(sInput(cute::_, cute::_, 0)); // (MMA,MMA_M,MMA_K) + cute::Tensor tOrfc1 = thr_mma.partition_fragment_B(sfc1_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K) + + cute::Tensor accum + = cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N) + cute::clear(accum); + // checkout the shape + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum)); // MMA_N + CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1)); // MMA_K + CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_A) == cute::size(tiled_mma)); + CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_B) == cute::size(tiled_mma)); + + // (1.4)retiling the smem and rf for copy.. + auto smem_tiled_copy_A = cute::make_tiled_copy_A(typename KT::SmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + cute::Tensor tOsInput = smem_thr_copy_A.partition_S(sInput); // (CPY,CPY_M,CPY_K,Stage) + cute::Tensor tOrInput_copy_view = smem_thr_copy_A.retile_D(tOrInput); // (CPY,CPY_M,CPY_K) + CUTE_STATIC_ASSERT_V(cute::size<1>(tOsInput) == cute::size<1>(tOrInput_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(cute::size<2>(tOsInput) == cute::size<2>(tOrInput_copy_view)); // CPY_K + + auto smem_tiled_copy_B = cute::make_tiled_copy_B(typename KT::SmemCopyAtomB{}, tiled_mma); + auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); + cute::Tensor tOsfc1 = smem_thr_copy_B.partition_S(sfc1_weight); // (CPY,CPY_N,CPY_K,Stage) + cute::Tensor tOrfc1_copy_view = smem_thr_copy_B.retile_D(tOrfc1); // (CPY,CPY_N,CPY_K) + CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1) == cute::size<1>(tOrfc1_copy_view)); // CPY_N + CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1) == cute::size<2>(tOrfc1_copy_view)); // CPY_K + + // (1.5) mainloop + // Current pipe index in smem to read from + int smem_pipe_read = 0; + // Current pipe index in smem to write to + int smem_pipe_write = KT::Stages - 1; + + cute::Tensor tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); + cute::Tensor tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); + + constexpr int K_BLOCK_MAX = cute::size<2>(tOrInput); + // prefetch register pipeline + if constexpr (K_BLOCK_MAX > 1) + { + cute::cp_async_wait(); + __syncthreads(); + + // Prefetch the first rmem from the first k-tile + cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, cute::Int<0>{}), + tOrInput_copy_view(cute::_, cute::_, cute::Int<0>{})); + cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, cute::Int<0>{}), + tOrfc1_copy_view(cute::_, cute::_, cute::Int<0>{})); + } + // k loop for mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) + { + cute::for_each(cute::make_int_sequence{}, + [&](auto k_block) + { + if (k_block == K_BLOCK_MAX - 1) + { + tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); + tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); + cute::cp_async_wait(); + __syncthreads(); + } + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; + cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), + tOrInput_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), + tOrfc1_copy_view(cute::_, cute::_, k_block_next)); + // Copy gmem to smem before computing gemm on each k-pipe + if (k_block == 0) + { + // cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + // tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::copy_if(gmem_tiled_copy_A, tInputpInput, + tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter), + tfc1sfc1(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::cp_async_fence(); + if (k_tile_count - 1 > 0) + { + ++k_tile_iter; + } + + // Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe) + smem_pipe_write = smem_pipe_read; + ++smem_pipe_read; + smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read; + } + // Thread-level register gemm for k_block + cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), + accum); + }); + } + // load tail + cute::for_each(cute::make_int_sequence{}, + [&](auto WaitIndex) + { + k_tile_count--; + using WaitIndex_t = decltype(WaitIndex); + cute::for_each(cute::make_int_sequence{}, + [&](auto k_block) + { + if (k_block == K_BLOCK_MAX - 1) + { + tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); + tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); + cute::cp_async_wait(); + __syncthreads(); + } + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; + cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), + tOrInput_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), + tOrfc1_copy_view(cute::_, cute::_, k_block_next)); + if (k_block == 0) + { + // only update smem_pipe_read + ++smem_pipe_read; + smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read; + } + // Thread-level register gemm for k_block + cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), + tOrfc1(cute::_, cute::_, k_block), accum); + }); + }); + // mma tail + cute::for_each(cute::make_int_sequence{}, + [&](auto k_block) + { + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; + cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), + tOrInput_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), + tOrfc1_copy_view(cute::_, cute::_, k_block_next)); + // Thread-level register gemm for k_block + cute::gemm( + tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), accum); + }); + // if (cute::thread0()) { + // cute::print(accum_gate(0, 0, 0)); + // printf("\n"); + // } + // (2) add bias if it has.. + if (params.ptr_bias != nullptr) + { + cute::Tensor gBias = gBias_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx); + cute::Tensor tOgBias = thr_mma.partition_C(gBias); + for (int i = 0; i < cute::size(accum); i++) + { + accum(i) += tOgBias(i); + } + } + // (3) calculate swiglu + using ActivationFn = typename KT::ActivationFn; + ActivationFn fn{}; + CUTLASS_PRAGMA_UNROLL + for (int temp_iter = 0; temp_iter < cute::size(accum); temp_iter++) + { + accum(temp_iter) = fn(accum(temp_iter)); + } + + // (4) push all the result to smem + // (4.1) convert result from ElementAccum to ElementInput + cute::Tensor temp_accum = util_convert_type(accum); + // if (cute::thread0()) { + // cute::print(temp_accum(0, 0, 0)); + // printf("\n"); + // } + // (4.2) retile rf and smem for copy back.. + auto smem_tiled_copy_O = cute::make_tiled_copy_C(typename KT::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); + // cute::clear(sO); + cute::Tensor taccumrO = smem_thr_copy_O.retile_S(temp_accum); + cute::Tensor taccumsO = smem_thr_copy_O.partition_D(sO); + + // (4.3) copy rf result to smem (TODO: maybe use forloop for better performance..) + cute::copy(smem_tiled_copy_O, taccumrO, taccumsO); + __syncthreads(); + + // (4.4) sO -> rO -> gO + + typename KT::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + // auto gmem_thr_copy_Bias = gmem_tiled_copy_O.get_thread_slice(thread_idx % KT::kGmemTrheadsPerRow); // + cute::Tensor gO = gOutput_mn(cute::_, cute::_, block_m_idx, block_n_idx); + auto tOsO = gmem_thr_copy_O.partition_S(sO); + auto tOgO = gmem_thr_copy_O.partition_D(gO); + cute::Tensor cOutput = cute::make_identity_tensor( + cute::make_shape(cute::size<0>(typename KT::TileShape{}), cute::size<1>(typename KT::TileShape{}))); + cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cOutput); + cute::Tensor tOrO = cute::make_tensor(cute::shape(tOgO)); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < cute::size<1>(tOgO); ++m) + { + if (cute::get<0>(tOcO(0, m, 0)) < residue_m) + { + cute::copy(gmem_tiled_copy_O, tOrO(cute::_, m, cute::_), tOgO(cute::_, m, cute::_)); + } + } + } +}; + +} // namespace fused_moe diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh new file mode 100644 index 00000000000..b4c90085dbb --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh @@ -0,0 +1,215 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace fused_moe +{ +template +struct Routine_Arguments +{ + ElementInput* ptr_input{}; + ElementWeight* ptr_fc1{}; + ElementInput* ptr_bias{}; + ElementOutput* ptr_output{}; + int64_t const* total_tokens_including_expert{}; + int gemm_n{}; + int gemm_k{}; + int num_expert{}; + bool bias_is_broadcast{}; +}; + +template +struct Routine_Params +{ + ElementInput* ptr_input{}; + ElementWeight* ptr_fc1{}; + ElementInput* ptr_bias{}; + ElementOutput* ptr_output{}; + int64_t const* total_tokens_including_expert{}; + int gemm_n{}; + int gemm_k{}; + int num_expert{}; + bool bias_is_broadcast{}; +}; + +enum class Activation_Type +{ + Gelu = 0, + Relu, + Silu, + Swiglu, + Geglu, + Identity, + InvalidType +}; + +constexpr bool isGateActivation(Activation_Type const& activation_type) +{ + return activation_type == Activation_Type::Swiglu || activation_type == Activation_Type::Geglu; +} + +template +constexpr Activation_Type EpilogueRouting(bool /*is_gate*/) +{ + return Activation_Type::InvalidType; +} + +template <> +constexpr Activation_Type EpilogueRouting(bool /*is_gate*/) +{ + return Activation_Type::Identity; +} + +template <> +constexpr Activation_Type EpilogueRouting(bool /*is_gate*/) +{ + return Activation_Type::Relu; +} + +template <> +constexpr Activation_Type EpilogueRouting(bool is_gate) +{ + return is_gate ? Activation_Type::Swiglu : Activation_Type::Silu; +} + +template <> +constexpr Activation_Type EpilogueRouting(bool is_gate) +{ + return is_gate ? Activation_Type::Geglu : Activation_Type::Gelu; +} + +/* fusing all three kernels has many limitations. This is the simpler version. Just fuse first two kernels..*/ +template +struct Fused_Moe_Kernel_traits_sm80 +{ + using ElementInput = ElementInput_; + using ElementWeight = ElementWeight_; + using ElementAccum = float; + using ElementOutput = ElementOutput_; + + using index_t = uint32_t; + static_assert(TileM_ % 16 == 0); + static_assert(TileN_ % 32 == 0); + static_assert(TileK_ % 32 == 0); + static constexpr int Stages = Stages_; + static constexpr int kTileM = TileM_; + static constexpr int kTileN = TileN_; + static constexpr int kTileK = (kTileM > 16) ? (TileK_) : (TileK_ >= 64 ? TileK_ : 64); + + // tile shape + using TileShape = cute::Shape, cute::Int, cute::Int>; + static constexpr int kWarpsCount = 4; + static constexpr int kThreadCount = kWarpsCount * 32; + + // MMA atom arch and layout + using MMA_Atom_Arch = std::conditional_t, + cute::MMA_Atom, cute::MMA_Atom>; + // using ValLayoutMNK = cute::Layout>; + using ThreadLayoutMNK + = std::conditional_t, cute::_1>>, + cute::Layout, cute::_1>>>; + using ValLayoutMNK = std::conditional_t, + cute::Tile>; + using TiledMma = cute::TiledMMA; // 32x32x16 or 16x64x16 MMA for LDSM if kWarp = 4 + static constexpr int kAlignment = 8; + static constexpr int kBlcokKSmem = (kTileM == 16) ? 64 : 32; + // A memory copy operand + using DefaultOperandA + = DefaultGemm_TensorOpSm80_OperandA; + using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; + using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; + using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; + + // B memory copy operand + using DefaultOperandB + = DefaultGemm_TensorOpSm80_OperandB; + using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; + using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; + using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; + + // Output memory copy operand + using SmemLayoutAtomO = SmemLayoutAtomA; + using SmemCopyAtomO = cute::Copy_Atom; + static constexpr int kGmemElementPerLoad = sizeof(cute::uint128_t) / sizeof(ElementOutput); + static constexpr int kGmemTrheadsPerRow = kBlcokKSmem / kGmemElementPerLoad; + using GmemLayoutAtomO + = cute::Layout, cute::Int>, + cute::Stride, cute::_1>>; + using GmemTiledCopyO = decltype(cute::make_tiled_copy(cute::Copy_Atom{}, + GmemLayoutAtomO{}, cute::Layout>{})); + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2); + static_assert(cute::size<0>(TileShape{}) % cute::size<0>(SmemLayoutAtomA{}) == 0); // M + static_assert(cute::size<2>(TileShape{}) % cute::size<1>(SmemLayoutAtomA{}) == 0); // K + static_assert(cute::rank(SmemLayoutAtomB{}) == 2); + static_assert(cute::size<1>(TileShape{}) % cute::size<0>(SmemLayoutAtomB{}) == 0); // N + static_assert(cute::size<2>(TileShape{}) % cute::size<1>(SmemLayoutAtomB{}) == 0); // K + + using SmemLayoutA = decltype(cute::tile_to_shape(SmemLayoutAtomA{}, + cute::make_shape( + cute::shape<0>(TileShape{}), cute::shape<2>(TileShape{}), cute::Int{}))); // BLK_M, BLK_K, Stages + using SmemLayoutB = decltype(cute::tile_to_shape(SmemLayoutAtomB{}, + cute::make_shape( + cute::shape<1>(TileShape{}), cute::shape<2>(TileShape{}), cute::Int{}))); // BLK_N, BLK_K, Stages + using SmemLayoutO = decltype(cute::tile_to_shape( + SmemLayoutAtomO{}, cute::make_shape(cute::shape<0>(TileShape{}), cute::shape<1>(TileShape{})))); // BLK_M, BLK_N + + // we need at least 2 stages.. + static_assert(Stages >= 2); + + struct SharedStorageNormal : cute::aligned_struct<128> + { + cute::array_aligned> smem_input; + cute::array_aligned> smem_fc1_weight; + cute::array_aligned> smem_o; + }; + + struct SharedStorageGate : cute::aligned_struct<128> + { + cute::array_aligned> smem_input; + cute::array_aligned> smem_fc1_gate_weight; + cute::array_aligned> smem_fc1_weight; + cute::array_aligned> smem_o; + }; + + using SharedStorage = std::conditional_t; + + using ActivationFn = std::conditional_t, + std::conditional_t, + std::conditional_t, cutlass::epilogue::thread::Identity>>>; + + static constexpr int kSmemSize = static_cast(sizeof(SharedStorage)); + + static constexpr bool can_implement(int const avaliable_smem_size) + { + return avaliable_smem_size > kSmemSize; + } + + // #endif +}; +} // namespace fused_moe diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h new file mode 100644 index 00000000000..80a4d856085 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h @@ -0,0 +1,73 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Scheduler for grouped GEMM +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" +#include "cutlass/matrix_coord.h" + +#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h" +#include "cutlass_extensions/gemm/kernel/moe_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct GemmMoeProblemVisitor + : public MoeProblemVisitor, ThreadblockShape, + GroupScheduleMode_, PrefetchTileCount, ThreadCount> +{ + + static bool const kTransposed = Transposed; + + using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper; + using Base + = MoeProblemVisitor; + using Params = typename Base::Params; + using SharedStorage = typename Base::SharedStorage; + + // + // Methods + // + CUTLASS_DEVICE + GemmMoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx) + : Base(params_, shared_storage_, block_idx) + { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp new file mode 100644 index 00000000000..3a084ee04fb --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp @@ -0,0 +1,70 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel +{ + +//////////////////////////////////////////////////////////////////////////////// + +/* + * Stateless universal device GEMM kernel type that treats GEMM as + * a composition of a collective mainloop and a collective epilogue. + * + * Supports both the 2.x and 3.x APIs based on whether the first type is + * a cute::tuple<> or not. + * 2.x API implementation: cutlass/gemm/kernel/gemm_universal.h + * 3.x API implementation: cutlass/gemm/kernel/gemm_*.hpp + * + * In the following declaration, the name preceding the 'Or' refers to + * 3.x API type argument order, and the name succeeding the 'Or' refers to + * 2.x API type argument order. Template arguments without two names + * belong to the 3.x API only. + **/ +template +class GemmUniversalGated; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel + +//////////////////////////////////////////////////////////////////////////////// + +#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp" +#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp" +//////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h new file mode 100644 index 00000000000..0650ca8ded4 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h @@ -0,0 +1,585 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief GEMM kernel to support the epilogue visitor model + for customized softmax partial reduction epilogue fusion. + + This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once + its usage has been stabilized. For now, it is included in this example to demonstrate + some basic output fusion options. + + original file: 3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h +*/ + +#pragma once + +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" +#include "cutlass/trace.h" + +#include "cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h" + +namespace tk = tensorrt_llm::common; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct GemmWithEpilogueVisitor +{ +public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueVisitor = typename Epilogue::Visitor; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using TensorRefA = TensorRef; + + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using TensorRefB = TensorRef; + + using ElementCompute = typename EpilogueVisitor::ElementCompute; + using LayoutAlphaCol = cutlass::layout::RowMajor; + using LayoutAlphaRow = cutlass::layout::ColumnMajor; + using TensorRefAlphaCol = TensorRef; + using TensorRefAlphaRow = TensorRef; + + using ElementC = typename EpilogueVisitor::ElementOutput; + using LayoutC = typename Epilogue::Layout; + using TensorRefC = TensorRef; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + using EpilogueOutputOp = + typename Epilogue::Visitor::ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment + = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); + + // + // Structures + // + + /// Argument structure + struct Arguments + { + + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; + + TensorRefA ref_A; + TensorRefB ref_B; + tk::QuantMode quant_option; + TensorRefAlphaCol ref_alpha_col; + TensorRefAlphaRow ref_alpha_row; + TensorRefC ref_C; + TensorRefC ref_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_D; + + typename EpilogueVisitor::Arguments epilogue_visitor; + + // + // Methods + // + + Arguments() + : mode(GemmUniversalMode::kGemm) + , batch_count(1) + { + } + + /// constructs an arguments structure + Arguments(GemmUniversalMode mode_, GemmCoord problem_size_, int batch_count_, TensorRefA ref_A_, + TensorRefB ref_B_, tk::QuantMode quant_option_, TensorRefAlphaCol ref_alpha_col_, + TensorRefAlphaRow ref_alpha_row_, TensorRefC ref_C_, TensorRefC ref_D_, int64_t batch_stride_A_, + int64_t batch_stride_B_, typename EpilogueVisitor::Arguments epilogue_visitor_) + : mode(mode_) + , problem_size(problem_size_) + , batch_count(batch_count_) + , ref_A(ref_A_) + , ref_B(ref_B_) + , quant_option(quant_option_) + , ref_alpha_col(ref_alpha_col_) + , ref_alpha_row(ref_alpha_row_) + , ref_C(ref_C_) + , ref_D(ref_D_) + , batch_stride_A(batch_stride_A_) + , batch_stride_B(batch_stride_B_) + , batch_stride_D(0) + , epilogue_visitor(epilogue_visitor_) + { + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params + { + + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col; + typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row; + typename EpilogueVisitor::OutputTileIterator::Params params_C; + typename EpilogueVisitor::OutputTileIterator::Params params_D; + + GemmUniversalMode mode; + int batch_count; + int gemm_k_size; + + void* ptr_A; + void* ptr_B; + tk::QuantMode quant_option; + typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_col; + typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_row; + ElementC* ptr_C; + ElementC* ptr_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + + typename EpilogueVisitor::Params epilogue_visitor; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : swizzle_log_tile(0) + , params_A(0) + , params_B(0) + , params_alpha_col(0) + , params_C(0) + , params_D(0) + , batch_count(0) + , gemm_k_size(0) + , mode(cutlass::gemm::GemmUniversalMode::kGemm) + , ptr_A(nullptr) + , ptr_B(nullptr) + , ptr_alpha_col(nullptr) + , ptr_alpha_row(nullptr) + , ptr_C(nullptr) + , ptr_D(nullptr) + , batch_stride_A(0) + , batch_stride_B(0) + { + } + + Params( + Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape_, int gemm_k_size_, int* workspace_) + : problem_size(args.problem_size) + , swizzle_log_tile(0) + , params_A(args.ref_A.layout()) + , params_B(args.ref_B.layout()) + , params_alpha_col(args.ref_alpha_col.layout()) + , params_alpha_row(args.ref_alpha_col.layout()) + , params_C(args.ref_C.layout()) + , params_D(args.ref_D.layout()) + , mode(args.mode) + , batch_count(args.batch_count) + , gemm_k_size(args.problem_size.k()) + , ptr_A(args.ref_A.data()) + , ptr_B(args.ref_B.data()) + , quant_option(args.quant_option) + , ptr_alpha_col(args.ref_alpha_col.data()) + , ptr_alpha_row(args.ref_alpha_row.data()) + , ptr_C(args.ref_C.data()) + , ptr_D(args.ref_D.data()) + , batch_stride_A(args.batch_stride_A) + , batch_stride_B(args.batch_stride_B) + , epilogue_visitor(args.epilogue_visitor) + { + + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) + { + + int const kAlignK + = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + + if (gemm_k_size) + { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage + { + + typename Mma::SharedStorage main_loop; + + struct + { + typename Epilogue::SharedStorage epilogue; + typename EpilogueVisitor::SharedStorage visitor; + } epilogue; + }; + +public: + // + // Methods + // + + CUTLASS_DEVICE + GemmWithEpilogueVisitor() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) + { + + CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()"); + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::OutputTileIterator::kElementsPerAccess; + + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + + if (platform::is_same::value) + { + isAMisaligned = problem_size.k() % kAlignmentA; + } + else if (platform::is_same::value) + { + isAMisaligned = problem_size.m() % kAlignmentA; + } + else if (platform::is_same>::value + || platform::is_same>::value) + { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (platform::is_same::value) + { + isBMisaligned = problem_size.n() % kAlignmentB; + } + else if (platform::is_same::value) + { + isBMisaligned = problem_size.k() % kAlignmentB; + } + else if (platform::is_same>::value + || platform::is_same>::value) + { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (platform::is_same::value) + { + isCMisaligned = problem_size.n() % kAlignmentC; + } + else if (platform::is_same::value) + { + isCMisaligned = problem_size.m() % kAlignmentC; + } + else if (platform::is_same>::value + || platform::is_same>::value) + { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) + { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) + { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) + { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + static Status can_implement(Arguments const& args) + { + return can_implement(args.problem_size); + } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) + { + + return 0; + } + +#define SPLIT_K_ENABLED 1 + + /// Executes one GEMM + CUTLASS_DEVICE + void run_kernel_(Params const& params, SharedStorage& shared_storage) + { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() + || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) + { + + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA* ptr_A = static_cast(params.ptr_A); + ElementB* ptr_B = static_cast(params.ptr_B); + +#if SPLIT_K_ENABLED + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel) + { + + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) + { + + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } + else if (params.mode == GemmUniversalMode::kBatched) + { + ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; + } + else if (params.mode == GemmUniversalMode::kArray) + { + ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; + ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; + } +#endif + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // + // Construct the epilogue visitor + // + + EpilogueVisitor epilogue_visitor(params.epilogue_visitor, shared_storage.epilogue.visitor, + params.problem_size.mn(), thread_idx, warp_idx, lane_idx, params.params_alpha_col, params.params_C, + params.params_D, params.quant_option, params.ptr_alpha_row, params.ptr_alpha_col, params.ptr_C, + params.ptr_D, threadblock_offset, blockIdx.y * params.problem_size.m()); + + if (params.mode == GemmUniversalMode::kGemm) + { + // Indicate which position in a serial reduction the output operator is currently updating + epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) + { + epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); + } + + // Construct the epilogue + Epilogue epilogue(shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(epilogue_visitor, accumulators); + } + + template + CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) + { + if constexpr (platform::is_same::value) + { + run_kernel_(params, shared_storage); + } + else + { + CUTLASS_NOT_IMPLEMENTED(); + } + } + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) + { +#if defined(__CUDA_ARCH__) +#if (__CUDA_ARCH__ >= 720) && (__CUDA_ARCH__ < 750) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 900) + // TODO - replace with CUTLASS_NOT_IMPLEMENTED() and upgrade to 3.x kernels. + run_kernel(params, shared_storage); +#else + static_assert( + false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); +#endif +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h new file mode 100644 index 00000000000..6dc6ffc1a9f --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h @@ -0,0 +1,143 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is + quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices + to be consumed by CUTLASS. + + Note that for int4, ThreadBlockK MUST be 64. + + */ + +#pragma once + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/platform/platform.h" + +#include "cutlass_extensions/arch/mma.h" +#include "cutlass_extensions/tile_interleaved_layout.h" + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +template +struct LayoutDetailsB +{ +}; + +// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks. +// TODO - Switch this to column major for weights since gemms should be more performant. +template +struct LayoutDetailsB= 75>::type> +{ + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template +struct LayoutDetailsB= 75>::type> +{ + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template +struct LayoutDetailsB +{ + static constexpr int ThreadblockK = 64; + +private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + +public: + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; + // for fast accumulation + // using Operator = cutlass::arch::OpMultiplyAddFastAccum; +}; + +// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA, +// which signals that we want to dequantize after loading from smem. +template + struct LayoutDetailsB < TypeA, + uint8_t, Arch, + typename platform::enable_if= 75 && Arch::kMinComputeCapability<90>::type> +{ + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + +private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + +public: + using Layout = layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; +}; + +template + struct LayoutDetailsB < TypeA, + uint4b_t, Arch, + typename platform::enable_if= 75 && Arch::kMinComputeCapability<90>::type> +{ + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + +private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + +public: + using Layout = layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; +}; + +template +struct LayoutDetailsB= 90>::type> +{ + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template +struct LayoutDetailsB= 90>::type> +{ + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cute_util.cuh b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cute_util.cuh new file mode 100644 index 00000000000..aac2cb35799 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cute_util.cuh @@ -0,0 +1,185 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include +#include + +template +struct DefaultGemm_TensorOpSm80_OperandA; + +template +struct DefaultGemm_TensorOpSm80_OperandB; + +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype(cute::composition( + cute::Swizzle<3, 3, 3>{}, cute::Layout, cute::Stride>{})); + using SmemCopyAtom = cute::Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype(cute::make_tiled_copy( + cute::Copy_Atom, cute::half_t>{}, + cute::Layout, cute::Stride>{}, + cute::Layout>{})); +}; + +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype(cute::composition( + cute::Swizzle<3, 3, 3>{}, cute::Layout, cute::Stride>{})); + using SmemCopyAtom = cute::Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype(cute::make_tiled_copy( + cute::Copy_Atom, cute::bfloat16_t>{}, + cute::Layout, cute::Stride>{}, + cute::Layout>{})); +}; + +/// Operand A - Column-major (M-major) +template +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype(cute::composition( + cute::Swizzle<3, 3, 3>{}, cute::Layout, cute::Stride>{})); + using SmemCopyAtom = cute::Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype(cute::make_tiled_copy( + cute::Copy_Atom, cute::half_t>{}, + cute::Layout, cute::Stride>{}, + cute::Layout>{})); +}; + +template +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype(cute::composition( + cute::Swizzle<3, 3, 3>{}, cute::Layout, cute::Stride>{})); + using SmemCopyAtom = cute::Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype(cute::make_tiled_copy( + cute::Copy_Atom, cute::bfloat16_t>{}, + cute::Layout, cute::Stride>{}, + cute::Layout>{})); +}; + +// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the DefaultOperands + +// Operand B - Column-Major (K-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{ +}; + +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{ +}; + +// Operand B - Row-Major (N-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{ +}; + +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{ +}; + +// +// F16: 128-by-128-by-32 (small k-block) +// + +/// Operand A - Row-major (K-Major) +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype(cute::composition( + cute::Swizzle<2, 3, 3>{}, cute::Layout, cute::Stride>{})); + using SmemCopyAtom = cute::Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype(cute::make_tiled_copy( + cute::Copy_Atom, cute::half_t>{}, + cute::Layout, cute::Stride>{}, + cute::Layout>{})); +}; + +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype(cute::composition( + cute::Swizzle<2, 3, 3>{}, cute::Layout, cute::Stride>{})); + using SmemCopyAtom = cute::Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype(cute::make_tiled_copy( + cute::Copy_Atom, cute::bfloat16_t>{}, + cute::Layout, cute::Stride>{}, + cute::Layout>{})); +}; + +template +CUTE_DEVICE auto util_convert_type(cute::Tensor const& tensor) +{ + using From_type = typename Engine::value_type; + constexpr int numel = decltype(cute::size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast const*>(tensor.data())); + return cute::make_tensor(cute::make_rmem_ptr(&frag), tensor.layout()); +} + +template +CUTE_DEVICE void util_copy( + TiledCopy const& tiled_copy, cute::Tensor const& S, cute::Tensor& D) +{ + CUTE_STATIC_ASSERT_V(cute::rank(S) == cute::Int<3>{}); + CUTE_STATIC_ASSERT_V(cute::rank(D) == cute::Int<3>{}); + CUTE_STATIC_ASSERT_V(cute::size<0>(S) == cute::size<0>(D)); + CUTE_STATIC_ASSERT_V(cute::size<1>(S) == cute::size<1>(D)); + CUTE_STATIC_ASSERT_V(cute::size<2>(S) == cute::size<2>(D)); + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < cute::size<1>(S); ++m) + { + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < cute::size<2>(S); ++k) + { + cute::copy(tiled_copy, S(cute::_, m, k), D(cute::_, m, k)); + } + } +} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h new file mode 100644 index 00000000000..b708f7c28b5 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h @@ -0,0 +1,553 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" + +#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h" +#include "cutlass_extensions/tile_interleaved_layout.h" + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// +// This section exists to that we can use the same kernel code for regular gemm and dequantizing gemms. +// It will dispatch to the dequantizing gemm if the Mma type has an Iterator for scales in global. +template +using void_t = void; + +template +struct use_dq_gemm : platform::false_type +{ +}; + +template +struct use_dq_gemm> : platform::true_type +{ +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MoeFCGemm +{ +public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + static bool const kTransposed = false; + + // Optional transpose + using MapArguments = kernel::detail::MapArguments; + + // Public-facing type definitions related to operand element type, layout, and complex conjugate + // operation. Must interact with the 'kTransposed' notion. + static_assert(!kTransposed, "Transpose problem not supported"); + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename MapArguments::LayoutC; + using ElementScale = ElementC; + + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = MapArguments::kAlignmentA; + static int const kAlignmentB = MapArguments::kAlignmentB; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using ProblemVisitor + = GemmMoeProblemVisitor; + + // + // Structures + // + + /// Argument structure + struct Arguments + { + + // + // Data members + // + + int problem_count; + int threadblock_count; + int group_size; + + typename EpilogueOutputOp::Params output_op; + + ElementA* ptr_A; + ElementB* ptr_B; + ElementScale* weight_scales; + ElementC* ptr_C; + ElementC* ptr_D; + bool C_is_broadcast; + + int64_t const* total_tokens_including_expert; + int64_t gemm_n; + int64_t gemm_k; + + // Only used by device-level operator + GemmCoord* host_problem_sizes; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() + : problem_count(0) + , threadblock_count(0) + , ptr_A(nullptr) + , ptr_B(nullptr) + , weight_scales(nullptr) + , ptr_C(nullptr) + , ptr_D(nullptr) + , total_tokens_including_expert(nullptr) + , gemm_n(0) + , gemm_k(0) + , host_problem_sizes(nullptr) + , C_is_broadcast{true} + { + } + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments(int problem_count, int threadblock_count, int group_size, typename EpilogueOutputOp::Params output_op, + ElementA const* ptr_A, ElementB const* ptr_B, ElementScale const* weight_scales, ElementC const* ptr_C, + bool C_is_broadcast, ElementC* ptr_D, int64_t const* total_tokens_including_expert, int64_t gemm_n, + int64_t gemm_k, GemmCoord* host_problem_sizes = nullptr) + : problem_count(problem_count) + , threadblock_count(threadblock_count) + , group_size(group_size) + , output_op(output_op) + , ptr_A(const_cast(ptr_A)) + , ptr_B(const_cast(ptr_B)) + , weight_scales(const_cast(weight_scales)) + , ptr_C(const_cast(ptr_C)) + , C_is_broadcast{C_is_broadcast} + , ptr_D(ptr_D) + , total_tokens_including_expert(total_tokens_including_expert) + , gemm_n(gemm_n) + , gemm_k(gemm_k) + , host_problem_sizes(nullptr) + { + if (platform::is_same::value || platform::is_same::value) + { + assert(weight_scales); + } + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params + { + + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + int group_size; + bool C_is_broadcast; + + typename EpilogueOutputOp::Params output_op; + + ElementA* ptr_A; + ElementB* ptr_B; + ElementScale* weight_scales; + ElementC* ptr_C; + ElementC* ptr_D; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : ptr_A(nullptr) + , ptr_B(nullptr) + , weight_scales(nullptr) + , ptr_C(nullptr) + , ptr_D(nullptr) + , C_is_broadcast(true) + { + } + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + : problem_visitor( + args.total_tokens_including_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, tile_count) + , threadblock_count(args.threadblock_count) + , group_size(args.group_size) + , output_op(args.output_op) + , ptr_A(args.ptr_A) + , ptr_B(args.ptr_B) + , weight_scales(args.weight_scales) + , ptr_C(args.ptr_C) + , ptr_D(args.ptr_D) + , C_is_broadcast(args.C_is_broadcast) + { + } + + CUTLASS_HOST_DEVICE + void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + { + + problem_visitor = typename ProblemVisitor::Params(args.total_tokens_including_expert, args.gemm_n, + args.gemm_k, args.problem_count, workspace, tile_count); + threadblock_count = args.threadblock_count; + output_op = args.output_op; + ptr_A = args.ptr_A; + ptr_B = args.ptr_B; + weight_scales = args.weight_scales; + ptr_C = args.ptr_C; + ptr_D = args.ptr_D; + C_is_broadcast = args.C_is_broadcast; + } + }; + + /// Shared memory storage structure + union SharedStorage + { + typename ProblemVisitor::SharedStorage problem_visitor; + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + +public: + // + // Methods + // + + CUTLASS_DEVICE + MoeFCGemm() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) + { + return Status::kSuccess; + } + + static Status can_implement(Arguments const& args) + { + if (platform::is_same::value || platform::is_same::value) + { + if (args.weight_scales == nullptr) + { + CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - weight scales are required for uint8_t and uint4b_t"); + return Status::kInvalid; + } + } + else if (args.weight_scales != nullptr) + { + CUTLASS_TRACE_HOST( + "MoeFCGemm::can_implement() - weight scales are ignored for all types except uint8_t and uint4b_t"); + return Status::kInvalid; + } + else if (args.group_size != args.gemm_k) + { + CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - scale shape should be (1, gemm_n)"); + return Status::kInvalid; + } + // Handle the case the input is too short + else if (args.gemm_n < Mma::IteratorB::AccessType::kElements) + { + CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - gemm_n is smaller than the input alignment"); + return Status::kInvalid; + } + return Status::kSuccess; + } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) + { + + return 0; + } + + CUTLASS_DEVICE + void run_kernel_(Params const& params, SharedStorage& shared_storage) + { + // + // These types shadow the type-level definitions and support the ability to implement + // a 'transposed' GEMM that computes the transposed problems. + // + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + static_assert(platform::is_same::value && kInterleave == 1 + || platform::is_same::value && kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); + + // + // Problem visitor. + // + ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + + const int64_t gemm_k = params.problem_visitor.gemm_k; + const int64_t gemm_n = params.problem_visitor.gemm_n; + int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; + + // Outer 'persistent' loop to iterate over tiles + int loop = 0; + while (problem_visitor.next_tile()) + { + loop++; + + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); + + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + cutlass::gemm::GemmCoord threadblock_offset( + int(cta_idx / grid_shape.n()) * Mma::Shape::kM, int(cta_idx % grid_shape.n()) * Mma::Shape::kN, 0); + + // Load element pointers. Exchange pointers and strides if working on the transpose + const int64_t rows_to_jump + = problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1]; + ElementA* ptr_A = reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; + typename LayoutA::LongIndex ldm_A = gemm_k; + + char* byte_ptr_B = ((char*) params.ptr_B) + problem_idx * bytes_per_expert_matrix; + ElementB* ptr_B = reinterpret_cast(byte_ptr_B); + typename LayoutB::LongIndex ldm_B + = platform::is_same::value ? gemm_n : gemm_k * kInterleave; + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_offset.m(), + 0, + }; + + cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; + + cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()}; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, tb_offset_A); + + typename Mma::IteratorB iterator_B(LayoutB(ldm_B), ptr_B, + {problem_size.k() * kInterleave, problem_size.n() / kInterleave}, thread_idx, tb_offset_B); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Matrix multiply phase + // + + // Construct thread-scoped matrix multiply + auto CreateMMA = [&]() + { + if constexpr (use_dq_gemm::value) + return Mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx); + else + return Mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + }; + Mma mma = CreateMMA(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Wait for all threads to finish their epilogue phases from the previous tile. + __syncthreads(); + + // Compute threadblock-scoped matrix multiply-add + ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * problem_size.n(); + + if constexpr (use_dq_gemm::value) + { + const MatrixCoord scale_extent = {1, problem_size.n()}; + typename Mma::IteratorScale iterator_scale(Mma::IteratorScale::Layout(scale_extent.column()), + weight_scale_ptr, scale_extent, thread_idx, tb_offset_scale); + + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators); + } + else + { + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + } + + // + // Epilogue + // + + ElementC* ptr_C = reinterpret_cast(params.ptr_C) + + (params.C_is_broadcast ? problem_idx : rows_to_jump) * gemm_n; + ElementC* ptr_D = reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; + + // lora need to set as layout_C(gemm_n) + LayoutC layout_C = params.C_is_broadcast ? LayoutC(0) : LayoutC(gemm_n); + LayoutC layout_D(gemm_n); + + typename Epilogue::OutputTileIterator::Params params_C(layout_C); + typename Epilogue::OutputTileIterator::Params params_D(layout_D); + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset.mn()); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset.mn()); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + if constexpr (platform::is_same>::value) + { + EpilogueOutputOp output_op(params.output_op, problem_idx); + epilogue(output_op, iterator_D, accumulators, iterator_C); + } + else + { + EpilogueOutputOp output_op(params.output_op); + epilogue(output_op, iterator_D, accumulators, iterator_C); + } + + // Next tile + problem_visitor.advance(gridDim.x); + } + } + + template + CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) + { + if constexpr (platform::is_same::value) + { + run_kernel_(params, shared_storage); + } + else + { + CUTLASS_NOT_IMPLEMENTED(); + } + } + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) + { +#if defined(__CUDA_ARCH__) +#if (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 890) && (__CUDA_ARCH__ < 900) + constexpr bool isFp8 = platform::is_same::value + || platform::is_same::value; + if constexpr (isFp8) + { + run_kernel(params, shared_storage); + } + else + { // reuse sm80 kernel for other types, align with dispatchToArch + run_kernel(params, shared_storage); + } +#elif (__CUDA_ARCH__ >= 900) + run_kernel(params, shared_storage); +#else + static_assert( + false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); +#endif +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_problem_visitor.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_problem_visitor.h new file mode 100644 index 00000000000..796dc2fe78d --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_problem_visitor.h @@ -0,0 +1,344 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Base scheduler for grouped problems, using MoE +*/ + +#pragma once + +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct BaseMoeProblemVisitor +{ + using ThreadblockShape = ThreadblockShape_; + + struct ProblemInfo + { + static int32_t const kNoPrefetchEntry = -1; + int32_t problem_idx; + int32_t problem_start; + + CUTLASS_DEVICE + ProblemInfo() + : problem_idx(kNoPrefetchEntry) + , problem_start(kNoPrefetchEntry) + { + } + + CUTLASS_DEVICE + ProblemInfo(int32_t problem_idx_, int32_t problem_start_) + : problem_idx(problem_idx_) + , problem_start(problem_start_) + { + } + }; + + struct Params + { + int64_t const* last_row_for_problem; + int64_t gemm_n; + int64_t gemm_k; + int32_t problem_count; + void const* workspace; + int32_t tile_count; + + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + Params() + : last_row_for_problem(nullptr) + , gemm_n(0) + , gemm_k(0) + , problem_count(0) + , workspace(nullptr) + , tile_count(0) + { + } + + /// Ctor + CUTLASS_HOST_DEVICE + Params(int64_t const* last_row_for_problem, int64_t gemm_n, int64_t gemm_k, int32_t problem_count, + void const* workspace = nullptr, int32_t tile_count = 0) + : last_row_for_problem(last_row_for_problem) + , gemm_n(gemm_n) + , gemm_k(gemm_k) + , problem_count(problem_count) + , workspace(workspace) + , tile_count(tile_count) + { + } + }; + + Params const& params; + int32_t tile_idx; + int32_t problem_tile_start; + int32_t problem_idx; + + // + // Methods + // + CUTLASS_DEVICE + BaseMoeProblemVisitor(Params const& params_, int32_t block_idx) + : params(params_) + , tile_idx(block_idx) + , problem_tile_start(0) + , problem_idx(0) + { + } + + /// Get the grid shape + CUTLASS_HOST_DEVICE + static cutlass::gemm::GemmCoord grid_shape(cutlass::gemm::GemmCoord const& problem) + { + + return cutlass::gemm::GemmCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), + ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), 1); + } + + /// Gets the global tile index + CUTLASS_HOST_DEVICE + int32_t tile_index() const + { + return tile_idx; + } + + /// Gets the index of the problem + CUTLASS_HOST_DEVICE + int32_t problem_index() const + { + return problem_idx; + } + + CUTLASS_HOST_DEVICE + int32_t threadblock_idx() const + { + return tile_idx - problem_tile_start; + } + + CUTLASS_DEVICE + void advance(int32_t grid_size) + { + tile_idx += grid_size; + } + + CUTLASS_HOST_DEVICE + static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) + { + ProblemSizeHelper::possibly_transpose_problem(problem); + } + + /// Returns the problem size for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size() const + { + return problem_size(problem_idx); + } + + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size(int idx) const + { + const int64_t prev_problem_row = idx == 0 ? 0 : params.last_row_for_problem[idx - 1]; + const int64_t current_problem_row = params.last_row_for_problem[idx]; + const int64_t gemm_m = current_problem_row - prev_problem_row; + GemmCoord problem(GemmCoord::Index(gemm_m), GemmCoord::Index(params.gemm_n), GemmCoord::Index(params.gemm_k)); + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } + + CUTLASS_HOST_DEVICE + static int32_t tile_count(cutlass::gemm::GemmCoord const& grid) + { + return ProblemSizeHelper::tile_count(grid); + } + + static int32_t group_tile_count(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count) + { + int32_t total_tiles = 0; + for (int32_t i = 0; i < problem_count; ++i) + { + auto problem = host_problem_sizes_ptr[i]; + possibly_transpose_problem(problem); + auto grid = grid_shape(problem); + total_tiles += tile_count(grid); + } + + return total_tiles; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MoeProblemVisitor; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// ProblemVisitor that performs all scheduling on device +// +template +struct MoeProblemVisitor : public BaseMoeProblemVisitor +{ + using Base = BaseMoeProblemVisitor; + using Params = typename Base::Params; + static int const kThreadCount = ThreadCount; + static bool const kRequiresPrecomputation = false; + static int const kThreadsPerWarp = 32; + + struct SharedStorage + { + }; + + // Final tile of the problem loaded by this thread. Each thread will hold + // a separate value. + int32_t problem_ending_tile; + + SharedStorage& shared_storage; + + // + // Methods + // + CUTLASS_DEVICE + MoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx) + : Base(params_, block_idx) + , problem_ending_tile(0) + , shared_storage(shared_storage_) + { + this->problem_idx = -1 * kThreadsPerWarp; + this->problem_tile_start = 0; + } + + CUTLASS_DEVICE + bool next_tile() + { + // Check whether the tile to compute is within the range of the current problem. + int32_t problem_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, this->problem_idx % kThreadsPerWarp); + if (this->tile_idx < problem_tile_end) + { + return true; + } + + // Check whether the tile to compute is within the current group of problems fetched by the warp. + // The last tile for this group is the final tile of the problem held by the final thread in the warp. + int32_t group_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1); + + // Keep the starting problem for this group in `problem_idx`. This is done to reduce + // register pressure. The starting problem for this group is simply the first problem + // in the group most recently fetched by the warp. + int32_t& group_problem_start = this->problem_idx; + group_problem_start = (this->problem_idx / kThreadsPerWarp) * kThreadsPerWarp; + + // Keep the starting tile for this group in `problem_tile_start`. This is done to reduce + // register pressure. + int32_t& group_tile_start = this->problem_tile_start; + + // Each thread in the warp processes a separate problem to advance until + // reaching a problem whose starting tile is less less than tile_idx. + while (group_tile_end <= this->tile_idx) + { + group_problem_start += kThreadsPerWarp; + if (group_problem_start > this->params.problem_count) + { + return false; + } + + // Since `group_tile_start` is a reference to `this->problem_tile_start`, this + // also sets `this->problem_tile_start`. The fact that `this->problem_tile_start` + // is also set here is used later in `next_tile`. + group_tile_start = group_tile_end; + + int lane_idx = threadIdx.x % kThreadsPerWarp; + int32_t lane_problem = group_problem_start + lane_idx; + + // Compute the number of tiles in the problem assigned to each thread. + problem_ending_tile = 0; + if (lane_problem < this->params.problem_count) + { + cutlass::gemm::GemmCoord problem = this->problem_size(lane_problem); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + problem_ending_tile = this->tile_count(grid); + } + + // Compute a warp-wide inclusive prefix sum to compute the ending tile index of + // each thread's problem. + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kThreadsPerWarp; i <<= 1) + { + int32_t val = __shfl_up_sync(0xffffffff, problem_ending_tile, i); + if (lane_idx >= i) + { + problem_ending_tile += val; + } + } + + // The total tile count for this group is now in the final position of the prefix sum + int32_t tiles_in_group = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1); + + problem_ending_tile += group_tile_start; + group_tile_end += tiles_in_group; + } + + // The next problem to process is the first one that does not have ending tile position + // that is greater than or equal to tile index. + int32_t problem_idx_in_group = __popc(__ballot_sync(0xffffffff, problem_ending_tile <= this->tile_idx)); + + this->problem_idx = group_problem_start + problem_idx_in_group; + + // The starting tile for this problem is the ending tile of the previous problem. In cases + // where `problem_idx_in_group` is the first problem in the group, we do not need to reset + // `problem_tile_start`, because it is set to the previous group's ending tile in the while + // loop above. + if (problem_idx_in_group > 0) + { + this->problem_tile_start = __shfl_sync(0xffffffff, problem_ending_tile, problem_idx_in_group - 1); + } + + return true; + } + + static size_t get_workspace_size( + cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count, int32_t block_count) + { + return 0; + } + + static void host_precompute(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count, + int32_t block_count, void* host_workspace_ptr) + { + } +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp new file mode 100644 index 00000000000..e3d31a2c5b3 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp @@ -0,0 +1,646 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/tensor.hpp" +#include "cutlass/arch/mma_sm90.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" +#include "cutlass/workspace.h" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel +{ + +/////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversalGated + && CollectiveMainloop_::isGated>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + using Activation = typename CollectiveMainloop::Activation; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(ArchTag::kMinComputeCapability >= 90); + + using TileSchedulerTag = TileScheduler_; + using TileScheduler = + typename detail::TileSchedulerSelector::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr uint32_t NumLoadWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup; + static constexpr uint32_t MaxThreadsPerBlock + = CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + /// Register requirement for Load and Math WGs + static constexpr uint32_t LoadRegisterRequirement = 40; + static constexpr uint32_t MmaRegisterRequirement = 232; + + // 1 stage ordered sequence between mainloop and epilogue producer load threads + using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>; + + // Kernel level shared memory storage + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128> + { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16> + { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; + } pipelines; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments + { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params + { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + void* workspace{nullptr}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static Params to_underlying_arguments(Arguments const& args, void* workspace) + { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + auto problem_shape = args.problem_shape; + // if constexpr (detail::IF_SWAP_AB::value) { + // // swap M/N + // get<0>(problem_shape) = get<1>(args.problem_shape); + // get<1>(problem_shape) = get<0>(args.problem_shape); + // } + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) + { + CUTLASS_TRACE_HOST( + " WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + void* scheduler_workspace = workspace_ptr; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = nullptr; + // Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used + // in separate reduction scheme for streamk case, NumEpilogueSubTiles default value is 1, which means + // subtile will not be used, therefore separate reduction will not be enabled. + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments(problem_shape_MNKL, TileShape{}, + ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles); + + return {args.mode, problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), hw_info, + scheduler, workspace}; + } + + static bool can_implement(Arguments const& args) + { + bool implementable = (args.mode == GemmUniversalMode::kGemm) + or (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); + if (!implementable) + { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + return implementable; + } + + static size_t get_workspace_size(Arguments const& args) + { + size_t workspace_size = 0; + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status initialize_workspace(Arguments const& args, void* workspace = nullptr, + cudaStream_t stream = nullptr, CudaHostAdapter* cuda_adapter = nullptr) + { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + + status = TileScheduler::template initialize_workspace(args.scheduler, + workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups, + NumEpilogueSubTiles); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) + { + return status; + } + + status = CollectiveEpilogue::initialize_workspace( + args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) + { + return status; + } + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 get_grid_shape(Params const& params) + { + // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently + TileSchedulerArguments args{}; + if constexpr (!std::is_const_v) + { + args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; + } + args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN + ? TileScheduler::RasterOrderOptions::AlongN + : TileScheduler::RasterOrderOptions::AlongM; + return TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); + } + + static dim3 get_block_shape() + { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void operator()(Params const& params, char* smem_buf) + { + using namespace cute; + using X = Underscore; + +// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. +#if !defined(__CUDA_ARCH_FEAT_SM90_ALL) + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); +#else + + // Preconditions + static_assert(size(TiledMma{}) == 256, "Cooperative kernel must have TiledMMA operating using 256 threads."); + static_assert(size<0>(TileShape{}) >= 128, + "Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension."); + + static_assert(cute::rank(StrideA{}) == 3, + "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, + "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, + "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, + "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + /* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */ + enum class WarpGroupRole + { + Producer = 0, + Consumer0 = 1, + Consumer1 = 2 + }; + enum class ProducerWarpRole + { + Mainloop = 0, + Warp1 = 1, + Epilogue = 2, + Warp3 = 3 + }; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int lane_idx = canonical_lane_idx(); + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + int mma_thread_idx = thread_idx % size(TiledMma{}); + auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); + auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); + int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) + { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop) + { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) + { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; + mainloop_pipeline_params.num_consumers = size(TiledMma{}); + mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue) + { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) + { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; + epi_load_pipeline_params.consumer_arv_count = size(TiledMma{}); + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + typename LoadWarpOrderBarrier::Params params_load_order_barrier; + params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1; + params_load_order_barrier.group_size = NumThreadsPerWarp; + LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier); + + // Initialize starting pipeline states for the collectives + // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + auto cluster_wait_fn = []() + { + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer thread blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) + { + cute::cluster_arrive_relaxed(); + return []() { cute::cluster_wait(); }; + } + else + { + __syncthreads(); + return []() {}; // do nothing + } + }(); + + // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + + // Get the appropriate blocks for this thread block -- potential for thread block locality + TiledMma tiled_mma; + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + + TileScheduler scheduler{params.scheduler}; + auto work_tile_info = scheduler.get_current_work(); + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Prepare and partition the input tensors. Expects a tuple of tensors where: + // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) + // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) + auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); + static_assert(cute::tuple_size_v >= 3, + "Output of load_init must have at least three elements (A, B, Aux)"); + + // Extract out partitioned A and B. + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + Tensor gAux_xkl = get<2>(load_inputs); + + // Get pipeline stage increments from tensor shapes + auto k_tile_count = size<3>(gA_mkl); + + // Wait for all thread blocks in the Cluster + cluster_wait_fn(); + + if (warp_group_role == WarpGroupRole::Producer) + { + cutlass::arch::warpgroup_reg_dealloc(); + + // Mainloop Producer Warp + if (producer_warp_role == ProducerWarpRole::Mainloop) + { + bool do_load_order_arrive = true; + while (work_tile_info.is_valid()) + { + if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) + { + work_tile_info = fetch_next_work(work_tile_info, scheduler); + continue; + } + + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the + // work. + auto work_k_tile_count + = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); + auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); + auto k_tile_iter + = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); + + collective_mainloop.load(params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state, + load_inputs, blk_coord, k_tile_iter, work_k_tile_count, lane_idx, block_rank_in_cluster, + shared_storage.tensors.mainloop); + // Update starting pipeline state for the next tile + mainloop_pipe_producer_state.advance(work_k_tile_count); + + // Signal for the epilogue load warp to begin + if (do_load_order_arrive) + { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + + // Get next work tile + work_tile_info = fetch_next_work(work_tile_info, scheduler); + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + } // Mainloop Producer Warp End + + // Epilogue Producer Warp + else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) + { + while (work_tile_info.is_valid()) + { + if (!TileScheduler::requires_separate_reduction(params.scheduler)) + { + load_order_barrier.wait(); + } + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) + { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + epi_load_pipe_producer_state = collective_epilogue.load(epi_load_pipeline, + epi_load_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, tiled_mma, lane_idx, + shared_storage.tensors.epilogue, work_tile_info.reduction_subtile_idx()); + } + + // Get next work tile + work_tile_info = fetch_next_work(work_tile_info, scheduler); + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } // Epilogue Producer Warp End + } // Producer Warp Group End + + else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) + { + cutlass::arch::warpgroup_reg_alloc(); + + // Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it + bool do_store_tail = false; + float scale_d0 = params.mainloop.scale_d0; + float scale_d1 = params.mainloop.scale_d1; + while (work_tile_info.is_valid()) + { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + auto work_k_tile_count + = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); + + // Allocate the accumulators for the (M,N) blk_shape + // + // MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead. + auto accumulators0 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) + auto accumulators1 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) + if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) + { + collective_mainloop.mma(mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0, + accumulators1, work_k_tile_count, mma_thread_idx, shared_storage.tensors.mainloop, + params.mainloop); + + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail(mainloop_pipeline, mainloop_pipe_consumer_state, work_k_tile_count); + + // Update starting mainloop pipeline state for the next tile + mainloop_pipe_consumer_state.advance(work_k_tile_count); + } + // Index of warp group within consumer warp groups + int consumer_warp_group_idx = canonical_warp_group_idx() - NumLoadWarpGroups; + + // Perform reduction across splits, if needed + TileScheduler::fixup( + params.scheduler, work_tile_info, accumulators0, NumMmaWarpGroups, consumer_warp_group_idx); + TileScheduler::fixup( + params.scheduler, work_tile_info, accumulators1, NumMmaWarpGroups, consumer_warp_group_idx); + + Activation elt_op; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators0); i++) + { + accumulators0[i] = (accumulators0[i] * scale_d0) * elt_op(scale_d1 * accumulators1[i]); + } + + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) + { + // Epilogue and write to gD + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] + = collective_epilogue.store(epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline, + epi_store_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, accumulators0, + tiled_mma, mma_thread_idx, shared_storage.tensors.epilogue, + work_tile_info.reduction_subtile_idx()); + epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next; + epi_store_pipe_producer_state = epi_store_pipe_producer_state_next; + do_store_tail = true; + } + + // Get next work tile + work_tile_info = fetch_next_work(work_tile_info, scheduler); + } // Scheduler work fetch loop + + if (do_store_tail) + { + collective_epilogue.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline, epi_store_pipe_producer_state); + } + } // Consumer Warp Groups End +#endif + } + +private: + // Kernel helper function to get next work unit + CUTLASS_DEVICE + typename TileScheduler::WorkTileInfo fetch_next_work( + typename TileScheduler::WorkTileInfo& work_tile_info, TileScheduler& scheduler) const + { + // Check whether we should continue on with the current work unit. If this is the case, + // the work unit will have been updated in continue_current_work to reflect the new + // tile to be computed. + if (scheduler.continue_current_work(work_tile_info)) + { + return work_tile_info; + } + + // Get next work tile + scheduler.advance_to_next_work(); + return scheduler.get_current_work(); + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp new file mode 100644 index 00000000000..39886f2431d --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp @@ -0,0 +1,621 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/mma_sm90.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" +#include "cutlass/workspace.h" + +#include "cute/tensor.hpp" + +#include "cute/util/debug.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel +{ + +/////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversalGated + && CollectiveMainloop_::isGated>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + using Activation = typename CollectiveMainloop::Activation; + static_assert(ArchTag::kMinComputeCapability >= 90); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(!cute::is_same_v, + "Ping-pong kernel does not currently support stream-K scheduler."); + using TileSchedulerTag = TileScheduler_; + using TileScheduler = + typename detail::TileSchedulerSelector::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr uint32_t NumLoadWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = 2; + static constexpr uint32_t MaxThreadsPerBlock + = CUTE_STATIC_V(size(TiledMma{})) + (NumMmaWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + /// Register requirement for Load and Math WGs + static constexpr uint32_t LoadRegisterRequirement = 40; + static constexpr uint32_t MmaRegisterRequirement = 232; + + // 1 stage ordered sequence between mainloop and epilogue producer load threads + using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>; + + // Order Sequence barrier with two stages: one for Mainloop and one for Epilogue + static constexpr uint32_t StagesPerMathWarpGroup = 2; + using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier; + + // Kernel level shared memory storage + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128> + { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16> + { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + using MathWarpGroupOrderBarrierStorage = typename MathWarpGroupOrderBarrier::SharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order; + alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; + } pipelines; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments + { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params + { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static Params to_underlying_arguments(Arguments const& args, void* workspace) + { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + (void) workspace; + auto problem_shape = args.problem_shape; + // if constexpr (detail::IF_SWAP_AB::value) { + // // swap M/N + // get<0>(problem_shape) = get<1>(args.problem_shape); + // get<1>(problem_shape) = get<0>(args.problem_shape); + // } + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) + { + CUTLASS_TRACE_HOST( + " WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + void* scheduler_workspace = workspace_ptr; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = nullptr; + + return {args.mode, problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), hw_info, + TileScheduler::to_underlying_arguments( + problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace)}; + } + + static bool can_implement(Arguments const& args) + { + bool implementable = (args.mode == GemmUniversalMode::kGemm) + or (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); + if (!implementable) + { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + return implementable; + } + + static size_t get_workspace_size(Arguments const& args) + { + size_t workspace_size = 0; + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status initialize_workspace(Arguments const& args, void* workspace = nullptr, + cudaStream_t stream = nullptr, CudaHostAdapter* cuda_adapter = nullptr) + { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + status = TileScheduler::template initialize_workspace(args.scheduler, + workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) + { + return status; + } + + status = CollectiveEpilogue::initialize_workspace( + args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) + { + return status; + } + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 get_grid_shape(Params const& params) + { + // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently + TileSchedulerArguments args{}; + if constexpr (!std::is_const_v) + { + args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; + } + args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN + ? TileScheduler::RasterOrderOptions::AlongN + : TileScheduler::RasterOrderOptions::AlongM; + return TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); + } + + static dim3 get_block_shape() + { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void operator()(Params const& params, char* smem_buf) + { + using namespace cute; + using X = Underscore; + +// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. +#if !defined(__CUDA_ARCH_FEAT_SM90_ALL) + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); +#else + + // Preconditions + static_assert(cute::rank(StrideA{}) == 3, + "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, + "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, + "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, + "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + enum class WarpGroupRole + { + Producer = 0, + Consumer0 = 1, + Consumer1 = 2 + }; + enum class ProducerWarpRole + { + Mainloop = 0, + Warp1 = 1, + Epilogue = 2, + Warp3 = 3 + }; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int lane_idx = canonical_lane_idx(); + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); + auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); + int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) + { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop) + { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) + { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; + mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup; + mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue) + { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) + { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; + epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + typename LoadWarpOrderBarrier::Params params_load_order_barrier; + params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1; + params_load_order_barrier.group_size = NumThreadsPerWarp; + LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier); + + typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier; + // DMA Load WG will not participate in these Ordered Barrier syncs + params_math_wg_order_barrier.group_id = canonical_warp_group_idx() - static_cast(WarpGroupRole::Consumer0); + params_math_wg_order_barrier.group_size = NumThreadsPerWarpGroup; // Number of threads / participants in a group + MathWarpGroupOrderBarrier math_wg_order_barrier( + shared_storage.pipelines.math_wg_order, params_math_wg_order_barrier); + + // Initialize starting pipeline states for the collectives + // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + auto cluster_wait_fn = [&]() + { + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer thread blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) + { + cute::cluster_arrive_relaxed(); + return []() { cute::cluster_wait(); }; + } + else + { + __syncthreads(); + return []() {}; // do nothing + } + }(); + + // Separate out problem shape for convenience + // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + + // Get the appropriate blocks for this thread block -- potential for thread block locality + TiledMma tiled_mma; + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Prepare and partition the input tensors. Expects a tuple of tensors where: + // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) + // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) + auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); + static_assert(cute::tuple_size_v >= 3, + "Output of load_init must have at least three elements (A, B, Aux)"); + + // Extract out partitioned A and B. + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + Tensor gAux_xkl = get<2>(load_inputs); + + // Get pipeline stage increments from tensor shapes + auto k_tile_count = size<3>(gA_mkl); + auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape); + auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); + + TileScheduler scheduler{params.scheduler}; + + if (warp_group_role == WarpGroupRole::Consumer1) + { + // Advance 2nd Math WG to the next work tile for the startup + scheduler.advance_to_next_work(); + // Advance 2nd Math WG pipeline states to the end of 1st Math WG + mainloop_pipe_consumer_state.advance(k_tile_count); + epi_load_pipe_consumer_state.advance(c_tile_count); + epi_store_pipe_producer_state.advance(d_tile_count); + } + auto work_tile_info = scheduler.get_current_work(); + + // Wait for all thread blocks in the Cluster + cluster_wait_fn(); + + if (warp_group_role == WarpGroupRole::Producer) + { + cutlass::arch::warpgroup_reg_dealloc(); + + // Mainloop Producer Warp + if (producer_warp_role == ProducerWarpRole::Mainloop) + { + bool do_load_order_arrive = true; + while (work_tile_info.is_valid()) + { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl)); + + collective_mainloop.load(params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state, + load_inputs, blk_coord, k_tile_iter, k_tile_count, lane_idx, block_rank_in_cluster, + shared_storage.tensors.mainloop); + // Update starting pipeline state for the next tile + mainloop_pipe_producer_state.advance(k_tile_count); + + // Signal for the epilogue load warp to begin + if (do_load_order_arrive) + { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + + // Get next work tile + scheduler.advance_to_next_work(); + work_tile_info = scheduler.get_current_work(); + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + } // Mainloop Producer Warp End + + // Epilogue Producer Warp + else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) + { + load_order_barrier.wait(); + while (work_tile_info.is_valid()) + { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + epi_load_pipe_producer_state + = collective_epilogue.load(epi_load_pipeline, epi_load_pipe_producer_state, problem_shape_MNKL, + blk_shape, blk_coord, tiled_mma, lane_idx, shared_storage.tensors.epilogue); + + // Get next work tile + scheduler.advance_to_next_work(); + work_tile_info = scheduler.get_current_work(); + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } // Epilogue Producer Warp End + } // Producer Warp Group End + + else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) + { + cutlass::arch::warpgroup_reg_alloc(); + + float scale_d0 = params.mainloop.scale_d0; + float scale_d1 = params.mainloop.scale_d1; + while (work_tile_info.is_valid()) + { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Allocate the accumulators for the (M,N) blk_shape + Tensor accumulators0 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) + Tensor accumulators1 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) + + // Order two Math WG's MMA one after the other, helps hide Epilogue + math_wg_order_barrier.wait(); + + collective_mainloop.mma(mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0, accumulators1, + k_tile_count, warp_group_thread_idx, shared_storage.tensors.mainloop, params.mainloop); + + // Cue for next Math WG's MMA to start + math_wg_order_barrier.arrive(); + + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail(mainloop_pipeline, mainloop_pipe_consumer_state, k_tile_count); + // Update starting mainloop pipeline state for the next tile + mainloop_pipe_consumer_state.advance(k_tile_count * NumMmaWarpGroups); + + Activation elt_op; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators0); i++) + { + accumulators0[i] = (accumulators0[i] * scale_d0) * elt_op(scale_d1 * accumulators1[i]); + } + + // Order two Math WG's Epilogue one after the other + math_wg_order_barrier.wait(); + + // Epilogue and write to gD + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] + = collective_epilogue.store(epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline, + epi_store_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, accumulators0, + tiled_mma, warp_group_thread_idx, shared_storage.tensors.epilogue); + + // TMA store pipeline wait is only visible to TMA-issuing warp, so for multiple-consumer kernels + // we need to wait for all TMA stores to complete before issuing consumer order barrier arrives + // to ensure next math consumer doesn't overwrite smem of in-flight TMA stores of current consumer. + auto [epi_load_pipe_consumer_state_next_, epi_store_pipe_producer_state_next_] + = collective_epilogue.store_tail(epi_load_pipeline, epi_load_pipe_consumer_state_next, + epi_store_pipeline, epi_store_pipe_producer_state_next); + + // Update starting load/store pipeline states for the next tile + // state has already been incremented by 1 tile in collective calls, advance once again for ping pong + epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next_; + epi_store_pipe_producer_state = epi_store_pipe_producer_state_next_; + epi_load_pipe_consumer_state.advance(c_tile_count); + epi_store_pipe_producer_state.advance(d_tile_count); + + // Cue for next Math WG's Epilogue to start + math_wg_order_barrier.arrive(); + + // Get next work tile + scheduler.advance_to_next_work(NumMmaWarpGroups); + work_tile_info = scheduler.get_current_work(); + } // Scheduler work fetch loop + } // Consumer Warp Groups End +#endif + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h new file mode 100644 index 00000000000..5e3531f0938 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h @@ -0,0 +1,494 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h +*/ + +#pragma once + +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SplitkGemmGrouped +{ +public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + static bool const kTransposed = Transposed; + + // Optional transpose + using MapArguments = kernel::detail::MapArguments; + + // Public-facing type definitions related to operand element type, layout, and complex conjugate + // operation. Must interact with the 'kTransposed' notion. + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename MapArguments::LayoutC; + + using ElementFinalOutput = typename MapArguments::ElementA; + + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = MapArguments::kAlignmentA; + static int const kAlignmentB = MapArguments::kAlignmentB; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using ProblemVisitor + = GemmGroupedProblemVisitor; + + // + // Structures + // + + /// Argument structure + struct Arguments + { + + // + // Data members + // + + GemmCoord* problem_sizes; + int problem_count; + int threadblock_count; + + typename EpilogueOutputOp::Params output_op; + + ElementA** ptr_A; + ElementB** ptr_B; + ElementFinalOutput** ptr_C; + ElementFinalOutput** ptr_D; + + typename LayoutA::Stride::LongIndex* lda; + typename LayoutB::Stride::LongIndex* ldb; + typename LayoutC::Stride::LongIndex* ldc; + typename LayoutC::Stride::LongIndex* ldd; + + // Only used by device-level operator + GemmCoord* host_problem_sizes; + + // splitK + int split_k_slices; + int64_t* splitk_buffer_offsets; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() + : problem_count(0) + , threadblock_count(0) + , ptr_A(nullptr) + , ptr_B(nullptr) + , ptr_C(nullptr) + , ptr_D(nullptr) + , lda(nullptr) + , ldb(nullptr) + , ldc(nullptr) + , ldd(nullptr) + , host_problem_sizes(nullptr) + , split_k_slices(1) + , splitk_buffer_offsets(nullptr) + { + } + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments(GemmCoord* problem_sizes, int problem_count, int threadblock_count, + typename EpilogueOutputOp::Params output_op, ElementA** ptr_A, ElementB** ptr_B, ElementFinalOutput** ptr_C, + ElementFinalOutput** ptr_D, typename LayoutA::Stride::LongIndex* lda, + typename LayoutB::Stride::LongIndex* ldb, typename LayoutC::Stride::LongIndex* ldc, + typename LayoutC::Stride::LongIndex* ldd, GemmCoord* host_problem_sizes, int split_k_slices, + int64_t* splitk_buffer_offsets) + : problem_sizes(problem_sizes) + , problem_count(problem_count) + , threadblock_count(threadblock_count) + , output_op(output_op) + , ptr_A(ptr_A) + , ptr_B(ptr_B) + , ptr_C(ptr_C) + , ptr_D(ptr_D) + , lda(lda) + , ldb(ldb) + , ldc(ldc) + , ldd(ldd) + , host_problem_sizes(host_problem_sizes) + , split_k_slices(split_k_slices) + , splitk_buffer_offsets(splitk_buffer_offsets) + { + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params + { + + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + + typename EpilogueOutputOp::Params output_op; + + ElementA** ptr_A; + ElementB** ptr_B; + ElementFinalOutput** ptr_C; + ElementFinalOutput** ptr_D; + ElementC* ptr_C_split; + ElementC* ptr_D_split; + + typename LayoutA::Stride::LongIndex* lda; + typename LayoutB::Stride::LongIndex* ldb; + typename LayoutC::Stride::LongIndex* ldc; + typename LayoutC::Stride::LongIndex* ldd; + + // + // Methods + // + + // splitk + GemmCoord grid_tiled_shape; + int swizzle_log_tile; + int gemm_k_size; + GemmCoord* host_problem_sizes; + int split_k_slices; + int64_t* splitk_buffer_offsets; + + CUTLASS_HOST_DEVICE + Params() + : ptr_A(nullptr) + , ptr_B(nullptr) + , ptr_C(nullptr) + , ptr_D(nullptr) + , ptr_C_split(nullptr) + , ptr_D_split(nullptr) + , lda(nullptr) + , ldb(nullptr) + , ldc(nullptr) + , ldd(nullptr) + , swizzle_log_tile(0) + , gemm_k_size(0) + , host_problem_sizes(nullptr) + , split_k_slices(1) + , splitk_buffer_offsets(nullptr) + { + } + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + : problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count) + , host_problem_sizes(args.host_problem_sizes) + , threadblock_count(args.threadblock_count) + , output_op(args.output_op) + , ptr_A(args.ptr_A) + , ptr_B(args.ptr_B) + , ptr_C(args.ptr_C) + , ptr_D(args.ptr_D) + , ptr_C_split((ElementC*) workspace) + , ptr_D_split((ElementC*) workspace) + , lda(args.lda) + , ldb(args.ldb) + , ldc(args.ldc) + , ldd(args.ldd) + , split_k_slices(args.split_k_slices) + , splitk_buffer_offsets(args.splitk_buffer_offsets) + { + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.host_problem_sizes[0], + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.split_k_slices); + swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape); + + // only support same k + int full_gemm_k_iterations = args.host_problem_sizes[0].k() / Mma::Shape::kK; + int gemm_k_iterations = full_gemm_k_iterations / grid_tiled_shape.k(); + + gemm_k_size = gemm_k_iterations * Mma::Shape::kK; + } + + CUTLASS_HOST_DEVICE + void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + { + + problem_visitor = + typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, workspace, tile_count); + threadblock_count = args.threadblock_count; + output_op = args.output_op; + ptr_A = args.ptr_A; + ptr_B = args.ptr_B; + ptr_C = args.ptr_C; + ptr_D = args.ptr_D; + ptr_C_split = workspace; + ptr_D_split = workspace; + + lda = args.lda; + ldb = args.ldb; + ldc = args.ldc; + ldd = args.ldd; + } + }; + + /// Shared memory storage structure + struct SharedStorage + { + union + { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + } kernel; + + // ProblemVisitor shared storage can't be overlapped with others + typename ProblemVisitor::SharedStorage problem_visitor; + }; + +public: + // + // Methods + // + + CUTLASS_DEVICE + SplitkGemmGrouped() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) + { + return Status::kSuccess; + } + + static Status can_implement(Arguments const& args) + { + return Status::kSuccess; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) + { + + // + // These types shadow the type-level definitions and support the ability to implement + // a 'transposed' GEMM that computes the transposed problems. + // + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + + // + // Problem visitor. + // + ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) + { + + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + // Load element pointers. Exchange pointers and strides if working on the transpose + ElementA* ptr_A + = reinterpret_cast((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx])); + typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]); + + ElementB* ptr_B + = reinterpret_cast((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx])); + typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]); + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + cutlass::gemm::GemmCoord threadblock_offset(int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM, + int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, 0); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_offset.m(), + threadblock_tile_offset.k() * params.gemm_k_size, + }; + + cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size, threadblock_offset.n()}; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k; + if (threadblock_tile_offset.k() + 1 == params.grid_tiled_shape.k()) + { + problem_size_k = problem_size.k(); + } + else + { + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size_k}, thread_idx, tb_offset_A); + + typename Mma::IteratorB iterator_B( + LayoutB(ldm_B), ptr_B, {problem_size_k, problem_size.n()}, thread_idx, tb_offset_B); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = canonical_warp_idx_sync(); + + int lane_idx = threadIdx.x % 32; + + // + // Matrix multiply phase + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); + + // Wait for all threads to finish their epilogue phases from the previous tile. + __syncthreads(); + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + ElementC* ptr_C = params.ptr_C_split; + ElementC* ptr_D = params.ptr_D_split; + + LayoutC layout_C(params.ldc[problem_idx]); + LayoutC layout_D(params.ldd[problem_idx]); + + typename Epilogue::OutputTileIterator::Params params_C(layout_C); + typename Epilogue::OutputTileIterator::Params params_D(layout_D); + + // assume identity swizzle + MatrixCoord threadblock_offset_C(threadblock_offset.m(), threadblock_offset.n()); + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset_C); + + iterator_C.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k() + + gridDim.z * params.splitk_buffer_offsets[problem_idx]); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset_C); + iterator_D.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k() + + gridDim.z * params.splitk_buffer_offsets[problem_idx]); + + Epilogue epilogue(shared_storage.kernel.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // Next tile + problem_visitor.advance(gridDim.x); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h new file mode 100644 index 00000000000..ed5e3e4daf8 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h @@ -0,0 +1,125 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass_extensions/arch/mma.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ +//////////////////////////////////////////////////////////////////////////////// + +// We need to distinguish here, since we want volta support. It is too much effort +// to write shared memory iterators that are probably needed for volta to function +// properly. As a result, we allow converters both after the LDG (for volta) and after +// the LDS for Turing+. +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Warp level Mma + typename MmaOperator, + /// Math operation perform by warp level operator + typename MathOperator> +struct SetConverters +{ +}; + +// Dequantize after LDG, so set transforms accordingly +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Mma Policy + typename MmaOperator> +struct SetConverters +{ + using TransformAfterLDG + = FastInterleavedAndBiasedNumericArrayConverter; + + using TransformAfterLDS = NumericArrayConverter; +}; + +// Dequantize after LDS, so set transforms accordingly + +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Mma Policy + typename MmaOperator> +struct SetConverters +{ + using TransformAfterLDG = NumericArrayConverter; + + using TransformAfterLDS + = FastInterleavedAndBiasedNumericArrayConverter; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale_, + /// Layout for the scale operand + typename LayoutScale_, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// + typename Enable = void> +struct DqMma; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h new file mode 100644 index 00000000000..17c6346553c --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h @@ -0,0 +1,302 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass_extensions/arch/mma.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage.h" +#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" +#include "cutlass_extensions/tile_interleaved_layout.h" + +#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h" +#include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h" + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +//////////////////////////////////////////////////////////////////////////////// + +template +struct DefaultScaleIteratorsMultistage; + +// Fine grained iterators +template +struct DefaultScaleIteratorsMultistage> +{ + using IteratorScale + = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, Element, + Layout, 0, Alignment>; + + using SmemIteratorScale = IteratorScale; +}; + +// Per column iterators +template +struct DefaultScaleIteratorsMultistage> +{ + // ThreadMap for scale iterator + static_assert((MmaShape::kN % Alignment) == 0, ""); + +private: + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap, + MmaShape::kN / Alignment, Alignment>; + +public: + // Define iterators over tiles from the scale operand + using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator, + Element, Layout, 0, IteratorScaleThreadMap, Alignment>; + + using SmemIteratorScale = IteratorScale; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Stages in GEMM + int kStages, + /// Operator performed by GEMM + typename Operator_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +struct DqMma= 80 && !layout::IsColumnMajorTileInterleave::value)>::type> +{ + + static_assert(platform::is_same::value || platform::is_same::value + || platform::is_same::value, + "Element A must be fp16, fp8 or bf16"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, + AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementB, LayoutB, 0, ThreadMapB, + AccessTypeB>; + + using ScaleIterators = DefaultScaleIteratorsMultistage; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converter = FastInterleavedAndBiasedNumericArrayConverter; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage; +}; + +// Specialization to handle column major interleave B +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Stages in GEMM + int kStages, + /// Operator performed by GEMM + typename Operator_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +struct DqMma= 80 && layout::IsColumnMajorTileInterleave::value)>::type> +{ + + static_assert(platform::is_same::value || platform::is_same::value + || platform::is_same::value, + "Element A must be fp16, fp8 or bf16"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, + AccessTypeA>; + +private: + static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; + static constexpr int RowsPerTile = LayoutB::kRowsPerTile; + static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); + static_assert(RowsPerTile == MmaCore::Shape::kK, ""); + + using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; + using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; + static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); + + using GmemIteratorShape + = MatrixShape; + using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, OriginalThreadMap::kThreads, + layout::PitchLinearShape, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + +public: + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator; + + using ScaleIterators = DefaultScaleIteratorsMultistage; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converter = FastInterleavedAndBiasedNumericArrayConverter; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h new file mode 100644 index 00000000000..345cd2eec9a --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h @@ -0,0 +1,284 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass_extensions/arch/mma.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h" +#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" +#include "cutlass_extensions/tile_interleaved_layout.h" + +#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h" +#include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h" + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +//////////////////////////////////////////////////////////////////////////////// + +template +struct DefaultScaleIteratorsPipelined; + +// Fine grained iterators +template +struct DefaultScaleIteratorsPipelined> +{ +private: + using SmemScaleType = half_t; + +public: + using IteratorScale + = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, Element, + Layout, 0, Alignment>; + + using SmemIteratorScale + = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, + SmemScaleType, Layout, 0, Alignment>; +}; + +// Per column iterators +template +struct DefaultScaleIteratorsPipelined> +{ + static_assert((MmaShape::kN % Alignment) == 0, ""); + +private: + // ThreadMap for scale iterator + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap, + MmaShape::kN / Alignment, Alignment>; + using SmemScaleType = half_t; + +public: + // Define iterators over tiles from the scale operand + using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator, + Element, Layout, 0, IteratorScaleThreadMap, Alignment>; + + using SmemIteratorScale + = cutlass::transform::threadblock::PredicatedTileIterator, SmemScaleType, + Layout, 0, IteratorScaleThreadMap, Alignment>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator_> +struct DqMma::value)>::type> +{ + + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(OperatorInfo::QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); + + static constexpr bool DqAfterLDG = platform::is_same::value; + using MmaCoreElementA = half_t; + using MmaCoreElementB = typename platform::conditional::type; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, + typename MmaCore::IteratorThreadMapA, kAlignmentA>; + + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementB, LayoutB, 0, + typename MmaCore::IteratorThreadMapB, kAlignmentB>; + + using ScaleIterators = DefaultScaleIteratorsPipelined; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converters = SetConverters; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined; +}; + +// Specialization to handle column major interleave B +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator_> +struct DqMma::value)>::type> +{ + + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + + static constexpr bool DqAfterLDG = platform::is_same::value; + using MmaCoreElementA = half_t; + using MmaCoreElementB = typename platform::conditional::type; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, + typename MmaCore::IteratorThreadMapA, kAlignmentA>; + +private: + static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; + static constexpr int RowsPerTile = LayoutB::kRowsPerTile; + static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); + static_assert(RowsPerTile == MmaCore::Shape::kK, ""); + + using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; + using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; + static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); + + using GmemIteratorShape + = MatrixShape; + using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, OriginalThreadMap::kThreads, + layout::PitchLinearShape, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + +public: + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator; + + // ThreadMap for scale iterator + static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); + using IteratorScaleThreadMap + = transform::PitchLinearStripminedThreadMap, + MmaCore::Shape::kN / kAlignmentScale, kAlignmentScale>; + + using ScaleIterators = DefaultScaleIteratorsPipelined; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converters = SetConverters; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h new file mode 100644 index 00000000000..ad6c7496e14 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h @@ -0,0 +1,351 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" +#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" +#include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h" + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma pipelined (stage=2) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma pipelined (stage=2) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage +/// (stage>=3) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma multistage +/// (stage>=3) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +#ifdef ENABLE_FP8 +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage +/// (stage>=3) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +#endif + +// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on +// large tile when not enough shared mem is present to do 3+ stage +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma +{ + + // Define the MmaCore components + // 3 is used on purpose here to trigger components for mma multistage + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, half_t, LayoutA, 1, ThreadMapA, AccessTypeA, + GatherA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, half_t, LayoutB, 0, ThreadMapB, AccessTypeB, + GatherB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h new file mode 100644 index 00000000000..77af81005ab --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h @@ -0,0 +1,353 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" +#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & bf16 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma +{ + +private: + // Conversions only needed pre-ampere. This will trigger mma pipeline, so we convert before STS. + static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80; + using MmaElementA = typename platform::conditional::type; + using MmaElementB = typename platform::conditional::type; + +public: + // Define the MmaCore components + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, bfloat16_t, LayoutA, 1, + typename MmaCore::IteratorThreadMapA, kAlignmentA, GatherA>; + + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, bfloat16_t, LayoutB, 0, + typename MmaCore::IteratorThreadMapB, kAlignmentB, GatherB>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined; +}; + +// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on +// large tile when not enough shared mem is present to do 3+ stage +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma +{ + + // Define the MmaCore components + // 3 is used on purpose here to trigger components for mma multistage + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, bfloat16_t, LayoutA, 1, ThreadMapA, + AccessTypeA, GatherA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, bfloat16_t, LayoutB, 0, ThreadMapB, + AccessTypeB, GatherB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h new file mode 100644 index 00000000000..1fb7f7eb28f --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h @@ -0,0 +1,257 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass_extensions/weight_only_quant_op.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +//////////////////////////////////////////////////////////////////////////////// +// SFINAE trick so I can keep the same loop code for Volta and dispatch to the +// correct warp level mma. On volta, all data is stored to shared memory as FP16. +template +CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D, + typename WarpMma::FragmentA const& A, typename WarpMma::FragmentB const& B, typename WarpMma::FragmentC const& C, + int const warp_tileB_k_offset) +{ + warp_mma(D, A, B, C); +} + +template +CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D, + typename WarpMma::TransformedFragmentA const& A, typename WarpMma::TransformedFragmentB const& B, + typename WarpMma::FragmentC const& C, int const warp_tileB_k_offset) +{ + warp_mma(D, A, B, C, warp_tileB_k_offset); +} + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// The type of the scales + typename ElementScale_, + /// Number of stages, + int Stages, + /// The dequantizing op to be performed. + WeightOnlyQuantOp DequantOp, + /// Used for partial specialization, + typename Enable = bool> +class DqMmaBase +{ +public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + ///< Type of the scale to be loaded + using ElementScale = ElementScale_; + + static_assert(DequantOp != WeightOnlyQuantOp::UNDEFINED, ""); + + // Finegrained scales get streamed in via cp.async + static constexpr int ScalebiasStages = isFinegrained(DequantOp) ? Stages : 1; + // We always have scales. + static constexpr int ScaleElementsPerStage = Shape::kN; + // We sometimes have a bias + static constexpr int BiasElementsPerStage = hasZero(DequantOp) ? Shape::kN : 0; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM operations + static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + static constexpr int kNumKIterationsPerWarpBLoad + = Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK; + + static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), ""); + static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad; + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage + { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA + = MatrixShape; + + /// Shape of the B matrix operand in shared memory + using ShapeB + = MatrixShape; + + /// Shape of the shared memory buffer for the scales for the B matrix. + using ShapeScale = MatrixShape; + /// Shape of the shared memory buffer for the biases of the B matrix. + using ShapeZero = MatrixShape; + + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B; + + /// Buffer to hold scales for threadblock + AlignedBuffer operand_scale; + + /// Buffer to hold scales for threadblock + AlignedBuffer operand_zero; + + public: + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() + { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() + { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() + { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() + { + return TensorRefB{operand_B.data(), LayoutB()}; + } + }; + +protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage& shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx) + , warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) + { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h new file mode 100644 index 00000000000..3c4036dd8cc --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h @@ -0,0 +1,110 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Used for partial specialization + typename Enable = void> +class DqMmaMultistage; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h" +#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h" diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h new file mode 100644 index 00000000000..f81961dee3c --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h @@ -0,0 +1,708 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Iterators over scales in global memory + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Layout of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applied immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +class DqMmaMultistage> + : public DqMmaBase +{ +public: + ///< Base class + using Base = DqMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + static_assert(Base::SharedStorage::ShapeScale::kRow == Stages, ""); + static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, ""); + + /// Internal structure exposed for introspection. + struct Detail + { + + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA + = (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB + = (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + +private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave + = layout::IsColumnMajorTileInterleave::value; + static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + +private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory + SmemIteratorScale smem_iterator_scale_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + /// The group size for quantization + int const group_size, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx) + , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + {shared_storage.operand_zero.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) + , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) + , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), + shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_scales_and_advance(IteratorScale& iterator_scale, int stage = -1, int k_iter = -1) + { + static_assert(IteratorScale::Shape::kRow == 1, "Scale stride must be 1."); + + typename IteratorScale::AccessType* gmem_scale_ptr = iterator_scale.get_scale(); + typename IteratorScale::AccessType* gmem_zero_ptr = iterator_scale.get_zero(); + + typename IteratorScale::AccessType* smem_scale_ptr + = reinterpret_cast(this->smem_iterator_scale_.get_scale()); + typename IteratorScale::AccessType* smem_zero_ptr + = reinterpret_cast(this->smem_iterator_scale_.get_zero()); + + int const kSrcBytes = sizeof_bits::value * IteratorScale::kAlignment / 8; + + cutlass::arch::cp_async(smem_scale_ptr, gmem_scale_ptr, iterator_scale.valid()); + + if (gmem_zero_ptr != nullptr) + { + cutlass::arch::cp_async(smem_zero_ptr, gmem_zero_ptr, iterator_scale.valid()); + } + + if (iterator_scale.group_size_ == 64) + { + iterator_scale.add_tile_offset({1, 0}); + } + else if (iterator_scale.group_size_ == 128) + { + if constexpr (Shape::kK == 128) + { + iterator_scale.add_tile_offset({1, 0}); + } + else if constexpr (Shape::kK == 64) + { + if (iterator_scale.row_groupsize64_ & 0x1) + { + iterator_scale.add_tile_offset({1, 0}); + } + } + else + { + static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128"); + } + } + + iterator_scale.row_groupsize64_++; + + this->smem_iterator_scale_.add_tile_offset({1, 0}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance( + IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0) + { + iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) + { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) + { + typename IteratorA::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value + * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) + { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) + { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + else + { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) + { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) + { + typename IteratorB::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value + * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) + { + auto gmem_ptr = iterator_B.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) + { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + else + { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over scale operand in global memory + IteratorScale iterator_scale, + ///< initial value of accumulator + FragmentC const& src_accum) + { + + // + // Prologue + // + + TransformBAfterLDS lds_converter; + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) + { + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) + { + typename IteratorA::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) + { + int const kSrcBytes = sizeof_bits::value + * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) + { + typename IteratorB::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) + { + int const kSrcBytes = sizeof_bits::value + * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + copy_scales_and_advance(iterator_scale, stage, gemm_k_iterations); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels + // so that all accumulator elements outside the GEMM footprint are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) + { + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) + { + + typename IteratorA::AccessType* dst_ptr + = reinterpret_cast(last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) + { + + typename IteratorB::AccessType* dst_ptr + = reinterpret_cast(last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + typename Dequantizer::FragmentScale warp_frag_scales; + typename Dequantizer::FragmentZero warp_frag_zeros; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + warp_dequantizer_.add_pointer_offset(Shape::kN); + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) + { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) + { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) + { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + typename TransformBAfterLDS::result_type converted_frag_B + = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zeros); + + using FragmentOperandB = cutlass::Array; + constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements; + static_assert(ConversionVectorWidth == FragmentOperandB::kElements); + + using Converter + = cutlass::NumericArrayConverter; + + FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B); + run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum, + warp_tileB_k_compute_offset); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) + { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + + // This is the first group of a given stage, so we issue the loads for the B scales immediately. + if (group_start_iteration_B == 0) + { + copy_scales_and_advance(iterator_scale); + } + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) + { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - + // #committed) + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) + { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } + else + { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) + { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); + warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN); + smem_read_stage_idx = 0; + } + else + { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + } + } + + // Load the scale needed for the next tile iteration. + warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); + // Update internal pointer to set of scales in shared memory. + warp_dequantizer_.add_pointer_offset(Shape::kN); + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) + { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h new file mode 100644 index 00000000000..83efdc5cb01 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h @@ -0,0 +1,647 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Iterators over scales in global memory + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Layout of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +class DqMmaMultistage> + : public DqMmaBase +{ +public: + ///< Base class + using Base = DqMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + // + // Dependent types + // + + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail + { + + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA + = (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB + = (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + +private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave + = layout::IsColumnMajorTileInterleave::value; + static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + +private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale operand to shared memory + SmemIteratorScale smem_iterator_scale_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + ///< Group size for quantization. Not used by this main loop since it assumes per-column + int const group_size, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx) + , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) + , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) + , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance( + IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0) + { + iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) + { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) + { + typename IteratorA::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value + * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) + { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) + { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + else + { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) + { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) + { + typename IteratorB::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value + * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) + { + auto gmem_ptr = iterator_B.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) + { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + else + { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over scale operand in global memory + IteratorScale iterator_scale, + ///< initial value of accumulator + FragmentC const& src_accum) + { + + // + // Prologue + // + + TransformBAfterLDS lds_converter; + + // NOTE - switch to ldg.sts + // Issue this first, so cp.async.commit_group will commit this load as well. + // Note: we do not commit here and this load will commit in the same group as + // the first load of A. + FragmentScale tb_frag_scales; + tb_frag_scales.clear(); + iterator_scale.load(tb_frag_scales); + this->smem_iterator_scale_.store(tb_frag_scales); + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) + { + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) + { + typename IteratorA::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) + { + int const kSrcBytes = sizeof_bits::value + * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) + { + typename IteratorB::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) + { + int const kSrcBytes = sizeof_bits::value + * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels + // so that all accumulator elements outside the GEMM footprint are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) + { + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) + { + + typename IteratorA::AccessType* dst_ptr + = reinterpret_cast(last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) + { + + typename IteratorB::AccessType* dst_ptr + = reinterpret_cast(last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + typename Dequantizer::FragmentScale warp_frag_scales; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + warp_dequantizer_.load(warp_frag_scales); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) + { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) + { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) + { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + typename TransformBAfterLDS::result_type converted_frag_B + = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); + + using FragmentOperandB = cutlass::Array; + constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements; + static_assert(ConversionVectorWidth == FragmentOperandB::kElements); + + using Converter + = cutlass::NumericArrayConverter; + + FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B); + run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum, + warp_tileB_k_compute_offset); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) + { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) + { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - + // #committed) + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) + { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } + else + { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) + { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); + smem_read_stage_idx = 0; + } + else + { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + } + } + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) + { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h new file mode 100644 index 00000000000..bd3e38971b0 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h @@ -0,0 +1,106 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" +#include "cutlass_extensions/gemm_configs.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Converter for B matrix applied immediately after the LDG (before STS) + typename TransformBAfterLDG_, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Used for partial specialization + typename Enable = void> +class DqMmaPipelined; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h" +#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h" diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h new file mode 100644 index 00000000000..50bdd0d85b0 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h @@ -0,0 +1,486 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" +#include "cutlass_extensions/gemm_configs.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Iterators over scales in global memory + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Layout of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Converter for B matrix applied immediately after the LDG (before STS) + typename TransformBAfterLDG_, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_> +class DqMmaPipelined> + : public DqMmaBase +{ +public: + ///< Base class + using Base = DqMmaBase; + + using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + using TransformBAfterLDG = TransformBAfterLDG_; + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2"); + + static_assert(Base::SharedStorage::ShapeScale::kRow == Base::kStages, ""); + static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, ""); + +private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using WarpFragmentScale = typename Dequantizer::FragmentScale; + using WarpFragmentZero = typename Dequantizer::FragmentZero; + + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave + = layout::IsColumnMajorTileInterleave::value; + static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + +protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory + SmemIteratorScale smem_iterator_scale_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaPipelined(typename Base::SharedStorage& + shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM + int const group_size, ///< The group size for quantization + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx) + , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + {shared_storage.operand_zero.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) + , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) + , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), + shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size) + { + + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_scales_and_advance(IteratorScale& iterator_scale) + { + using TransformScale = NumericArrayConverter; + + FragmentScale tb_frag_scales; + FragmentScale tb_frag_zeros; + tb_frag_scales.clear(); + tb_frag_zeros.clear(); + + TransformScale transformScale; + + using FragmentElement = typename FragmentScale::Element; + + auto gmem_scale_ptr = iterator_scale.get_scale(); + auto gmem_zero_ptr = iterator_scale.get_zero(); + + arch::global_load(tb_frag_scales, gmem_scale_ptr, iterator_scale.valid()); + + if (gmem_zero_ptr != nullptr) + { + arch::global_load( + tb_frag_zeros, gmem_zero_ptr, iterator_scale.valid()); + } + + typename TransformScale::result_type tb_frag_scales_fp16 = transformScale(tb_frag_scales); + typename TransformScale::result_type tb_frag_zeros_fp16; + if (gmem_zero_ptr != nullptr) + tb_frag_zeros_fp16 = transformScale(tb_frag_zeros); + + auto frag_scale_ptr_fp16 = reinterpret_cast(&tb_frag_scales_fp16); + auto frag_zero_ptr_fp16 = reinterpret_cast(&tb_frag_zeros_fp16); + auto smem_scale_ptr = this->smem_iterator_scale_.get_scale(); + auto smem_zero_ptr = this->smem_iterator_scale_.get_zero(); + + if (iterator_scale.valid()) + { + auto smem_offset = cast_smem_ptr_to_uint(smem_scale_ptr); + arch::shared_store(smem_offset, frag_scale_ptr_fp16); + + if (gmem_zero_ptr != nullptr) + { + smem_offset = cast_smem_ptr_to_uint(smem_zero_ptr); + arch::shared_store(smem_offset, frag_zero_ptr_fp16); + } + } + + if (iterator_scale.group_size_ == 64) + { + iterator_scale.add_tile_offset({1, 0}); + } + else if (iterator_scale.group_size_ == 128) + { + if constexpr (Shape::kK == 128) + { + iterator_scale.add_tile_offset({1, 0}); + } + else if constexpr (Shape::kK == 64) + { + if (iterator_scale.row_groupsize64_ & 0x1) + { + iterator_scale.add_tile_offset({1, 0}); + } + } + else + { + static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128"); + } + } + + iterator_scale.row_groupsize64_++; + + this->smem_iterator_scale_.add_tile_offset({1, 0}); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + IteratorScale iterator_scale, ///< iterator over scale operand in global memory + FragmentC const& src_accum) + { ///< source accumulator tile + + // + // Prologue + // + TransformBAfterLDG ldg_converter; + TransformBAfterLDS lds_converter; + + using TransformA + = NumericArrayConverter; + + // These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want + // to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS. + TransformA transformA; + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + + tb_frag_A.clear(); + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transformA(tb_frag_A)); + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + copy_scales_and_advance(iterator_scale); + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + WarpFragmentScale warp_frag_scales; + WarpFragmentZero warp_frag_zero; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + warp_dequantizer_.load(warp_frag_scales, warp_frag_zero); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + warp_dequantizer_.add_pointer_offset(Shape::kN); + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); + iterator_scale.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tighest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) + { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) + { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) + { + + // Write fragments to shared memory + this->smem_iterator_A_.store(transformA(tb_frag_A)); + + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) + { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0}); + } + else + { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); + warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment. + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) + { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + if (warp_mma_k == 0) + { + + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + copy_scales_and_advance(iterator_scale); + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); + iterator_scale.clear_mask(gemm_k_iterations <= 2); + } + + typename TransformBAfterLDS::result_type converted_frag_B + = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zero); + run_warp_mma( + warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset); + } + + // Load the scales needed for the next tile iteration + warp_dequantizer_.load(warp_frag_scales, warp_frag_zero); + // Update internal pointer to the set of scales in shared memory + warp_dequantizer_.add_pointer_offset(Shape::kN); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h new file mode 100644 index 00000000000..316ea9f80a9 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h @@ -0,0 +1,399 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" +#include "cutlass_extensions/gemm_configs.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Iterators over scales in global memory + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Layout of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Converter for B matrix applied immediately after the LDG (before STS) + typename TransformBAfterLDG_, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_> +class DqMmaPipelined> + : public DqMmaBase +{ +public: + ///< Base class + using Base = DqMmaBase; + + using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + using TransformBAfterLDG = TransformBAfterLDG_; + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2"); + +private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave + = layout::IsColumnMajorTileInterleave::value; + static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + +protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale operand to shared memory + SmemIteratorScale smem_iterator_scale_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaPipelined(typename Base::SharedStorage& + shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM + int const group_size, ///< Will not be used, just to adapt to finegrained modifications and make the compilation + ///< successful. Because DqMmaPipelined is only enabled for sm<80, so even if this + ///< argument is not added, it does not affect compilation for sm>=80. + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx) + , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) + , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) + , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx) + { + + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + IteratorScale iterator_scale, ///< iterator over scale operand in global memory + FragmentC const& src_accum) + { ///< source accumulator tile + + // + // Prologue + // + TransformBAfterLDG ldg_converter; + TransformBAfterLDS lds_converter; + + using TransformA + = NumericArrayConverter; + + using TransformScale = NumericArrayConverter; + + // These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want + // to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS. + TransformA transformA; + TransformScale transformScale; + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + FragmentScale tb_frag_scales; + + using WarpFragmentScale = typename Dequantizer::FragmentScale; + WarpFragmentScale warp_frag_scales; + + tb_frag_A.clear(); + tb_frag_B.clear(); + tb_frag_scales.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + iterator_scale.load(tb_frag_scales); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transformA(tb_frag_A)); + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + this->smem_iterator_scale_.store(transformScale(tb_frag_scales)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + __syncthreads(); + + warp_dequantizer_.load(warp_frag_scales); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tighest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) + { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) + { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) + { + + // Write fragments to shared memory + this->smem_iterator_A_.store(transformA(tb_frag_A)); + + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) + { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } + else + { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment. + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) + { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + if (warp_mma_k == 0) + { + + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + + typename TransformBAfterLDS::result_type converted_frag_B + = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); + run_warp_mma( + warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset); + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h new file mode 100644 index 00000000000..350b247de2e --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h @@ -0,0 +1,107 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/warp/default_mma_tensor_op.h" +#include "cutlass/gemm/warp/mma_tensor_op.h" + +#include "cutlass_extensions/arch/mma.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" + +namespace cutlass +{ +namespace gemm +{ +namespace warp +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for m-by-n-by-kgroup +template < + /// Shape of one matrix production operation (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A elements, + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Number of partitions along K dimension + int PartitionsK, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor> +struct DefaultMmaTensorOp +{ + +private: + // Shape for computing the FP16s + using ComputeInstructionShape = InstructionShape_; + + // Chosen so we get K=16 for int8 and K=32 for int4. + static constexpr int LoadInstructionK = 128 / sizeof_bits::value; + + // Shape for loading the narrow data type from shared memory + using LoadInstructionShape = GemmShape; + +public: + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma, + cutlass::MatrixShape<1, 1>>; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h new file mode 100644 index 00000000000..7c5088894b4 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h @@ -0,0 +1,306 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing warp-level matrix multiply-accumulate operations targeting + Tensor Cores. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm80.h" +#include "cutlass/arch/mma_sm89.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma.h" + +#include "cutlass/gemm/warp/mma_tensor_op_policy.h" + +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace warp +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + typename Policy_, + /// Instruction shape to override shared memory iterators with + typename SharedMemoryInstructionShape_, + /// Number of partitions along K dimension + int PartitionsK_ = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Used for partial specialization + typename Enable = bool> +class MmaTensorOpComputeBWithF16 +{ +public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + static_assert((platform::is_same::value + && platform::is_same::value) + || (platform::is_same::value + && platform::is_same::value + && ArchTag::kMinComputeCapability >= 80) + || (platform::is_same::value + && platform::is_same::value + && ArchTag::kMinComputeCapability >= 89), + "MmaTensorOpCvtBToA only supports underlying HMMA/QMMA"); + + static_assert(platform::is_same::value + || (platform::is_same::value && ArchTag::kMinComputeCapability >= 80) + || (platform::is_same::value && ArchTag::kMinComputeCapability >= 89), + "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+, or FP8 on Ada"); + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Instruction shape to override shared memory iterators with + using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; + + static_assert( + SharedMemoryInstructionShape::kM == InstructionShape::kM, "M dimension of compute instruction must match load"); + static_assert( + SharedMemoryInstructionShape::kN == InstructionShape::kN, "N dimension of compute instruction must match load"); + + static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK; + + static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + +public: + /// Iterates over the A operand in memory + using IteratorA + = MmaTensorOpMultiplicandTileIterator, Operand::kA, ElementA, LayoutA, + MatrixShape, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = Array; + + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator, Operand::kB, ElementB, + LayoutB, MatrixShape, Policy::OpDelta::kRow, + kThreadCount, kPartitionsK>; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed B tile + using TransformedFragmentB = Array; + + /// Iterates over the C operand in memory + using IteratorC = MmaTensorOpAccumulatorTileIterator, ElementC, LayoutC, + typename ArchMmaOperator::Shape, typename Policy::OpDelta>; + + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; + + /// Number of mma operations performed + using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>; + +public: + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; + +public: + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaTensorOpComputeBWithF16() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C, + int const warp_tileB_k_offset) const + { + + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; + + static_assert( + TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, + "Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of " + "B"); + + D = C; + + MmaOperandA const* ptr_A = reinterpret_cast(&A); + MmaOperandB const* ptr_B = reinterpret_cast(&B); + MmaOperandC* ptr_D = reinterpret_cast(&D); + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + // Serpentine visitation order maximizing reuse of Rb + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) + { + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) + { + + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); + + int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; + if (AccumulatorsInRowMajor) + { // matrix B is reordered + mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n_offsetB], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } + else + { + mma(ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n_offsetB], + ptr_D[m_serpentine + n * MmaIterations::kRow]); + } + } + } +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + // Serpentine visitation order maximizing reuse of Ra + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) + { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) + { + + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine; + if (AccumulatorsInRowMajor) + { // matrix B is reordered + mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine_offsetB], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } + else + { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], ptr_A[m], ptr_B[n_serpentine_offsetB], + ptr_D[m + n_serpentine * MmaIterations::kRow]); + } + } + } +#else + assert(0); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h new file mode 100644 index 00000000000..1d5cd5d8985 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h @@ -0,0 +1,463 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" + +#include "cutlass/functional.h" +#include "cutlass/platform/platform.h" + +#include "cutlass_extensions/weight_only_quant_op.h" +#include "tensorrt_llm/common/cudaBf16Wrapper.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace warp +{ + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Matrix multiply operator + typename MmaOperator_, + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Operand identity + Operand Operand, + /// Data type of Scale elements + typename Element_, + /// Layout of operand + typename Layout_, + /// Number of threads participating in one matrix operation + int Threads, + /// + WeightOnlyQuantOp QuantOp_, + /// + typename Enable = void> +class MmaTensorOpDequantizer; + +//////////////////////////////////////////////////////////////////////////////// +// Bfloat specialization for Ampere +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// + WeightOnlyQuantOp QuantOp_> +class MmaTensorOpDequantizer= 80 + && platform::is_same::value>::type> +{ + +public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the scales + using ElementScale = bfloat16_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentScale = Array; + using FragmentZero = Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx) + { + int const warp_offset = warp_idx_n * Shape::kN; + int const quad = lane_idx / 4; + int const thread_offset = warp_offset + quad; + pointer_scale_ = smem_scales.data() + thread_offset; + if constexpr (hasZero(QuantOp)) + { + pointer_zero_ = smem_zeros.data() + thread_offset; + } + } + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) + : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) + { + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) + { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) + { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn + == FragmentDequantizedOperand::kElements, + ""); + + __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag); + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); + + __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]); + __nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) + { + operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2); + } + } +#else + // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should + // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid + // numerous conversion instructions in GEMM main loop. + arch::device_breakpoint(); +#endif + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag, FragmentScale& zero_frag) + { + if constexpr (hasZero(QuantOp)) + { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN]; + } + } + else + { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + } + + CUTLASS_DEVICE + void dequantize( + FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag) + { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn + == FragmentDequantizedOperand::kElements, + ""); + + __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag); + __nv_bfloat16 const* zero_ptr = reinterpret_cast<__nv_bfloat16 const*>(&zero_frag); + + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); + + __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]); + __nv_bfloat162 zerox2 = __bfloat162bfloat162(zero_ptr[mma_n_iter]); + __nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); + + if constexpr (hasZero(QuantOp)) + { + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) + { + operand_bf16x2_ptr[ii] = __hfma2(operand_bf16x2_ptr[ii], scalex2, zerox2); + } + } + else + { + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) + { + operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2); + } + } + } +#else + // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should + // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid + // numerous conversion instructions in GEMM main loop. + arch::device_breakpoint(); +#endif + } + + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) + { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_scale_ += offset; + pointer_zero_ += offset; + } + +private: + ElementScale const* pointer_scale_; + ElementScale const* pointer_zero_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +// Specialization for Turing & Ampere +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// + WeightOnlyQuantOp QuantOp_> +class MmaTensorOpDequantizer= 75 + && platform::is_same::value>::type> +{ + +public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the scales + using ElementScale = half_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentScale = Array; + using FragmentZero = Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx) + { + int const warp_offset = warp_idx_n * Shape::kN; + int const quad = lane_idx / 4; + int const thread_offset = warp_offset + quad; + pointer_scale_ = smem_scales.data() + thread_offset; + if constexpr (hasZero(QuantOp)) + { + pointer_zero_ = smem_zeros.data() + thread_offset; + } + } + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) + : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) + { + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) + { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) + { + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB + = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn + == FragmentDequantizedOperand::kElements, + ""); + + multiplies mul_op; + + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); + } + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag, FragmentScale& zero_frag) + { + if constexpr (hasZero(QuantOp)) + { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN]; + } + } + else + { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + } + + CUTLASS_DEVICE + void dequantize( + FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag) + { + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB + = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn + == FragmentDequantizedOperand::kElements, + ""); + + multiplies mul_op; + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + + if constexpr (hasZero(QuantOp)) + { + plus plus_op; + + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + operand_frag_ptr[mma_n_iter] + = plus_op(mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]), zero_frag[mma_n_iter]); + } + } + else + { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); + } + } + } + + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) + { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_scale_ += offset; + pointer_zero_ += offset; + } + +private: + ElementScale const* pointer_scale_; + ElementScale const* pointer_zero_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h new file mode 100644 index 00000000000..4acef2d180f --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h @@ -0,0 +1,224 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace tensorrt_llm +{ +namespace cutlass_extensions +{ +// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape +// in the kernel layout details when doing weight only quantization. +enum class CutlassTileConfig +{ + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + // SiMT config + CtaShape128x128x8_WarpShape64x64x8, + + // TensorCore configs CTA_N = 128, CTA_K = 64 + // Warp configs for M=16 + CtaShape16x128x64_WarpShape16x32x64, + // Warp configs for M=32 + CtaShape32x128x64_WarpShape32x32x64, + + // Warp configs for M=64 + CtaShape64x128x64_WarpShape32x64x64, + CtaShape64x64x128_WarpShape32x64x64, + CtaShape64x128x64_WarpShape64x32x64, + + // Warp configs for M=128 + CtaShape128x64x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape64x64x64, + CtaShape128x128x64_WarpShape128x32x64, + CtaShape128x256x64_WarpShape64x64x64, + + // Warp configs for M=256 + CtaShape256x128x64_WarpShape64x64x64, + + // TensorCore config CTA_N = 64, CTA_K = 128 + CtaShape128x64x128_WarpShape64x32x128, + + // TensorCore config CTA_N = 256, CTA_K = 64 + CtaShape16x256x64_WarpShape16x64x64, + + // TensorCore config CTA_N = 256, CTA_K = 128 + CtaShape16x256x128_WarpShape16x64x128 + +}; + +enum class SplitKStyle +{ + NO_SPLIT_K, + SPLIT_K_SERIAL, + STREAM_K, // Sm80+ + // SPLIT_K_PARALLEL // Not supported yet +}; + +enum class CutlassTileConfigSM90 +{ + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + // CTA configs for M=64 + CtaShape64x16x128B, + CtaShape64x32x128B, + CtaShape64x64x128B, + CtaShape64x128x128B, + CtaShape64x256x128B, + + // CTA configs for M=128 + CtaShape128x16x128B, + CtaShape128x32x128B, + CtaShape128x64x128B, + CtaShape128x128x128B, + CtaShape128x256x128B, + + // CTA configs for M=128 + CtaShape256x128x128B, +}; + +enum class MainloopScheduleType +{ + AUTO // Automatically selects between pingpong and cooperative schedules on Hopper. On older architectures, this + // defaults to the "legacy" main loop schedule. +}; + +enum class EpilogueScheduleType +{ + AUTO // Automatically chooses an epilogue schedule compatible with the selected main loop schedule for Hopper. For + // architectures older than hopper, the epilogue is always performed by the same thread block as the main loop. +}; + +enum class ClusterShape +{ + ClusterShape_1x1x1, + ClusterShape_2x1x1, + ClusterShape_1x2x1, + ClusterShape_2x2x1, + ClusterShape_1x8x1, + ClusterShape_8x1x1 +}; + +struct CutlassGemmConfig +{ + enum CandidateConfigTypeParam : int + { + NONE = 0, + WEIGHT_ONLY = 1u << 0, + SIMT_ONLY = 1u << 1, + INT8_ONLY = 1u << 2, + HOPPER = 1u << 3, + GROUPED_GEMM = 1u << 4, + FP8_ONLY = 1u << 5, + }; + + CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic; + SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K; + int split_k_factor = -1; + int stages = -1; + + // config options for sm90 + CutlassTileConfigSM90 tile_config_sm90 = CutlassTileConfigSM90::ChooseWithHeuristic; + MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO; + EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO; + ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1; + bool is_sm90 = false; + + CutlassGemmConfig() {} + + CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages) + : tile_config(tile_config) + , split_k_style(split_k_style) + , split_k_factor(split_k_factor) + , stages(stages) + , is_sm90(false) + { + } + + CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90, MainloopScheduleType mainloop_schedule, + EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape) + : tile_config_sm90(tile_config_sm90) + , mainloop_schedule(mainloop_schedule) + , epilogue_schedule(epilogue_schedule) + , cluster_shape(cluster_shape) + , is_sm90(true) + { + } + + std::string toString() const + { + std::stringstream tactic; + tactic << "Cutlass GEMM Tactic"; + if (tile_config_sm90 != tensorrt_llm::cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic) + { + assert(is_sm90 && "Invalid cutlass GEMM config"); + tactic << "\n\tstyle=TMA" + << "\n\ttile shape ID: " << (int) tile_config_sm90 << "\n\tcluster shape ID: " << (int) cluster_shape + << "\n\tmainloop sched: " << (int) mainloop_schedule << "\n\tepi sched: " << (int) epilogue_schedule; + } + else if (tile_config != tensorrt_llm::cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic) + { + assert(!is_sm90 && "Invalid cutlass GEMM config"); + tactic << "\n\tstyle=compatible" + << "\n\ttile shape ID: " << (int) tile_config << "\n\tstages: " << (int) stages + << "\n\tsplit k: " << (int) split_k_factor; + } + else + { + tactic << "\n\tundefined"; + } + tactic << "\n"; + return tactic.str(); + } +}; + +inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& config) +{ + // clang-format off + if (config.is_sm90) + { + out << "tile_config_sm90_enum: " << int(config.tile_config_sm90) + << ", mainloop_schedule_enum: " << int(config.mainloop_schedule) + << ", epilogue_schedule_enum: " << int(config.epilogue_schedule) + << ", cluster_shape_enum: " << int(config.cluster_shape); + } + else + { + out << "tile_config_enum: " << int(config.tile_config) + << ", split_k_style_enum: " << int(config.split_k_style) + << ", split_k_factor: " << config.split_k_factor + << ", stages: " << config.stages; + } + // clang-format on + return out; +} + +} // namespace cutlass_extensions +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h new file mode 100644 index 00000000000..44ba79680e6 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h @@ -0,0 +1,447 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t interleaved in a register +*/ + +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/array.h" +#include "cutlass/half.h" +#include "cutlass/numeric_types.h" + +namespace cutlass +{ + +// This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low +// bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally +// signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned. +// This converter will uninterleave the data and subtract the bias while converting to the result type. +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter +{ + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); + + // Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16. + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) + { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter +{ + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + uint32_t* bf16_result_ptr = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t fp32_base = 0x4B000000; + float fp32_intermediates[4]; + + // Construct FP32s, bfloat does not have enough mantissa for IADD trick + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + + // Subtract out fp32_base + 128 to make the unsigned integer signed. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 4; ++ii) + { + fp32_intermediates[ii] -= 8388736.f; + } + + // Truncate the fp32 representation and pack up as bfloat16s. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 2; ++ii) + { + bf16_result_ptr[ii] + = __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632); + } +#else + // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use + // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. + result.clear(); // Suppress compiler warning + arch::device_breakpoint(); +#endif + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) + { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter +{ + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM_MASK = 0x000f000f; + static constexpr uint32_t TOP_MASK = 0x00f000f0; + static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + + // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing + // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. + // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and + // elt_67 to fp16 without having to shift them to the bottom bits before hand. + + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue + // immediately before required. + const uint32_t top_i4s = i4s >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + + // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the + // half2 ctor. In this case, I chose performance reliability over code readability. + + // This is the half2 {1032, 1032} represented as an integer. + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + // This is the half2 {-72, -72} represented as an integer. + static constexpr uint32_t NEG_72 = 0xd480d480; + + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static constexpr int VEC_WIDTH = 8; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) + { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter +{ + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + uint32_t* h = reinterpret_cast(&result); + uint32_t const source_i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; + + // We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop. + // No shift needed for first item. + uint32_t i4s = source_i4s; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + CUTLASS_PRAGMA_UNROLL + for (int ii = 1; ii < result_type::kElements / 2; ++ii) + { + i4s >>= sizeof_bits::value; + // (i4s & 0x000f000f) | 0x43004300 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[ii]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + } + + // This is the BF16 {-136, -136} represented as an integer. + static constexpr uint32_t BF16_BIAS = 0xC308C308; + static constexpr uint32_t BF16_ONE = 0x3F803F80; + + // Finally, we construct the output numbers. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < result_type::kElements / 2; ++ii) + { + // Since this section is for Ampere+, we use bf16 fma to do the bias subtraction + asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); + } +#else + // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use + // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. + arch::device_breakpoint(); + result.clear(); // Suppress compiler warning. +#endif + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static constexpr int VEC_WIDTH = 8; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) + { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h new file mode 100644 index 00000000000..5a0cd295708 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h @@ -0,0 +1,66 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines new layouts needed for MoE +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/pitch_linear_coord.h" + +namespace cutlass +{ +namespace layout +{ + +template +struct ColumnMajorTileInterleave +{ + static constexpr int kRowsPerTile = RowsPerTile; + static constexpr int kColumnsInterleaved = ColumnsInterleaved; +}; + +template +struct IsColumnMajorTileInterleave +{ + static constexpr bool value = false; +}; + +template +struct IsColumnMajorTileInterleave> +{ + static constexpr bool value = true; +}; + +} // namespace layout +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h new file mode 100644 index 00000000000..6095925e372 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h @@ -0,0 +1,250 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates for visiting scales to be used when dequantizing the weights for weight-only GEMM + quantization. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace transform +{ +namespace threadblock +{ + +//////////////////////////////////////////////////////////////////////////////// + +template +class FineGrainedScaleZeroIterator; + +template +class FineGrainedScaleZeroIterator +{ +public: + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = 0; + static int const kAlignment = Alignment_; + + static int const kAccessesPerVector = 1; + + /// Row index of scales corresponding to the groupsize of 64 + int row_groupsize64_; + int group_size_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using AccessType = AlignedArray; + + using Fragment = cutlass::Array; + + // For compatibility with existing iterator interface + struct Params + { + LongIndex stride_ = 0; + + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_ = 0; + + // Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : stride_(layout.stride(0)) + { + inc_advance_ = Shape::kRow * stride_ * sizeof_bits::value / 8; + } + }; + +private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + +private: + // + // Data members + // + + /// Parameters object with precomputed internal state + Params const params_; + + /// Internal pointer to first access of tile + BytePointer pointer_scale_; + BytePointer pointer_zero_; + + bool is_valid_ = false; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_DEVICE + FineGrainedScaleZeroIterator( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of scale tensor + Pointer pointer_scale, + ///< Pointer to start of zero tensor + Pointer pointer_zero, + ///< Extent of the scale and bias + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + ///< Group size + int group_size) + : params_(params) + , pointer_scale_(reinterpret_cast(const_cast(pointer_scale))) + , pointer_zero_(reinterpret_cast(const_cast(pointer_zero))) + { + row_groupsize64_ = threadblock_offset.row(); + group_size_ = group_size; + + const LongIndex tb_row_byte_offset + = threadblock_offset.row() / (group_size / 64) * params_.stride_ * sizeof_bits::value / 8; + const LongIndex tb_col_byte_offset = threadblock_offset.column() * sizeof_bits::value / 8; + pointer_scale_ += (tb_row_byte_offset + tb_col_byte_offset); + + if (pointer_zero_ != nullptr) + { + pointer_zero_ += (tb_row_byte_offset + tb_col_byte_offset); + } + + static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment; + + int const thread_row = thread_id / THREADS_PER_ROW; + int const thread_col = thread_id % THREADS_PER_ROW; + + const LongIndex thread_row_byte_offset = thread_row * params_.stride_ * sizeof_bits::value / 8; + const LongIndex thread_col_byte_offset = thread_col * kAlignment * sizeof_bits::value / 8; + pointer_scale_ += (thread_row_byte_offset + thread_col_byte_offset); + if (pointer_zero_ != nullptr) + { + pointer_zero_ += (thread_row_byte_offset + thread_col_byte_offset); + } + + // For the rows, we must check that we are within the extent AND the tile to avoid extra reads on + // a given iteration. The same threads will be responsible for issues reads since the number of scales + // read in a given iteration is a constant. Therefore, we should never have to update is_valid_ + // outside of the constructor. + int const global_row = threadblock_offset.row() + thread_row; + int const global_col = threadblock_offset.column() + thread_col * kAlignment; + + bool const row_in_bounds = global_row < extent.row() && thread_row < Shape::kRow; + bool const col_in_bounds = global_col < extent.column(); + + is_valid_ = row_in_bounds && col_in_bounds; + } + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE FineGrainedScaleZeroIterator(Params const& params, ///< Precomputed parameters object + Pointer pointer_scale, ///< Pointer to start of scale tensor + Pointer pointer_zero, ///< Pointer to start of zero tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + int group_size) + : FineGrainedScaleZeroIterator( + params, pointer_scale, pointer_zero, extent, thread_id, make_Coord(0, 0), group_size) + { + } + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) + { + const LongIndex row_byte_offset = tile_offset.row() * params_.inc_advance_; + const LongIndex col_byte_offset = tile_offset.column() * Shape::kColumn * sizeof_bits::value / 8; + pointer_scale_ += row_byte_offset + col_byte_offset; + if (pointer_zero_ != nullptr) + { + pointer_zero_ += row_byte_offset + col_byte_offset; + } + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE void clear_mask(bool enable = true) + { + is_valid_ &= (!enable); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const + { + return is_valid_; + } + + /// Returns a scale pointer + CUTLASS_HOST_DEVICE + AccessType* get_scale() const + { + return reinterpret_cast(pointer_scale_); + } + + /// Returns a zero pointer + CUTLASS_HOST_DEVICE + AccessType* get_zero() const + { + return reinterpret_cast(pointer_zero_); + } +}; + +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp new file mode 100644 index 00000000000..b430380b014 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp @@ -0,0 +1,181 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cute/util/print.hpp" + +using namespace cute; + +/// Function object that applies an index to its argument +template +struct IndexedGather +{ + CUTE_HOST_DEVICE constexpr IndexedGather(Iter indices = {}) + : indices_(indices) + { + } + + template + CUTE_HOST_DEVICE constexpr auto operator()(I i) const + { + return indices_[i]; + } + + CUTE_HOST_DEVICE friend void print(IndexedGather const& s) + { + cute::print("Indexed{"); + print(s.indices_); + print("}"); + } + + Iter indices_; +}; + +/// Custom stride object that applies a function followed by a stride +template +struct CustomStride +{ + CUTE_HOST_DEVICE constexpr CustomStride(Func const& func, Stride const& stride) + : func_(func) + , stride_(stride) + { + } + + template + CUTE_HOST_DEVICE constexpr friend auto operator*(I i, CustomStride const& s) + { + return s.func_(i) * s.stride_; + } + + template + CUTE_HOST_DEVICE constexpr friend auto operator*(CustomStride const& s, I i) + { + return s.func_(i) * s.stride_; + } + + CUTE_HOST_DEVICE friend void print(CustomStride const& s) + { + cute::print("Custom{"); + print(s.func_); + cute::print(","); + print(s.stride_); + cute::print("}"); + } + + template + CUTE_HOST_DEVICE constexpr friend auto safe_div(CustomStride const& s, Div const& div) + { + return CustomStride(s.func_, safe_div(s.stride_, div)); + } + + // Circumvent the requirement on make_layout that shape and stride are integral + template + CUTE_HOST_DEVICE constexpr friend auto make_layout(Shape const& shape, CustomStride const& stride) + { + return Layout(shape, stride); + } + + Func func_; + Stride stride_; +}; + +template +CUTLASS_HOST_DEVICE auto make_custom_stride_layout(Stride const& stride, Func&& func) +{ + // Use a dummy shape and replace the first non-unit and non-zero stride with a custom gather stride + auto idx = find_if(stride, [](auto x) { return !is_constant<1, decltype(x)>{} && !is_constant<0, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + return make_layout( + repeat_like(stride, _1{}), replace(stride, CustomStride{static_cast(func), get(stride)})); +} + +/// Helper function to optionally create a gather tensor +template +CUTLASS_HOST_DEVICE auto make_gather_tensor(Iterator iter, Shape const& shape, Stride const& stride, Func&& func) +{ + Layout matrix_layout = make_identity_layout(shape); + auto offset = as_arithmetic_tuple(repeat_like(shape, _0{})); + Layout gather_layout = make_custom_stride_layout(stride, static_cast(func)); + return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout}); +} + +namespace cute +{ + +template +CUTE_HOST_DEVICE constexpr auto upcast(Shape const& shape, Stride const& stride) +{ + if constexpr (is_tuple::value) + { + return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast(s, d); }); + } + else if constexpr (is_scaled_basis::value) + { + if constexpr (Stride::mode() == I) + { + return make_layout(shape_div(shape, Int{}), shape_div(stride, Int{})); + } + else + { + return make_layout(shape, stride); + } + } + else + { + return upcast(shape, stride); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr auto upcast( + ComposedLayout, Offset, Layout> const& layout) +{ + // Find index of the stride-1 mode - that is the only one that requires updating inner shape and offset + auto idx = find_if(layout.layout_a().stride(), [](auto x) { return is_constant<1, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + + // Upcast the outer layout (works as expected) + auto outer = upcast(layout.layout_a()); + + // Upcast the accumulated offset along stride-1 mode + auto offset = as_arithmetic_tuple(replace(layout.offset(), upcast(get(layout.offset())))); + + // Upcast the inner layout's shape along stride-1 mode + auto inner = upcast(layout.layout_b().shape(), layout.layout_b().stride()); + + return composition(outer, offset, inner); +} + +} // namespace cute diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/weight_only_quant_op.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/weight_only_quant_op.h new file mode 100644 index 00000000000..64774428e9f --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/weight_only_quant_op.h @@ -0,0 +1,58 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. +*/ + +#pragma once + +namespace cutlass +{ + +enum class WeightOnlyQuantOp +{ + UNDEFINED, + PER_COLUMN_SCALE_ONLY, + FINEGRAINED_SCALE_ONLY, + FINEGRAINED_SCALE_AND_ZEROS +}; + +constexpr bool isFinegrained(WeightOnlyQuantOp op) +{ + return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS || op == WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; +} + +constexpr bool hasZero(WeightOnlyQuantOp op) +{ + return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; +} + +} // namespace cutlass diff --git a/sgl-kernel/THIRDPARTYNOTICES.txt b/sgl-kernel/THIRDPARTYNOTICES.txt index c930aa5dd3d..fcae14df3aa 100644 --- a/sgl-kernel/THIRDPARTYNOTICES.txt +++ b/sgl-kernel/THIRDPARTYNOTICES.txt @@ -223,3 +223,208 @@ BSD 3-Clause "New" License 3rdparty/cutlass include/flashinfer/attention/hopper/block_sparse_gather.cuh + +Notice for NVIDIA/TensorRT-LLM +------------------------------- + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 90c3cbc1d3c..50299140312 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -39,6 +39,8 @@ def _get_version(): cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default)) flashinfer = root / "3rdparty" / "flashinfer" turbomind = root / "3rdparty" / "turbomind" +tensorrt_llm_parent = root / "3rdparty" +tensorrt_llm = root / "3rdparty" / "tensorrt_llm" include_dirs = [ cutlass.resolve() / "include", cutlass.resolve() / "tools" / "util" / "include", @@ -51,6 +53,8 @@ def _get_version(): "cublasLt", turbomind.resolve(), turbomind.resolve() / "src", + tensorrt_llm_parent.resolve(), + tensorrt_llm.resolve() / "cutlass_extensions" / "include", ] nvcc_flags = [ From e81d7f11dede2b9b3f82de00a433eccc3d47c25e Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Thu, 30 Jan 2025 23:49:14 +0800 Subject: [PATCH 11/12] add tensorrt_llm moe_gemm as 3rdparty (#3217) --- .../tensorrt_llm/common/cudaBf16Wrapper.h | 21 + .../tensorrt_llm/common/cudaDriverWrapper.cpp | 187 ---- .../tensorrt_llm/common/cudaDriverWrapper.h | 138 --- .../launchers/fused_moe_gemm_launcher_sm80.h | 25 + .../fused_moe_gemm_launcher_sm80.inl | 96 ++ .../launchers/moe_gemm_launcher_sm90.h | 37 + .../launchers/moe_gemm_launcher_sm90.inl | 348 ++++++++ .../moe_gemm/moe_gemm_hopper_input.cu | 131 +++ .../moe_gemm/moe_gemm_kernels.h | 230 +++++ .../moe_gemm/moe_gemm_kernels_bf16_bf16.cu | 24 + .../moe_gemm/moe_gemm_kernels_bf16_uint4.cu | 24 + .../moe_gemm/moe_gemm_kernels_bf16_uint8.cu | 24 + .../moe_gemm/moe_gemm_kernels_fp16_fp16.cu | 22 + .../moe_gemm/moe_gemm_kernels_fp16_uint4.cu | 22 + .../moe_gemm/moe_gemm_kernels_fp16_uint8.cu | 22 + .../moe_gemm/moe_gemm_kernels_fp32_fp32.cu | 22 + .../moe_gemm/moe_gemm_kernels_fp8_fp8.cu | 28 + .../moe_gemm/moe_gemm_kernels_template.h | 823 ++++++++++++++++++ .../moe_gemm/moe_gemm_kernels_template_sm90.h | 222 +++++ .../moe_gemm/moe_sm90_traits.h | 44 + 20 files changed, 2165 insertions(+), 325 deletions(-) create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Wrapper.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.inl create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_hopper_input.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template_sm90.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Wrapper.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Wrapper.h new file mode 100644 index 00000000000..fb2a89af5cd --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Wrapper.h @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#ifdef ENABLE_BF16 +#include +#endif diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp deleted file mode 100644 index 7eca46a1cab..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#define CUDA_LIB_NAME "cuda" - -#if defined(_WIN32) -#include -#define dllOpen(name) LoadLibrary("nv" name ".dll") -#define dllClose(handle) FreeLibrary(static_cast(handle)) -#define dllGetSym(handle, name) static_cast(GetProcAddress(static_cast(handle), name)) -#else // For non-Windows platforms -#include -#define dllOpen(name) dlopen("lib" name ".so.1", RTLD_LAZY) -#define dllClose(handle) dlclose(handle) -#define dllGetSym(handle, name) dlsym(handle, name) -#endif // defined(_WIN32) - -#include "cudaDriverWrapper.h" -#include "tensorrt_llm/common/assert.h" -#include -#include - -namespace tensorrt_llm::common -{ - -std::shared_ptr CUDADriverWrapper::getInstance() -{ - static std::mutex mutex; - static std::weak_ptr instance; - std::shared_ptr result = instance.lock(); - if (result) - { - return result; - } - - std::lock_guard lock(mutex); - result = instance.lock(); - if (!result) - { - result = std::shared_ptr(new CUDADriverWrapper()); - instance = result; - } - return result; -} - -CUDADriverWrapper::CUDADriverWrapper() - : handle(dllOpen(CUDA_LIB_NAME)) -{ - - TLLM_CHECK_WITH_INFO(handle != nullptr, "CUDA driver library is not open correctly."); - - auto load_sym = [](void* handle, char const* name) - { - void* ret = dllGetSym(handle, name); - return ret; - }; - - *reinterpret_cast(&_cuGetErrorName) = load_sym(handle, "cuGetErrorName"); - *reinterpret_cast(&_cuGetErrorMessage) = load_sym(handle, "cuGetErrorMessage"); - *reinterpret_cast(&_cuFuncSetAttribute) = load_sym(handle, "cuFuncSetAttribute"); - *reinterpret_cast(&_cuLinkComplete) = load_sym(handle, "cuLinkComplete"); - *reinterpret_cast(&_cuModuleUnload) = load_sym(handle, "cuModuleUnload"); - *reinterpret_cast(&_cuLinkDestroy) = load_sym(handle, "cuLinkDestroy"); - *reinterpret_cast(&_cuModuleLoadData) = load_sym(handle, "cuModuleLoadData"); - *reinterpret_cast(&_cuLinkCreate) = load_sym(handle, "cuLinkCreate_v2"); - *reinterpret_cast(&_cuModuleGetFunction) = load_sym(handle, "cuModuleGetFunction"); - *reinterpret_cast(&_cuModuleGetGlobal) = load_sym(handle, "cuModuleGetGlobal_v2"); - *reinterpret_cast(&_cuLinkAddFile) = load_sym(handle, "cuLinkAddFile_v2"); - *reinterpret_cast(&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2"); - *reinterpret_cast(&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel"); - *reinterpret_cast(&_cuLaunchKernel) = load_sym(handle, "cuLaunchKernel"); - *reinterpret_cast(&_cuTensorMapEncodeTiled) = load_sym(handle, "cuTensorMapEncodeTiled"); - *reinterpret_cast(&_cuMemcpyDtoH) = load_sym(handle, "cuMemcpyDtoH_v2"); -} - -CUDADriverWrapper::~CUDADriverWrapper() -{ - dllClose(handle); -} - -CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, char const** pStr) const -{ - return (*_cuGetErrorName)(error, pStr); -} - -CUresult CUDADriverWrapper::cuGetErrorMessage(CUresult error, char const** pStr) const -{ - return (*_cuGetErrorMessage)(error, pStr); -} - -CUresult CUDADriverWrapper::cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const -{ - return (*_cuFuncSetAttribute)(hfunc, attrib, value); -} - -CUresult CUDADriverWrapper::cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const -{ - return (*_cuLinkComplete)(state, cubinOut, sizeOut); -} - -CUresult CUDADriverWrapper::cuModuleUnload(CUmodule hmod) const -{ - return (*_cuModuleUnload)(hmod); -} - -CUresult CUDADriverWrapper::cuLinkDestroy(CUlinkState state) const -{ - return (*_cuLinkDestroy)(state); -} - -CUresult CUDADriverWrapper::cuModuleLoadData(CUmodule* module, void const* image) const -{ - return (*_cuModuleLoadData)(module, image); -} - -CUresult CUDADriverWrapper::cuLinkCreate( - unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const -{ - return (*_cuLinkCreate)(numOptions, options, optionValues, stateOut); -} - -CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const -{ - return (*_cuModuleGetFunction)(hfunc, hmod, name); -} - -CUresult CUDADriverWrapper::cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const -{ - return (*_cuModuleGetGlobal)(dptr, bytes, hmod, name); -} - -CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, - unsigned int numOptions, CUjit_option* options, void** optionValues) const -{ - return (*_cuLinkAddFile)(state, type, path, numOptions, options, optionValues); -} - -CUresult CUDADriverWrapper::cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, - char const* name, unsigned int numOptions, CUjit_option* options, void** optionValues) const -{ - return (*_cuLinkAddData)(state, type, data, size, name, numOptions, options, optionValues); -} - -CUresult CUDADriverWrapper::cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, - unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, - unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const -{ - return (*_cuLaunchCooperativeKernel)( - f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams); -} - -CUresult CUDADriverWrapper::cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, - unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, - unsigned int sharedMemBytes, CUstream hStream, void** kernelParams, void** extra) const -{ - return (*_cuLaunchKernel)( - f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra); -} - -CUresult CUDADriverWrapper::cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, - cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, - cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, - CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const -{ - return (*_cuTensorMapEncodeTiled)(tensorMap, tensorDataType, tensorRank, globalAddress, globalDim, globalStrides, - boxDim, elementStrides, interleave, swizzle, l2Promotion, oobFill); -} - -CUresult CUDADriverWrapper::cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const -{ - return (*_cuMemcpyDtoH)(dstHost, srcDevice, ByteCount); -} - -} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h deleted file mode 100644 index c4d470a85f0..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef CUDA_DRIVER_WRAPPER_H -#define CUDA_DRIVER_WRAPPER_H - -#include "tensorrt_llm/common/assert.h" -#include -#include -#include -#include - -namespace tensorrt_llm::common -{ - -class CUDADriverWrapper -{ -public: - static std::shared_ptr getInstance(); - - ~CUDADriverWrapper(); - CUDADriverWrapper(CUDADriverWrapper const&) = delete; - CUDADriverWrapper operator=(CUDADriverWrapper const&) = delete; - CUDADriverWrapper(CUDADriverWrapper&&) = delete; - CUDADriverWrapper operator=(CUDADriverWrapper&&) = delete; - - CUresult cuGetErrorName(CUresult error, char const** pStr) const; - - CUresult cuGetErrorMessage(CUresult error, char const** pStr) const; - - CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const; - - CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const; - - CUresult cuModuleUnload(CUmodule hmod) const; - - CUresult cuLinkDestroy(CUlinkState state) const; - - CUresult cuModuleLoadData(CUmodule* module, void const* image) const; - - CUresult cuLinkCreate( - unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const; - - CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const; - - CUresult cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const; - - CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, unsigned int numOptions, - CUjit_option* options, void** optionValues) const; - - CUresult cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, char const* name, - unsigned int numOptions, CUjit_option* options, void** optionValues) const; - - CUresult cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, - unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, - unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const; - - CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, - unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, - CUstream hStream, void** kernelParams, void** extra) const; - - CUresult cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, cuuint32_t tensorRank, - void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, cuuint32_t const* boxDim, - cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle, - CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const; - - CUresult cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const; - -private: - void* handle; - CUDADriverWrapper(); - - CUresult (*_cuGetErrorName)(CUresult, char const**); - CUresult (*_cuGetErrorMessage)(CUresult, char const**); - CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int); - CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*); - CUresult (*_cuModuleUnload)(CUmodule); - CUresult (*_cuLinkDestroy)(CUlinkState); - CUresult (*_cuLinkCreate)(unsigned int, CUjit_option*, void**, CUlinkState*); - CUresult (*_cuModuleLoadData)(CUmodule*, void const*); - CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, char const*); - CUresult (*_cuModuleGetGlobal)(CUdeviceptr*, size_t*, CUmodule, char const*); - CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, char const*, unsigned int, CUjit_option*, void**); - CUresult (*_cuLinkAddData)( - CUlinkState, CUjitInputType, void*, size_t, char const*, unsigned int, CUjit_option*, void**); - CUresult (*_cuLaunchCooperativeKernel)(CUfunction, unsigned int, unsigned int, unsigned int, unsigned int, - unsigned int, unsigned int, unsigned int, CUstream, void**); - CUresult (*_cuLaunchKernel)(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, - unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, - CUstream hStream, void** kernelParams, void** extra); - CUresult (*_cuTensorMapEncodeTiled)(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, - cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, - cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, - CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill); - CUresult (*_cuMemcpyDtoH)(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount); -}; - -template -void checkDriver( - T result, CUDADriverWrapper const& wrap, char const* const func, char const* const file, int const line) -{ - if (result) - { - char const* errorName = nullptr; - char const* errorMsg = nullptr; - wrap.cuGetErrorName(result, &errorName); - wrap.cuGetErrorMessage(result, &errorMsg); - throw TllmException( - file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA driver error in %s: %s: %s", func, errorName, errorMsg)); - } -} - -} // namespace tensorrt_llm::common - -/* - * Macros compliant with TensorRT coding conventions - */ -#define TLLM_CU_CHECK(stat) \ - do \ - { \ - tensorrt_llm::common::checkDriver( \ - (stat), *tensorrt_llm::common::CUDADriverWrapper::getInstance(), #stat, __FILE__, __LINE__); \ - } while (0) - -#endif // CUDA_DRIVER_WRAPPER_H diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h new file mode 100644 index 00000000000..f4eed277c18 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +namespace tensorrt_llm::kernels::cutlass_kernels +{ +template +void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B, + ElementType_ const* biases, bool bias_is_broadcast, ElementType_* C, int64_t const* total_tokens_including_expert, + int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, + int* kernel_occupancy); +} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl new file mode 100644 index 00000000000..126e761ec93 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cutlass/array.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" + +#include "cute/tensor.hpp" +#include "cutlass/cutlass.h" + +#include +#include +#include + +namespace tensorrt_llm::kernels::cutlass_kernels +{ +template +void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B, + ElementType_ const* biases, bool bias_is_broadcast, ElementType_* C, int64_t const* total_tokens_including_expert, + int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, + int* kernel_occupancy) +{ + constexpr auto activation_type = fused_moe::EpilogueRouting(true); + using GemmType = fused_moe::Fused_Moe_Kernel_sm80; + + // make sure GPU has enough resources.. + if (kernel_occupancy != nullptr) + { + constexpr int smem_size = GemmType::kSmemSize; + + if (smem_size > (48 << 10)) + { + cudaFuncAttributes attr{}; + int device = 0; + int max_smem_per_block = 0; + tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device)); + tensorrt_llm::common::check_cuda_error( + cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); + tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, fused_moe::run_global)); + if (smem_size + attr.sharedSizeBytes >= static_cast(max_smem_per_block)) + { + // This should mean that + // cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + // smem_size) wouldn't work. In that case, we return an occupancy of 0. This will cause the + // heuristic to ignore this configuration. + *kernel_occupancy = 0; + return; + } + } + + int max_active_blocks = -1; + tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, fused_moe::run_global, GemmType::kThreadCount, smem_size)); + *kernel_occupancy = max_active_blocks; + return; + } + int occupancy = std::min(2, fused_moe::fused_gemm_maximum_active_blocks()); + int const threadblock_count = multi_processor_count * occupancy; + TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run fused_moe kernel"); + using Arguments = typename GemmType::Arguments; + Arguments args{{const_cast(A), const_cast(B), const_cast(biases), + reinterpret_cast(C), total_tokens_including_expert, static_cast(gemm_n), + static_cast(gemm_k), num_experts, bias_is_broadcast}, + num_experts, threadblock_count}; + auto params = GemmType::to_underlying_arguments(args); + if (GemmType::kSmemSize >= (48 << 10)) + { + cudaError_t result = cudaFuncSetAttribute( + fused_moe::run_global, cudaFuncAttributeMaxDynamicSharedMemorySize, GemmType::kSmemSize); + TLLM_CHECK_WITH_INFO(result == cudaSuccess, + "Fail to set the max smem size to " + std::to_string(GemmType::kSmemSize) + " for fused moe kernel"); + } + dim3 grid(params.threadblock_count, 1, 1); + dim3 block(GemmType::kThreadCount); + fused_moe::run_global<<>>(params); + auto result = cudaGetLastError(); + TLLM_CHECK_WITH_INFO(result == cudaSuccess, "Fail to execute fused moe kernel, cuda error %d\n", (int) (result)); +} +} // namespace tensorrt_llm::kernels::cutlass_kernels diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h new file mode 100644 index 00000000000..91527fadb67 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h @@ -0,0 +1,37 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" +#include + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ + +// Keep in sync with the signature generated by generate_kernels.py +template +void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, int num_experts, + int multi_processor_count, cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size); + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.inl b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.inl new file mode 100644 index 00000000000..cca60a9816f --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.inl @@ -0,0 +1,348 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include "cutlass/array.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" + +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/tensor_ref.h" + +#include "cutlass_extensions/compute_occupancy.h" +#include "cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp" +#include "cutlass_extensions/epilogue_helpers.h" +#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" +#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h" +#include "cutlass_extensions/gemm/threadblock/default_mma.h" + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h" +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h" + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h" +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" + +#include +#include +#include +#include + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ +using EpilogueFusion = HopperGroupedGemmInput::EpilogueFusion; + +// Hopper helper class for defining all the cutlass helper types +template +struct HopperGroupedGemmInfo +{ + using Arch = cutlass::arch::Sm90; + + // TODO Update once mixed input support is added + static_assert(cutlass::platform::is_same::value, + "CUTLASS does not currently have specialised SM90 support for quantized operations"); + +#ifdef ENABLE_FP8 + constexpr static bool IsFP8 + = cutlass::platform::is_same::value || cutlass::platform::is_same::value; +#else + constexpr static bool IsFP8 = false; +#endif + +#ifdef ENABLE_BF16 + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value + || cutlass::platform::is_same::value || IsFP8, + "Specialized for bfloat16, half, float, fp8"); +#else + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value || IsFP8, + "Specialized for half, float, fp8"); +#endif + + static_assert(cutlass::platform::is_same::value + || cutlass::platform::is_same::value + || cutlass::platform::is_same::value + || cutlass::platform::is_same::value + || cutlass::platform::is_same::value, + "Unexpected quantization type"); + + // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. + using ElementType = typename TllmToCutlassTypeAdapter::type; + + using CutlassWeightTypeMaybeUint4 = typename TllmToCutlassTypeAdapter::type; + // For legacy reasons we convert unsigned 8-bit to signed + using CutlassWeightTypeMaybeUint8 + = std::conditional_t, cutlass::int4b_t, + CutlassWeightTypeMaybeUint4>; + using CutlassWeightType + = std::conditional_t, int8_t, CutlassWeightTypeMaybeUint8>; + + using ElementA = ElementType; + using ElementB = CutlassWeightType; + + using ElementD = typename TllmToCutlassTypeAdapter>::type; + using ElementFinalOutput = typename TllmToCutlassTypeAdapter::type; + + // using ElementC = std::conditional_t; + // using ElementCNoVoid = std::conditional_t; + using ElementC = void; + using ElementCNoVoid = ElementD; + + using ElementAccumulator = float; + + using ElementBias = ElementFinalOutput; + using ElementRouterScales = float; + + // A matrix configuration - this is transposed and swapped with B + using LayoutA = HopperGroupedGemmInput::LayoutA; + constexpr static int AlignmentA + = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units + // of elements (up to 16 bytes) + + // B matrix configuration - this is transposed and swapped with A + using LayoutB = HopperGroupedGemmInput::LayoutB; // Layout type for B matrix operand + constexpr static int AlignmentB + = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units + // of elements (up to 16 bytes) + + // C matrix configuration + using LayoutC = HopperGroupedGemmInput::LayoutC; // Layout type for C matrix operand + using StrideC = HopperGroupedGemmInput::StrideC; + // Note we use ElementType here deliberately, so we don't break when BIAS is disabled + constexpr static int AlignmentC + = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units + // of elements (up to 16 bytes) + + // D matrix configuration + using LayoutD = HopperGroupedGemmInput::DefaultEpilogue::LayoutD; + using StrideD = HopperGroupedGemmInput::DefaultEpilogue::StrideD; + constexpr static int AlignmentD + = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix + // in units of elements (up to 16 bytes) + + static_assert(cutlass::platform::is_same::value, + "Hopper Grouped GEMM specialisation doesn't support fused activation"); + + using EpilogueOp + = cutlass::epilogue::fusion::LinearCombination; + + // TODO Add mode for fused activation once CUTLASS adds support + // using EpilogueSchedule = cutlass::platform::conditional_t< + // cutlass::platform::is_same::value, + // cutlass::epilogue::PtrArrayNoSmemWarpSpecialized, + // cutlass::epilogue::?????????????????? /// <<<<<< what supports activations + // >; + using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; + + // Epilogue For Default Finalize + using CollectiveEpilogueDefault = typename cutlass::epilogue::collective::CollectiveBuilder< // + Arch, cutlass::arch::OpClassTensorOp, // + TileShape, ClusterShape, // + cutlass::epilogue::collective::EpilogueTileAuto, // + ElementAccumulator, ElementAccumulator, // + ElementC, LayoutC*, AlignmentC, // + ElementD, LayoutD*, AlignmentD, // + EpilogueSchedule>::CollectiveOp; + + // Epilogue For Fused Finalize + using CollectiveEpilogueFinalize = typename cutlass::epilogue::collective::EpilogueMoeFusedFinalizeBuilder< // + TileShape, // + ElementCNoVoid, StrideC*, // + ElementFinalOutput, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideFinalOutput, // + ElementAccumulator, // + ElementAccumulator, // + ElementBias, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideBias, // + ElementRouterScales, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideRouterScales // + >::CollectiveOp; + + using CollectiveEpilogue + = std::conditional_t; + + using StageCountAutoCarveout = cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>; + + using KernelSchedule + = std::conditional_t; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< // + Arch, cutlass::arch::OpClassTensorOp, // + CutlassWeightType, LayoutB*, AlignmentB, // A & B swapped here + ElementType, LayoutA*, AlignmentA, // + ElementAccumulator, // + TileShape, ClusterShape, // + StageCountAutoCarveout, KernelSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal; + + using GemmGrouped = cutlass::gemm::device::GemmUniversalAdapter; +}; + +// Hopper specialised version +template +void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, int num_experts, + int const multi_processor_count, cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size) +{ +#ifdef COMPILE_HOPPER_TMA_GEMMS + using namespace cute; + if constexpr (!should_filter_sm90_gemm_problem_shape_v) + { + using GemmInfo + = HopperGroupedGemmInfo; + + using ElementAccumulator = typename GemmInfo::ElementAccumulator; + using ElementA = typename GemmInfo::ElementA; + using ElementB = typename GemmInfo::ElementB; + using ElementC = typename GemmInfo::ElementC; + using ElementCNoVoid = typename GemmInfo::ElementCNoVoid; + using ElementD = typename GemmInfo::ElementD; + + using CollectiveMainloop = typename GemmInfo::CollectiveMainloop; + using CollectiveEpilogue = typename GemmInfo::CollectiveEpilogue; + using GemmKernel = typename GemmInfo::GemmKernel; + using GemmGrouped = typename GemmInfo::GemmGrouped; + + if (kernel_occupancy != nullptr) + { + *kernel_occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel(); + return; + } + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = multi_processor_count; + + GemmGrouped gemm; + + if (workspace_size != nullptr) + { + // Make a mock problem shape with just the minimal information actually required to get the workspace size + // This makes some assumptions about CUTLASS's implementation which is suboptimal. We have a check later to + // catch future cutlass updates causing silent breakages, but that is not fool proof. + // The alternative is to wait until we have data and then dynamically allocate the workspace + typename HopperGroupedGemmInput::ProblemShape shape_info{num_experts, nullptr, nullptr}; + + typename GemmGrouped::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, shape_info, {}, {}, hw_info}; + *workspace_size = gemm.get_workspace_size(args); + return; + } + + using MainloopArguments = typename CollectiveMainloop::Arguments; + TLLM_CHECK(hopper_input.stride_a); + TLLM_CHECK(hopper_input.stride_b); + TLLM_CHECK(hopper_input.ptr_a); + TLLM_CHECK(hopper_input.ptr_b); + + MainloopArguments const mainloop_params = {reinterpret_cast(hopper_input.ptr_b), + hopper_input.stride_b, reinterpret_cast(hopper_input.ptr_a), hopper_input.stride_a}; + + typename GemmGrouped::EpilogueOutputOp::Params epilogue_scalars{ + ElementAccumulator(1.f), hopper_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f)}; + epilogue_scalars.alpha_ptr_array = hopper_input.alpha_scale_ptr_array; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + // TODO(dastokes) ptr_c casts to ElementCNoVoid** because there is a workaround in CUTLASS + auto make_epi_args = [&]() + { + if constexpr (FUSION == EpilogueFusion::NONE) + { + auto epi_params = hopper_input.default_epilogue; + return EpilogueArguments{epilogue_scalars, reinterpret_cast(hopper_input.ptr_c), + hopper_input.stride_c, reinterpret_cast(epi_params.ptr_d), epi_params.stride_d}; + } + else if constexpr (FUSION == EpilogueFusion::FINALIZE) + { + // Parameters for fused finalize + auto epi_params = hopper_input.fused_finalize_epilogue; + return EpilogueArguments{ + epilogue_scalars, // Parameters to underlying epilogue + reinterpret_cast(hopper_input.ptr_c), hopper_input.stride_c, // C params + reinterpret_cast(epi_params.ptr_final_output), + epi_params.stride_final_output, // D (output) params + reinterpret_cast(epi_params.ptr_bias), + epi_params.stride_bias, // Bias params + epi_params.ptr_router_scales, epi_params.stride_router_scales, // Router scales + epi_params.ptr_expert_first_token_offset, // Offset of this expert's token in the router scales + epi_params.ptr_source_token_index, // Index of the source token to sum into + epi_params.num_rows_in_final_output // Number of tokens in the output buffer + }; + } + else + { + static_assert( + sizeof(EpilogueArguments) == 0, "Unimplemented fusion provided to SM90+ MoE gemm launcher"); + } + }; + EpilogueArguments const epilogue_params = make_epi_args(); + + typename GemmKernel::TileScheduler::Arguments scheduler_args{ + 1, GemmKernel::TileScheduler::RasterOrderOptions::AlongN}; + + typename GemmGrouped::Arguments args{cutlass::gemm::GemmUniversalMode::kGrouped, hopper_input.shape_info, + mainloop_params, epilogue_params, hw_info, scheduler_args}; + + size_t calculated_ws_size = gemm.get_workspace_size(args); + TLLM_CHECK_WITH_INFO(calculated_ws_size <= hopper_input.gemm_workspace_size, + "Workspace is size %zu but only %zu were allocated", calculated_ws_size, hopper_input.gemm_workspace_size); + + auto can_implement = gemm.can_implement(args); + TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess, + "Grouped GEMM kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement))); + + auto init_status = gemm.initialize(args, hopper_input.gemm_workspace); + TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess, + "Failed to initialize cutlass SM90 grouped gemm. Error: " + + std::string(cutlassGetStatusString(init_status))); + + auto run_status = gemm.run(stream); + TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess, + "Failed to run cutlass SM90 grouped gemm. Error: " + std::string(cutlassGetStatusString(run_status))); + sync_check_cuda_error(); + } + else + { + TLLM_THROW("Configuration was disabled by FAST_BUILD"); + } + +#else // COMPILE_HOPPER_TMA_GEMMS + TLLM_THROW("Please recompile with support for hopper by passing 90-real as an arch to build_wheel.py."); +#endif // COMPILE_HOPPER_TMA_GEMMS +} + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_hopper_input.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_hopper_input.cu new file mode 100644 index 00000000000..9862460dd6a --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_hopper_input.cu @@ -0,0 +1,131 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/conv/convolution.h" +// Order matters here, packed_stride.hpp is missing cute and convolution includes +#include "cutlass/util/packed_stride.hpp" + +#include "tensorrt_llm/common/logger.h" + +namespace tensorrt_llm +{ +std::array HopperGroupedGemmInput::workspaceBuffers(int num_experts) +{ + size_t problem_shape_size = sizeof(ProblemShape::UnderlyingProblemShape) * num_experts; + size_t stride_a_size = sizeof(StrideA) * num_experts; + size_t stride_b_size = sizeof(StrideB) * num_experts; + size_t stride_c_size = sizeof(StrideC) * num_experts; + size_t stride_d_size = sizeof(DefaultEpilogue::StrideD) * num_experts; + + size_t ptr_buf_size = sizeof(void*) * num_experts; + size_t scale_buf_size = sizeof(float*) * num_experts; + + return std::array{problem_shape_size, stride_a_size, stride_b_size, stride_c_size, stride_d_size, ptr_buf_size, + ptr_buf_size, ptr_buf_size, ptr_buf_size, scale_buf_size}; +} + +size_t HopperGroupedGemmInput::workspaceSize(int num_experts) +{ + auto buffers = workspaceBuffers(num_experts); + return tensorrt_llm::common::calculateTotalWorkspaceSize(buffers.data(), buffers.size()); +} + +void HopperGroupedGemmInput::configureWorkspace( + int8_t* start_ptr, int num_experts, void* gemm_workspace, size_t gemm_workspace_size) +{ + auto buffers = workspaceBuffers(num_experts); + std::array pointers{}; + TLLM_CHECK_WITH_INFO(pointers.size() == buffers.size(), "Mismatching workspace size and number of buffers"); + for (int i = 0; i < buffers.size(); i++) + { + pointers[i] = start_ptr; + start_ptr = tensorrt_llm::common::nextWorkspacePtr(start_ptr, buffers[i]); + } + + shape_info.num_groups = num_experts; + shape_info.problem_shapes = reinterpret_cast(pointers[0]); + shape_info.host_problem_shapes = nullptr; + stride_a = reinterpret_cast(pointers[1]); + stride_b = reinterpret_cast(pointers[2]); + stride_c = reinterpret_cast(pointers[3]); + default_epilogue.stride_d = reinterpret_cast(pointers[4]); + + ptr_a = reinterpret_cast(pointers[5]); + ptr_b = reinterpret_cast(pointers[6]); + ptr_c = reinterpret_cast(pointers[7]); + default_epilogue.ptr_d = reinterpret_cast(pointers[8]); + + alpha_scale_ptr_array = reinterpret_cast(pointers[9]); + + this->gemm_workspace = reinterpret_cast(gemm_workspace); + this->gemm_workspace_size = gemm_workspace_size; +} + +void HopperGroupedGemmInput::setFinalizeFusionParams(void* final_output, float const* router_scales, + int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size, + int num_output_tokens) +{ + fused_finalize_epilogue.ptr_final_output = final_output; + fused_finalize_epilogue.ptr_router_scales = router_scales; + fused_finalize_epilogue.ptr_bias = bias; + fused_finalize_epilogue.ptr_expert_first_token_offset = expert_first_token_offset; + fused_finalize_epilogue.ptr_source_token_index = source_token_index; + + fused_finalize_epilogue.stride_final_output + = cutlass::make_cute_packed_stride(FusedFinalizeEpilogue::StrideFinalOutput{}, + transpose_stride(cute::make_shape(num_output_tokens, hidden_size, 1))); + fused_finalize_epilogue.stride_bias + = transpose_stride(cute::make_stride(cute::Int<0>{}, cute::Int<1>{}, hidden_size)); + fused_finalize_epilogue.stride_router_scales = {}; + + fused_finalize_epilogue.num_rows_in_final_output = num_output_tokens; +} + +std::string HopperGroupedGemmInput::toString() const +{ + std::stringstream ss; + ss << "Hopper Input Information: " << (isValid() ? "valid" : "null") << "\n"; + if (isValid()) + { + ss << "Ptr A: " << ptr_a << ", Ptr B: " << ptr_b << ", Ptr C: " << ptr_c << "\n"; + ss << "Epilogue Fusion: " << (int) fusion; + if (fusion == HopperGroupedGemmInput::EpilogueFusion::FINALIZE) + { + ss << ",\nFinal Output: " << fused_finalize_epilogue.ptr_final_output; + ss << " with Stride: " << fused_finalize_epilogue.stride_router_scales; + ss << ",\nBias: " << fused_finalize_epilogue.ptr_bias; + ss << " with Stride: " << fused_finalize_epilogue.stride_bias; + ss << ",\nRouter Scales: " << fused_finalize_epilogue.ptr_router_scales; + ss << " with Stride: " << fused_finalize_epilogue.stride_router_scales; + ss << ",\nExpert Offset: " << fused_finalize_epilogue.ptr_expert_first_token_offset; + ss << ", Source Map: " << fused_finalize_epilogue.ptr_source_token_index; + } + else + { + ss << ", Ptr D: " << default_epilogue.ptr_d; + } + ss << '\n'; + ss << "Alpha scale ptr: " << alpha_scale_ptr_array << "\n"; + } + return ss.str(); +} +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h new file mode 100644 index 00000000000..0616c063654 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h @@ -0,0 +1,230 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "tensorrt_llm/common/cudaFp8Utils.h" +#include "tensorrt_llm/common/workspace.h" +#include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h" +#include +#include +#include +#include + +#include "cute/tensor.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/layout/layout.h" + +namespace tensorrt_llm +{ +template +constexpr auto transpose_stride(T const& t) +{ + return cute::prepend(cute::prepend(cute::take<2, cute::rank_v>(t), cute::get<0>(t)), cute::get<1>(t)); +} + +struct HopperGroupedGemmInput +{ + template + using TransposeStride = decltype(transpose_stride(T{})); + template + using TransposeLayoutTag = std::conditional_t, + cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>; + + static_assert(std::is_same_v>); + static_assert(std::is_same_v>); + + // Layout for A and B is transposed and then swapped in the implementation + // This uses B^T * A^T = (A * B)^T to get a better layout for the GEMM + using LayoutA = TransposeLayoutTag; // Layout type for A matrix operand + using LayoutB = TransposeLayoutTag; // Layout type for B matrix operand + using LayoutC = TransposeLayoutTag; // Layout type for C matrix operand + + using StrideA + = std::remove_pointer_t>; // Use B because they will be swapped + using StrideB + = std::remove_pointer_t>; // Use A because they will be swapped + using StrideC = std::remove_pointer_t>; + + template + constexpr static bool IsFP8_v = std::is_same_v || std::is_same_v; + + // Currently this should always just be T + template + using OutputTypeAdaptor_t = std::conditional_t, nv_bfloat16, T>; + + using ProblemShape = cutlass::gemm::GroupProblemShape>; + + ProblemShape shape_info{}; + StrideA* stride_a = nullptr; + StrideB* stride_b = nullptr; + + void const** ptr_a = nullptr; + void const** ptr_b = nullptr; + + // C is currently the same in both epilogues + StrideC* stride_c = nullptr; + void const** ptr_c = nullptr; + + struct DefaultEpilogue + { + using LayoutD = TransposeLayoutTag; // Layout type for D matrix operand + using StrideD = std::remove_pointer_t>; + + StrideD* stride_d = nullptr; + void** ptr_d = nullptr; + }; + + struct FusedFinalizeEpilogue + { + using StrideFinalOutput = DefaultEpilogue::StrideD; + using StrideBias = TransposeStride>; + using StrideRouterScales = TransposeStride>; + + void* ptr_final_output = nullptr; + StrideFinalOutput stride_final_output{}; + + void const* ptr_bias = nullptr; + StrideBias stride_bias{}; + + float const* ptr_router_scales = nullptr; + StrideRouterScales stride_router_scales{}; + + int64_t const* ptr_expert_first_token_offset = nullptr; + int const* ptr_source_token_index = nullptr; + + size_t num_rows_in_final_output = 0; + }; + + DefaultEpilogue default_epilogue; + FusedFinalizeEpilogue fused_finalize_epilogue; + + enum class EpilogueFusion + { + NONE, + ACTIVATION, + GATED_ACTIVATION, + FINALIZE + }; + EpilogueFusion fusion = EpilogueFusion::NONE; + + float const** alpha_scale_ptr_array = nullptr; + + uint8_t* gemm_workspace = nullptr; + size_t gemm_workspace_size = 0; + + static std::array workspaceBuffers(int num_experts); + + static size_t workspaceSize(int num_experts); + + void configureWorkspace(int8_t* start_ptr, int num_experts, void* gemm_workspace, size_t gemm_workspace_size); + + bool isValid() const + { + return stride_a != nullptr && ptr_a != nullptr; + } + + void setFinalizeFusionParams(void* final_output, float const* router_scales, + int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size, + int num_output_tokens); + + std::string toString() const; +}; + +// Note update moe.py to match +enum class ActivationType +{ + Gelu = 0, + Relu, + Silu, + Swiglu, + Geglu, + Identity, + InvalidType +}; + +constexpr bool isGatedActivation(ActivationType activation_type) +{ + return activation_type == ActivationType::Swiglu || activation_type == ActivationType::Geglu; +} + +template +class MoeGemmRunner +{ +public: + MoeGemmRunner(); + +#if defined(ENABLE_FP8) + static constexpr bool use_fp8 = std::is_same_v || std::is_same_v; +#else + static constexpr bool use_fp8 = false; +#endif + + void moeGemmBiasAct(T const* A, WeightType const* B, ScaleBiasType const* weight_scales, + ScaleBiasType const* biases, bool bias_is_broadcast, void* C, int64_t const* total_tokens_including_expert, + HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + ActivationType activation_type, bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, + cutlass_extensions::CutlassGemmConfig chosen_conf); + + void moeGemm(T const* A, WeightType const* B, ScaleBiasType const* weight_scales, void* C, + int64_t const* total_tokens_including_expert, HopperGroupedGemmInput layout_info, int64_t total_rows, + int64_t gemm_n, int64_t gemm_k, int num_experts, bool use_fused_moe, float const** alpha_scale_ptr_array, + cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf); + + std::vector getConfigs() const; + static std::vector getConfigs(int sm); + static std::vector getHopperConfigs(int sm); + static std::vector getAmpereConfigs(int sm); + + [[nodiscard]] bool isHopperSpecialised(cutlass_extensions::CutlassGemmConfig gemm_config) const; + [[nodiscard]] bool supportsHopperSpecialisation() const; + [[nodiscard]] bool isFusedGatedActivation( + cutlass_extensions::CutlassGemmConfig gemm_config, bool is_gated_activation, int gemm_n, int gemm_k) const; + [[nodiscard]] bool supportsFusedGatedActivation(bool is_gated_activation, int gemm_n, int gemm_k) const; + + size_t getMaxWorkspaceSize(int num_experts) const; + + [[nodiscard]] int getSM() const; + +private: + template + void dispatchToArch(T const* A, WeightType const* B, ScaleBiasType const* weight_scales, + ScaleBiasType const* biases, bool bias_is_broadcast, void* C, int64_t const* total_tokens_including_expert, + HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, bool use_fused_moe, float const** alpha_scale_ptr_array, + cudaStream_t stream, int* occupancy = nullptr); + + template + void runGemm(T const* A, WeightType const* B, ScaleBiasType const* weight_scales, ScaleBiasType const* biases, + bool bias_is_broadcast, void* C, int64_t const* total_tokens_including_expert, + HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, + cutlass_extensions::CutlassGemmConfig chosen_conf); + +private: + int sm_{}; + int multi_processor_count_{}; + mutable int num_experts_ = 0; + mutable size_t gemm_workspace_size_ = 0; + size_t calcMaxWorkspaceSize(int num_experts) const; +}; + +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu new file mode 100644 index 00000000000..3aa96502d39 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +#ifdef ENABLE_BF16 +template class MoeGemmRunner<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16>; +#endif +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu new file mode 100644 index 00000000000..fbb5270455e --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +#ifdef ENABLE_BF16 +template class MoeGemmRunner<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16>; +#endif +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu new file mode 100644 index 00000000000..78f1a93a6a8 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +#ifdef ENABLE_BF16 +template class MoeGemmRunner<__nv_bfloat16, uint8_t, __nv_bfloat16>; +#endif +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu new file mode 100644 index 00000000000..69c4b6a15a8 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +template class MoeGemmRunner; +} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu new file mode 100644 index 00000000000..4ffa5485f0f --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +template class MoeGemmRunner; +} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu new file mode 100644 index 00000000000..424b817b876 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +template class MoeGemmRunner; +} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu new file mode 100644 index 00000000000..f317023565c --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +template class MoeGemmRunner; +} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu new file mode 100644 index 00000000000..c6b8fe78724 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" + +namespace tensorrt_llm +{ +#ifdef ENABLE_FP8 +template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, half>; +#ifdef ENABLE_BF16 +template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>; +#endif +// template class MoeGemmRunner<__nv_fp8_e5m2, __nv_fp8_e5m2>; +#endif +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h new file mode 100644 index 00000000000..2a337e6ca4e --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h @@ -0,0 +1,823 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Ignore CUTLASS warnings about type punning +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif + +#include "cutlass/array.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" + +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/tensor_ref.h" + +#include "cutlass_extensions/compute_occupancy.h" +#include "cutlass_extensions/epilogue_helpers.h" +#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" +#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h" +#include "cutlass_extensions/gemm/threadblock/default_mma.h" + +#ifdef __GNUC__ // Restore GCC-specific diagnostics +#pragma GCC diagnostic pop +#endif + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/logger.h" + +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h" + +#include "moe_gemm_kernels_template_sm90.h" +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h" +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h" +#include + +#include +#include +#include +#include + +namespace tensorrt_llm +{ +namespace kernels::cutlass_kernels +{ + +// ============================= Variable batched Gemm things =========================== +template +void genericMoeGemmKernelLauncher(T const* A, WeightType const* B, GemmOutputType const* weight_scales, + GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, + int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int const multi_processor_count, bool use_fused_moe, + float const** alpha_scale_ptr_array, cudaStream_t stream, int* kernel_occupancy = nullptr) +{ +#if defined(ENABLE_FP8) + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value + || cutlass::platform::is_same::value + || cutlass::platform::is_same::value || cutlass::platform::is_same::value, + "Specialized for fp8, bfloat16, half, float"); +#elif defined(ENABLE_BF16) + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value + || cutlass::platform::is_same::value, + "Specialized for bfloat16, half, float"); +#else + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value, + "Specialized for half, float"); +#endif + + static_assert(cutlass::platform::is_same::value + || cutlass::platform::is_same::value + || cutlass::platform::is_same::value, + ""); + + static_assert(!cutlass::platform::is_same::value, + "Sm90 architecture should use specialised kernels"); + + // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. + using ElementType = typename TllmToCutlassTypeAdapter::type; + using CutlassGemmOutputType = typename TllmToCutlassTypeAdapter::type; + using CutlassWeightType = typename TllmToCutlassTypeAdapter::type; + if (!use_fused_moe) + { + // We need separate config for each architecture since we will target different tensorcore instructions. For + // float, we do not target TCs. + using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; + using ElementAccumulator = typename MixedGemmArchTraits::AccType; + + using EpilogueOp = typename tensorrt_llm::cutlass_extensions::Epilogue::Op; + + typename EpilogueOp::Params epilogue_op( + ElementAccumulator(1.f), biases ? ElementAccumulator(1.f) : ElementAccumulator(0.f)); + +#if defined(ENABLE_FP8) + if constexpr ((std::is_same_v + || std::is_same_v) &&std::is_same_v) + { + TLLM_CHECK_WITH_INFO(weight_scales == nullptr && biases == nullptr && alpha_scale_ptr_array, + "weight_scales and biases should be nullptr and alpha_scale_ptr_array shouldn't be nullptr for FP8 " + "Ada"); + epilogue_op.alpha_ptr_array = alpha_scale_ptr_array; + } +#endif + + // Finally, set up the kernel. + using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped::GemmKernel; + + using GemmKernel = cutlass::gemm::kernel::MoeFCGemm; + + using GemmGrouped = cutlass::gemm::device::GemmGrouped; + + if (kernel_occupancy != nullptr) + { + *kernel_occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel(); + return; + } + int occupancy = std::min(2, GemmGrouped::maximum_active_blocks()); + TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run GroupedGEMM kernel"); + int const threadblock_count = multi_processor_count * occupancy; + + int const group_size = gemm_k; + typename GemmGrouped::Arguments args(num_experts, threadblock_count, group_size, epilogue_op, + reinterpret_cast(A), reinterpret_cast(B), + reinterpret_cast(weight_scales), + reinterpret_cast(biases), bias_is_broadcast, + reinterpret_cast(C), total_tokens_including_expert, gemm_n, gemm_k); + + GemmGrouped gemm; + + auto can_implement = gemm.can_implement(args); + TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess, + "MoE FC kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement))); + + auto init_status = gemm.initialize(args); + TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess, + "Failed to initialize cutlass grouped gemm. Error: " + std::string(cutlassGetStatusString(init_status))); + + auto run_status = gemm.run(stream); + TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess, + "Failed to run cutlass grouped gemm. Error: " + std::string(cutlassGetStatusString(run_status))); + } + else if constexpr (sizeof(ElementType) == 2 && sizeof(CutlassWeightType) == 2 + && (std::is_same_v + || std::is_same_v) ) // use fused moe gemm + // kernel.. (only support + // fp16 or bf16) + { + sm80_generic_fused_moe_gemm_kernelLauncher(reinterpret_cast(A), + reinterpret_cast(B), reinterpret_cast(biases), + bias_is_broadcast, reinterpret_cast(C), total_tokens_including_expert, num_rows, gemm_n, + gemm_k, num_experts, multi_processor_count, stream, kernel_occupancy); + } +} + +} // namespace kernels::cutlass_kernels + +template +static void dispatch(T const* A, WeightType const* B, GemmOutputType const* weight_scales, GemmOutputType const* biases, + bool bias_is_broadcast, GemmOutputType* C, int64_t const* total_tokens_including_expert, int64_t num_rows, + int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, + int multi_processor_count, bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, + int* occupancy = nullptr) +{ + + static_assert(!std::is_same_v, "Use TMA specialised functions for arch SM90"); +#if defined(ENABLE_FP8) + constexpr bool isFp8 = std::is_same_v || std::is_same_v; +#else + constexpr bool isFp8 = false; +#endif + + if constexpr ((Stages == 2 || Arch::kMinComputeCapability >= 80) + && (!isFp8 || std::is_same_v) ) + { + kernels::cutlass_kernels::genericMoeGemmKernelLauncher(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + } + else + { + TLLM_THROW( + "Cutlass gemm. Not instantiated for arch %d with stages set to %d", Arch::kMinComputeCapability, Stages); + } +} + +template +void dispatchGemmConfig(T const* A, WeightType const* B, GemmOutputType const* weight_scales, + GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, + int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, + float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr) +{ + switch (gemm_config.stages) + { + case 2: + dispatch(A, B, weight_scales, + biases, bias_is_broadcast, C, total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case 3: + dispatch(A, B, weight_scales, + biases, bias_is_broadcast, C, total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case 4: + dispatch(A, B, weight_scales, + biases, bias_is_broadcast, C, total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + default: TLLM_THROW("dispatchGemmConfig does not support stages %d", gemm_config.stages); break; + } +} + +// This overload will handle tensorop gemms. It is disabled via SFINAE for fp32. +// This overload is only enabled when T == WeightType. +template ::value +#if defined(ENABLE_FP8) + && !std::is_same::value && !std::is_same::value +#endif + && std::is_same::value>::type* = nullptr> +void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales, + GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, + int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, + float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr) +{ + switch (gemm_config.tile_config) + { + case cutlass_extensions::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: + TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); + if constexpr (arch::kMinComputeCapability >= 75) + { + dispatchGemmConfig, + cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + } + break; + case cutlass_extensions::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: + TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); + if constexpr (arch::kMinComputeCapability >= 75) + { + dispatchGemmConfig, + cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + } + break; + case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; + case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: + TLLM_THROW("GEMM config should have already been set by heuristic."); + break; + default: TLLM_THROW("Config is invalid for same type tensorop GEMM."); break; + } +} + +// Tensorop GEMM overload +// Overload for quantize MoE GEMMs. We disable some warp configs here since they will not be used and we can improve +// compile time +template ::value && !std::is_same::value>::type* = nullptr> +void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales, + GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, + int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, + float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr) +{ + switch (gemm_config.tile_config) + { + case cutlass_extensions::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: + TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); + if constexpr (arch::kMinComputeCapability >= 75) + { + dispatchGemmConfig, + cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + } + break; + case cutlass_extensions::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: + TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); + if constexpr (arch::kMinComputeCapability >= 75) + { + dispatchGemmConfig, + cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, + multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + } + break; + case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<128, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; + case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: + TLLM_THROW("GEMM config should have already been set by heuristic."); + break; + default: TLLM_THROW("Config is invalid for mixed type tensorop GEMM."); break; + } +} + +// This overload will handle tensorop gemms. +// This overload is only enabled when T == WeightType and T == __nv_fp8_e4m3 or __nv_fp8_e5m2 +#if defined(ENABLE_FP8) +template ::value || std::is_same::value) + && std::is_same::value>::type* = nullptr> +void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales, + GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, + int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, + float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr) +{ + switch (gemm_config.tile_config) + { + case cutlass_extensions::CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128: + dispatchGemmConfig, + cutlass::gemm::GemmShape<16, 64, 128>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; + case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: + TLLM_THROW("GEMM config should have already been set by heuristic."); + break; + default: TLLM_THROW("Config is invalid for same type tensorop GEMM."); break; + } +} +#endif + +// This overload will handle simt gemms. It is disabled via SFINAE for tensorop. +template ::value>::type* = nullptr> +void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales, + GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, + int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, + float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr) +{ + switch (gemm_config.tile_config) + { + case cutlass_extensions::CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: + dispatchGemmConfig, + cutlass::gemm::GemmShape<64, 64, 8>>(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, + use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + break; + case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; + case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: + TLLM_THROW("GEMM config should have already been set by heuristic."); + break; + default: TLLM_THROW("Unsupported config for float MoE gemm."); break; + } +} + +template +std::vector +MoeGemmRunner::getConfigs() const +{ + return getConfigs(sm_); +} + +template +std::vector MoeGemmRunner::getConfigs( + int sm) +{ + std::vector candidate_configs = getHopperConfigs(sm); + std::vector ampere_configs = getAmpereConfigs(sm); + std::copy(ampere_configs.begin(), ampere_configs.end(), std::back_inserter(candidate_configs)); + + return candidate_configs; +} + +template +std::vector +MoeGemmRunner::getAmpereConfigs(int sm) +{ + using tensorrt_llm::cutlass_extensions::CutlassGemmConfig; + static constexpr auto weight_only_flag + = std::is_same::value ? CutlassGemmConfig::NONE : CutlassGemmConfig::WEIGHT_ONLY; + static constexpr auto simt_only_flag + = std::is_same::value ? CutlassGemmConfig::SIMT_ONLY : CutlassGemmConfig::NONE; + static constexpr auto fp8_only_flag = use_fp8 ? CutlassGemmConfig::FP8_ONLY : CutlassGemmConfig::NONE; + int const max_split_k = 1; + int const grouped_gemm_flag = CutlassGemmConfig::GROUPED_GEMM; + int const enable_hopper = CutlassGemmConfig::NONE; + + auto config_type_param = static_cast( + weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper | fp8_only_flag); + + if (!kernels::cutlass_kernels::isValidAmpereMOESpecialisation()) + { + return {}; + } + + std::vector ampere_configs + = kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, config_type_param); + return ampere_configs; +} + +template +std::vector +MoeGemmRunner::getHopperConfigs(int sm) +{ + using tensorrt_llm::cutlass_extensions::CutlassGemmConfig; + static constexpr auto weight_only_flag + = std::is_same::value ? CutlassGemmConfig::NONE : CutlassGemmConfig::WEIGHT_ONLY; + static constexpr auto simt_only_flag + = std::is_same::value ? CutlassGemmConfig::SIMT_ONLY : CutlassGemmConfig::NONE; + int const max_split_k = 1; + int const grouped_gemm_flag = CutlassGemmConfig::GROUPED_GEMM; + int const enable_hopper = CutlassGemmConfig::HOPPER; + static constexpr auto fp8_only_flag = use_fp8 ? CutlassGemmConfig::FP8_ONLY : CutlassGemmConfig::NONE; + auto config_type_param = static_cast( + weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper | fp8_only_flag); + + if (!kernels::cutlass_kernels::isValidHopperMOESpecialisation()) + { + return {}; + } + + std::vector hopper_configs + = kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, config_type_param); + return hopper_configs; +} + +template +bool MoeGemmRunner::isHopperSpecialised( + cutlass_extensions::CutlassGemmConfig gemm_config) const +{ + bool config_is_sm90 = gemm_config.is_sm90; + return supportsHopperSpecialisation() && config_is_sm90; +} + +template +bool MoeGemmRunner::supportsHopperSpecialisation() const +{ + return sm_ == 90 && kernels::cutlass_kernels::isValidHopperMOESpecialisation(); +} + +template +int MoeGemmRunner::getSM() const +{ + return this->sm_; +} + +// currently support sm80 bf16/fp16 gate activation, only set predication tensor for m direction +template +bool MoeGemmRunner::supportsFusedGatedActivation( + bool is_gated_activation, int gemm_n, int gemm_k) const +{ + constexpr bool ENABLE_FUSED_GATED_ACTIVATION = true; + return is_gated_activation && std::is_same_v && !std::is_same_v && !use_fp8 + && (this->getSM() >= 80) && (gemm_k % 64 == 0) && (gemm_n % 64 == 0) && ENABLE_FUSED_GATED_ACTIVATION; +} + +template +bool MoeGemmRunner::isFusedGatedActivation( + cutlass_extensions::CutlassGemmConfig gemm_config, bool is_gated_activation, int gemm_n, int gemm_k) const +{ + return supportsFusedGatedActivation(is_gated_activation, gemm_n, gemm_k) && !gemm_config.is_sm90; +} + +template +MoeGemmRunner::MoeGemmRunner() +{ + int device{-1}; + tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device)); + sm_ = tensorrt_llm::common::getSMVersion(); + tensorrt_llm::common::check_cuda_error( + cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device)); +} + +template +template +void MoeGemmRunner::dispatchToArch(T const* A, + WeightType const* B, ScaleBiasType const* weight_scales, ScaleBiasType const* biases, bool bias_is_broadcast, + void* C_void, int64_t const* total_tokens_including_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows, + int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, + bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy) +{ + static_assert(std::is_same_v, + "Separate Scale/Bias type is not supported. This is assumed to be the gemm output type"); + + // For now we always cast this to output type. + // In the future this will vary based on what fusions are applied for FP8 + auto* C = reinterpret_cast(C_void); + + TLLM_CHECK_WITH_INFO( + sm_ >= 89 || !hopper_input.isValid(), "Hopper input information is set for non specialised implementation"); + TLLM_CHECK_WITH_INFO( + sm_ == 90 || !gemm_config.is_sm90, "Hopper configuration provided for non-Hopper architecture"); + + if (sm_ >= 75 && sm_ < 80) + { + dispatchMoeGemmToCutlass(A, B, weight_scales, + biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, + gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); + } + else if (sm_ >= 80 && sm_ < 90) + { + if constexpr (use_fp8) + { +#if defined(ENABLE_FP8) + static_assert(!std::is_same_v && !std::is_same_v, + "FP8 GEMM Output not supported"); +#endif + + TLLM_CHECK_WITH_INFO(sm_ == 89, "For sm >= 80 and < 90, fp8 is only supported with sm == 89"); + dispatchMoeGemmToCutlass(A, B, + weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k, + num_experts, gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream, + occupancy); + } + else + { + dispatchMoeGemmToCutlass(A, B, + weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k, + num_experts, gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream, + occupancy); + } + } + else if (sm_ >= 90) + { + if constexpr (kernels::cutlass_kernels::isValidHopperMOESpecialisation()) + { + + // We allow both SM90 and SM80 configurations to coexist because for some cases with small numbers of tokens + // SM80 is faster. We check here to see which is selected + if (gemm_config.is_sm90) + { + TLLM_CHECK_WITH_INFO(biases != nullptr || hopper_input.ptr_c == nullptr, + "Input biases and hopper input disagree if bias is enabled"); + TLLM_CHECK_WITH_INFO(hopper_input.isValid(), "Calling SM90 configuration with invalid hopper config"); + + // Select the appropriate fusion function + auto select_function = [&]() + { + switch (hopper_input.fusion) + { + case HopperGroupedGemmInput::EpilogueFusion::FINALIZE: + return &dispatchMoeGemmSelectTileShapeSM90; + case HopperGroupedGemmInput::EpilogueFusion::NONE: + return &dispatchMoeGemmSelectTileShapeSM90; + case HopperGroupedGemmInput::EpilogueFusion::ACTIVATION: + case HopperGroupedGemmInput::EpilogueFusion::GATED_ACTIVATION: + default: TLLM_THROW("Unimplemented fusion %d requested", (int) hopper_input.fusion); + }; + }; + auto selected_func = select_function(); + selected_func( + hopper_input, num_experts, gemm_config, multi_processor_count_, stream, occupancy, nullptr); + return; + } + + // Fallthrough to SM80 impl below + } + + // Do Ampere case instead + if constexpr (kernels::cutlass_kernels::isValidAmpereMOESpecialisation()) + { + TLLM_CHECK_WITH_INFO(!hopper_input.isValid(), + "Non-specialised Hopper implementation is being rerouted to fallback implementation so input " + "information is not required"); + TLLM_CHECK_WITH_INFO(!gemm_config.is_sm90, + "GEMM config is for SM90 configuration, but this configuration is not valid for Hppper"); + dispatchMoeGemmToCutlass(A, B, + weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k, + num_experts, gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream, + occupancy); + } + else + { + TLLM_THROW("Configuration expects SM80 but configuration is not supported by SM80 kernels"); + } + } + else + { + TLLM_THROW("Arch unsupported for MoE GEMM"); + } +} + +template +size_t MoeGemmRunner::getMaxWorkspaceSize(int num_experts) const +{ + if (num_experts != num_experts_) + { + TLLM_LOG_TRACE("Calling getMaxWorkspaceSize() with a new expert count %d vs %d", num_experts, num_experts_); + num_experts_ = num_experts; + gemm_workspace_size_ = calcMaxWorkspaceSize(num_experts); + } + return gemm_workspace_size_; +} + +template +size_t MoeGemmRunner::calcMaxWorkspaceSize(int num_experts) const +{ + if (!supportsHopperSpecialisation()) + { + return 0; + } + if constexpr (kernels::cutlass_kernels::isValidHopperMOESpecialisation()) + { + auto configs = getHopperConfigs(sm_); + size_t max_size = 0; + bool has_config = false; + for (auto conf : configs) + { +#define CALC_SIZE_FUSION(FUSION) \ + do \ + { \ + try \ + { \ + size_t size = calcMaxWorkspaceSizeSM90( \ + num_experts, conf, multi_processor_count_); \ + max_size = std::max(max_size, size); \ + has_config = true; \ + } \ + catch (tensorrt_llm::common::TllmException const& e) \ + { \ + TLLM_LOG_TRACE("Unsupported config skipped when calculating MOE workspace size"); \ + } \ + } while (0) + + CALC_SIZE_FUSION(HopperGroupedGemmInput::EpilogueFusion::NONE); + CALC_SIZE_FUSION(HopperGroupedGemmInput::EpilogueFusion::FINALIZE); + } + TLLM_CHECK_WITH_INFO(has_config, "Could not find valid config when calculating workspace size"); + return max_size; + } + else + { + TLLM_THROW("Attempting to calculate Hopper GEMM workspace size with unsupported weight combination"); + return 0; + } +} + +template +template +void MoeGemmRunner::runGemm(T const* A, WeightType const* B, + ScaleBiasType const* weight_scales, ScaleBiasType const* biases, bool bias_is_broadcast, void* C, + int64_t const* total_tokens_including_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows, + int64_t gemm_n, int64_t gemm_k, int num_experts, bool use_fused_moe, float const** alpha_scale_ptr_array, + cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf) +{ + dispatchToArch(A, B, weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, + hopper_input, total_rows, gemm_n, gemm_k, num_experts, chosen_conf, use_fused_moe, alpha_scale_ptr_array, + stream, nullptr); +} + +template +void MoeGemmRunner::moeGemmBiasAct(T const* A, WeightType const* B, + ScaleBiasType const* weight_scales, ScaleBiasType const* biases, bool bias_is_broadcast, void* C, + int64_t const* total_tokens_including_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows, + int64_t gemm_n, int64_t gemm_k, int num_experts, ActivationType activation_type, bool use_fused_moe, + float const** alpha_scale_ptr_array, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf) +{ + switch (activation_type) + { + case ActivationType::Relu: + runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, + alpha_scale_ptr_array, stream, chosen_conf); + break; + case ActivationType::Gelu: + runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, + alpha_scale_ptr_array, stream, chosen_conf); + break; + case ActivationType::Silu: + runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, + alpha_scale_ptr_array, stream, chosen_conf); + break; + case ActivationType::Identity: + runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, + alpha_scale_ptr_array, stream, chosen_conf); + break; + case ActivationType::Swiglu: + runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, + alpha_scale_ptr_array, stream, chosen_conf); + break; + case ActivationType::Geglu: + runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, + total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, + alpha_scale_ptr_array, stream, chosen_conf); + break; + case ActivationType::InvalidType: TLLM_THROW("Activation type for fpA_intB must be valid."); break; + default: TLLM_THROW("Invalid activation type."); break; + } +} + +template +void MoeGemmRunner::moeGemm(T const* A, WeightType const* B, + ScaleBiasType const* weight_scales, void* C, int64_t const* total_tokens_including_expert, + HopperGroupedGemmInput hopper_input, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, + bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, + cutlass_extensions::CutlassGemmConfig chosen_conf) +{ + runGemm(A, B, weight_scales, nullptr, true, C, total_tokens_including_expert, + hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, alpha_scale_ptr_array, stream, + chosen_conf); +} + +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template_sm90.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template_sm90.h new file mode 100644 index 00000000000..3efb42f41ef --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template_sm90.h @@ -0,0 +1,222 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Ignore CUTLASS warnings about type punning +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ + +#include "cutlass/array.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" + +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/tensor_ref.h" + +#include "cutlass_extensions/compute_occupancy.h" +#include "cutlass_extensions/epilogue_helpers.h" +#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" +#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h" +#include "cutlass_extensions/gemm/threadblock/default_mma.h" + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic pop +#endif // __GNUC__ + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" + +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h" + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h" +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h" + +#include +#include +#include +#include + +namespace tensorrt_llm +{ +using EpilogueFusion = HopperGroupedGemmInput::EpilogueFusion; + +template +void dispatchMoeGemmSelectBiasSM90(HopperGroupedGemmInput hopper_input, int num_experts, int multi_processor_count, + cudaStream_t stream, int* occupancy, size_t* workspace_size) +{ + static_assert(kernels::cutlass_kernels::isValidHopperMOESpecialisation(), + "Invalid hopper configuration invoked, fallback to Sm80"); + + TLLM_CHECK_WITH_INFO( + workspace_size || hopper_input.isValid(), "Hopper specialisation is missing additional input information"); + + // auto func = hopper_input.ptr_c ? + // kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper + // : + // kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper; + // TODO(dastokes) Re-enable bias when CUTLASS supports it + auto func = kernels::cutlass_kernels::sm90_generic_moe_gemm_kernelLauncher; + func(hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size); +} + +/* + 1x1x1 cluster shape is are supported for any tile shape. + + 2x1x1 cluster shape is only supported for when the M tile is at least 128. + + 1x2x1 cluster shape is only supported when the N tile is at least 128. + + 2x2x1 cluster shape is only supported when both the M and N tiles are at least 128. + + We make the above restrictions are to improve compilation speed in TRT-LLM by pruning kernels + that may not be very useful in practice. + */ +template +constexpr bool are_tile_shapes_supported() +{ + using namespace cute; + [[maybe_unused]] constexpr int cta_m = get<0>(CTAShape{}); + [[maybe_unused]] constexpr int cta_n = get<1>(CTAShape{}); + constexpr int cga_m = get<0>(ClusterShape{}); + constexpr int cga_n = get<1>(ClusterShape{}); + + if constexpr (cga_m == _1{} && cga_n == _1{}) + { + return true; + } + else if constexpr (cga_m == _2{} && cga_n == _1{} && cta_m >= _128{}) + { + return true; + } + else if constexpr (cga_m == _1{} && cga_n == _2{} && cta_n >= _128{}) + { + return true; + } + else if constexpr (cga_m == _2{} && cga_n == _2{} && cta_m >= _128{} && cta_n >= _128{}) + { + return true; + } + else + { + return false; + } +} + +template +void dispatchMoeGemmSelectClusterShapeSM90(HopperGroupedGemmInput hopper_input, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy, + size_t* workspace_size) +{ + using namespace cute; + switch (gemm_config.cluster_shape) + { +#define SHAPE_CASE(M, N, K) \ + case cutlass_extensions::ClusterShape::ClusterShape_##M##x##N##x##K: \ + { \ + using ClusterShape = Shape<_##M, _##N, _##K>; \ + if constexpr (are_tile_shapes_supported()) \ + { \ + dispatchMoeGemmSelectBiasSM90( \ + hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size); \ + break; \ + } \ + else \ + { \ + TLLM_THROW("Unsupported tile and cluster shape combination"); \ + } \ + } + + SHAPE_CASE(1, 1, 1) + SHAPE_CASE(1, 2, 1) + + SHAPE_CASE(2, 1, 1) + SHAPE_CASE(2, 2, 1) + +#undef SHAPE_CASE + default: TLLM_THROW("Unsupported config for MoE gemm."); + } +} // namespace tensorrt_llm + +template +void dispatchMoeGemmSelectTileShapeSM90(HopperGroupedGemmInput hopper_input, int num_experts, + cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy, + size_t* workspace_size) +{ + using namespace cute; + + switch (gemm_config.tile_config_sm90) + { +#define SHAPE_CASE(M, N, K) \ + case cutlass_extensions::CutlassTileConfigSM90::CtaShape##M##x##N##x##K##B: \ + { \ + constexpr int KtileBytes = K / sizeof(T); \ + using KTileDim = Int; \ + using TileShape = Shape<_##M, _##N, KTileDim>; \ + dispatchMoeGemmSelectClusterShapeSM90( \ + hopper_input, num_experts, gemm_config, multi_processor_count, stream, occupancy, workspace_size); \ + break; \ + } + + SHAPE_CASE(128, 16, 128) + SHAPE_CASE(128, 32, 128) + SHAPE_CASE(128, 64, 128) + SHAPE_CASE(128, 128, 128) + SHAPE_CASE(128, 256, 128) + SHAPE_CASE(256, 128, 128) + +#undef SHAPE_CASE + case cutlass_extensions::CutlassTileConfigSM90::Undefined: TLLM_THROW("GEMM config undefined."); break; + case cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic: + TLLM_THROW("GEMM config should have already been set by heuristic."); + break; + default: TLLM_THROW("Unsupported config for MoE gemm."); break; + } +} + +template +size_t calcMaxWorkspaceSizeSM90( + int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count) +{ + size_t count; + // Most of the values are ignored for WS size calculation. We reuse the function to reduce the template bloat + dispatchMoeGemmSelectTileShapeSM90( + HopperGroupedGemmInput{}, num_experts, gemm_config, multi_processor_count, cudaStream_t{0}, nullptr, &count); + return count; +} + +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h new file mode 100644 index 00000000000..959d0ea088c --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h @@ -0,0 +1,44 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/arch/mma_sm90.h" +#include "cutlass_extensions/epilogue_helpers.h" + +namespace tensorrt_llm::kernels::cutlass_kernels +{ + +// Hopper arch +template +constexpr bool isValidHopperMOESpecialisation() +{ +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + return cutlass::platform::is_same::value + && cutlass::platform::is_same::value; +#else + return false; // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED is set when Hopper kernels are enabled +#endif +} + +// Hopper arch +template +constexpr bool isValidAmpereMOESpecialisation() +{ + return true; // Default to true +} + +} // namespace tensorrt_llm::kernels::cutlass_kernels From 9602c2aac76d2655d4d9aa657e60accde1cfb51f Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Fri, 31 Jan 2025 00:39:47 +0800 Subject: [PATCH 12/12] keep the parts needed for moe_kernels (#3218) --- .../tensorrt_llm/common/CMakeLists.txt | 22 - .../3rdparty/tensorrt_llm/common/assert.cpp | 0 .../3rdparty/tensorrt_llm/common/assert.h | 92 ++ .../tensorrt_llm/common/cudaDriverWrapper.cpp | 187 ++++ .../tensorrt_llm/common/cudaDriverWrapper.h | 138 +++ .../tensorrt_llm/common/cudaFp8Utils.h | 239 +++++ .../tensorrt_llm/common/cudaProfilerUtils.cpp | 84 -- .../3rdparty/tensorrt_llm/common/cudaUtils.h | 641 +++++++++++++ .../common/customAllReduceUtils.h | 36 - .../3rdparty/tensorrt_llm/common/envUtils.cpp | 214 ----- .../3rdparty/tensorrt_llm/common/envUtils.h | 60 -- .../3rdparty/tensorrt_llm/common/logger.h | 190 ++++ .../3rdparty/tensorrt_llm/common/mathUtils.h | 37 - .../tensorrt_llm/common/memoryUtils.cu | 906 ------------------ .../tensorrt_llm/common/memoryUtils.h | 292 ------ .../3rdparty/tensorrt_llm/common/mpiUtils.cpp | 588 ------------ .../3rdparty/tensorrt_llm/common/nvtxUtils.h | 46 - .../3rdparty/tensorrt_llm/common/opUtils.cpp | 323 ------- .../3rdparty/tensorrt_llm/common/opUtils.h | 215 ----- .../tensorrt_llm/common/quantization.h | 358 +++++++ .../3rdparty/tensorrt_llm/common/stlUtils.h | 123 --- .../tensorrt_llm/common/stringUtils.h | 113 +++ .../tensorrt_llm/common/timestampUtils.cpp | 42 - .../{timestampUtils.h => tllmException.h} | 27 +- 24 files changed, 1983 insertions(+), 2990 deletions(-) delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/CMakeLists.txt mode change 100755 => 100644 sgl-kernel/3rdparty/tensorrt_llm/common/assert.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/assert.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaProfilerUtils.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/cudaUtils.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/customAllReduceUtils.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.cpp delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/logger.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/mathUtils.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.cu delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/mpiUtils.cpp delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/nvtxUtils.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.cpp delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/quantization.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/stlUtils.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.h delete mode 100644 sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.cpp rename sgl-kernel/3rdparty/tensorrt_llm/common/{timestampUtils.h => tllmException.h} (50%) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/CMakeLists.txt b/sgl-kernel/3rdparty/tensorrt_llm/common/CMakeLists.txt deleted file mode 100644 index e479b298db4..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/CMakeLists.txt +++ /dev/null @@ -1,22 +0,0 @@ -# -# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & -# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may not -# use this file except in compliance with the License. You may obtain a copy of -# the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations under -# the License. -# -file(GLOB SRCS *.cpp) -file(GLOB CU_SRCS *.cu) - -add_library(common_src OBJECT ${SRCS} ${CU_SRCS}) -set_property(TARGET common_src PROPERTY POSITION_INDEPENDENT_CODE ON) -set_property(TARGET common_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/assert.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/assert.cpp old mode 100755 new mode 100644 diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/assert.h b/sgl-kernel/3rdparty/tensorrt_llm/common/assert.h new file mode 100644 index 00000000000..7f51dbf1b41 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/assert.h @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/common/stringUtils.h" +#include "tensorrt_llm/common/tllmException.h" + +#include + +namespace tensorrt_llm::common +{ +[[noreturn]] inline void throwRuntimeError(char const* const file, int const line, std::string const& info = "") +{ + throw TllmException(file, line, fmtstr("[TensorRT-LLM][ERROR] Assertion failed: %s", info.c_str())); +} + +} // namespace tensorrt_llm::common + +class DebugConfig +{ +public: + static bool isCheckDebugEnabled(); +}; + +#if defined(_WIN32) +#define TLLM_LIKELY(x) (__assume((x) == 1), (x)) +#define TLLM_UNLIKELY(x) (__assume((x) == 0), (x)) +#else +#define TLLM_LIKELY(x) __builtin_expect((x), 1) +#define TLLM_UNLIKELY(x) __builtin_expect((x), 0) +#endif + +#define TLLM_CHECK(val) \ + do \ + { \ + TLLM_LIKELY(static_cast(val)) ? ((void) 0) \ + : tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, #val); \ + } while (0) + +#define TLLM_CHECK_WITH_INFO(val, info, ...) \ + do \ + { \ + TLLM_LIKELY(static_cast(val)) \ + ? ((void) 0) \ + : tensorrt_llm::common::throwRuntimeError( \ + __FILE__, __LINE__, tensorrt_llm::common::fmtstr(info, ##__VA_ARGS__)); \ + } while (0) + +#define TLLM_CHECK_DEBUG(val) \ + do \ + { \ + if (TLLM_UNLIKELY(DebugConfig::isCheckDebugEnabled())) \ + { \ + TLLM_LIKELY(static_cast(val)) ? ((void) 0) \ + : tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, #val); \ + } \ + } while (0) + +#define TLLM_CHECK_DEBUG_WITH_INFO(val, info, ...) \ + do \ + { \ + if (TLLM_UNLIKELY(DebugConfig::isCheckDebugEnabled())) \ + { \ + TLLM_LIKELY(static_cast(val)) \ + ? ((void) 0) \ + : tensorrt_llm::common::throwRuntimeError( \ + __FILE__, __LINE__, tensorrt_llm::common::fmtstr(info, ##__VA_ARGS__)); \ + } \ + } while (0) + +#define TLLM_THROW(...) \ + do \ + { \ + throw NEW_TLLM_EXCEPTION(__VA_ARGS__); \ + } while (0) + +#define TLLM_WRAP(ex) \ + NEW_TLLM_EXCEPTION("%s: %s", tensorrt_llm::common::TllmException::demangle(typeid(ex).name()).c_str(), ex.what()) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp new file mode 100644 index 00000000000..7eca46a1cab --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define CUDA_LIB_NAME "cuda" + +#if defined(_WIN32) +#include +#define dllOpen(name) LoadLibrary("nv" name ".dll") +#define dllClose(handle) FreeLibrary(static_cast(handle)) +#define dllGetSym(handle, name) static_cast(GetProcAddress(static_cast(handle), name)) +#else // For non-Windows platforms +#include +#define dllOpen(name) dlopen("lib" name ".so.1", RTLD_LAZY) +#define dllClose(handle) dlclose(handle) +#define dllGetSym(handle, name) dlsym(handle, name) +#endif // defined(_WIN32) + +#include "cudaDriverWrapper.h" +#include "tensorrt_llm/common/assert.h" +#include +#include + +namespace tensorrt_llm::common +{ + +std::shared_ptr CUDADriverWrapper::getInstance() +{ + static std::mutex mutex; + static std::weak_ptr instance; + std::shared_ptr result = instance.lock(); + if (result) + { + return result; + } + + std::lock_guard lock(mutex); + result = instance.lock(); + if (!result) + { + result = std::shared_ptr(new CUDADriverWrapper()); + instance = result; + } + return result; +} + +CUDADriverWrapper::CUDADriverWrapper() + : handle(dllOpen(CUDA_LIB_NAME)) +{ + + TLLM_CHECK_WITH_INFO(handle != nullptr, "CUDA driver library is not open correctly."); + + auto load_sym = [](void* handle, char const* name) + { + void* ret = dllGetSym(handle, name); + return ret; + }; + + *reinterpret_cast(&_cuGetErrorName) = load_sym(handle, "cuGetErrorName"); + *reinterpret_cast(&_cuGetErrorMessage) = load_sym(handle, "cuGetErrorMessage"); + *reinterpret_cast(&_cuFuncSetAttribute) = load_sym(handle, "cuFuncSetAttribute"); + *reinterpret_cast(&_cuLinkComplete) = load_sym(handle, "cuLinkComplete"); + *reinterpret_cast(&_cuModuleUnload) = load_sym(handle, "cuModuleUnload"); + *reinterpret_cast(&_cuLinkDestroy) = load_sym(handle, "cuLinkDestroy"); + *reinterpret_cast(&_cuModuleLoadData) = load_sym(handle, "cuModuleLoadData"); + *reinterpret_cast(&_cuLinkCreate) = load_sym(handle, "cuLinkCreate_v2"); + *reinterpret_cast(&_cuModuleGetFunction) = load_sym(handle, "cuModuleGetFunction"); + *reinterpret_cast(&_cuModuleGetGlobal) = load_sym(handle, "cuModuleGetGlobal_v2"); + *reinterpret_cast(&_cuLinkAddFile) = load_sym(handle, "cuLinkAddFile_v2"); + *reinterpret_cast(&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2"); + *reinterpret_cast(&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel"); + *reinterpret_cast(&_cuLaunchKernel) = load_sym(handle, "cuLaunchKernel"); + *reinterpret_cast(&_cuTensorMapEncodeTiled) = load_sym(handle, "cuTensorMapEncodeTiled"); + *reinterpret_cast(&_cuMemcpyDtoH) = load_sym(handle, "cuMemcpyDtoH_v2"); +} + +CUDADriverWrapper::~CUDADriverWrapper() +{ + dllClose(handle); +} + +CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, char const** pStr) const +{ + return (*_cuGetErrorName)(error, pStr); +} + +CUresult CUDADriverWrapper::cuGetErrorMessage(CUresult error, char const** pStr) const +{ + return (*_cuGetErrorMessage)(error, pStr); +} + +CUresult CUDADriverWrapper::cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const +{ + return (*_cuFuncSetAttribute)(hfunc, attrib, value); +} + +CUresult CUDADriverWrapper::cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const +{ + return (*_cuLinkComplete)(state, cubinOut, sizeOut); +} + +CUresult CUDADriverWrapper::cuModuleUnload(CUmodule hmod) const +{ + return (*_cuModuleUnload)(hmod); +} + +CUresult CUDADriverWrapper::cuLinkDestroy(CUlinkState state) const +{ + return (*_cuLinkDestroy)(state); +} + +CUresult CUDADriverWrapper::cuModuleLoadData(CUmodule* module, void const* image) const +{ + return (*_cuModuleLoadData)(module, image); +} + +CUresult CUDADriverWrapper::cuLinkCreate( + unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const +{ + return (*_cuLinkCreate)(numOptions, options, optionValues, stateOut); +} + +CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const +{ + return (*_cuModuleGetFunction)(hfunc, hmod, name); +} + +CUresult CUDADriverWrapper::cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const +{ + return (*_cuModuleGetGlobal)(dptr, bytes, hmod, name); +} + +CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, + unsigned int numOptions, CUjit_option* options, void** optionValues) const +{ + return (*_cuLinkAddFile)(state, type, path, numOptions, options, optionValues); +} + +CUresult CUDADriverWrapper::cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, + char const* name, unsigned int numOptions, CUjit_option* options, void** optionValues) const +{ + return (*_cuLinkAddData)(state, type, data, size, name, numOptions, options, optionValues); +} + +CUresult CUDADriverWrapper::cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, + unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, + unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const +{ + return (*_cuLaunchCooperativeKernel)( + f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams); +} + +CUresult CUDADriverWrapper::cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, + unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, + unsigned int sharedMemBytes, CUstream hStream, void** kernelParams, void** extra) const +{ + return (*_cuLaunchKernel)( + f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra); +} + +CUresult CUDADriverWrapper::cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, + cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const +{ + return (*_cuTensorMapEncodeTiled)(tensorMap, tensorDataType, tensorRank, globalAddress, globalDim, globalStrides, + boxDim, elementStrides, interleave, swizzle, l2Promotion, oobFill); +} + +CUresult CUDADriverWrapper::cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const +{ + return (*_cuMemcpyDtoH)(dstHost, srcDevice, ByteCount); +} + +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h new file mode 100644 index 00000000000..c4d470a85f0 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h @@ -0,0 +1,138 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CUDA_DRIVER_WRAPPER_H +#define CUDA_DRIVER_WRAPPER_H + +#include "tensorrt_llm/common/assert.h" +#include +#include +#include +#include + +namespace tensorrt_llm::common +{ + +class CUDADriverWrapper +{ +public: + static std::shared_ptr getInstance(); + + ~CUDADriverWrapper(); + CUDADriverWrapper(CUDADriverWrapper const&) = delete; + CUDADriverWrapper operator=(CUDADriverWrapper const&) = delete; + CUDADriverWrapper(CUDADriverWrapper&&) = delete; + CUDADriverWrapper operator=(CUDADriverWrapper&&) = delete; + + CUresult cuGetErrorName(CUresult error, char const** pStr) const; + + CUresult cuGetErrorMessage(CUresult error, char const** pStr) const; + + CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const; + + CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const; + + CUresult cuModuleUnload(CUmodule hmod) const; + + CUresult cuLinkDestroy(CUlinkState state) const; + + CUresult cuModuleLoadData(CUmodule* module, void const* image) const; + + CUresult cuLinkCreate( + unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const; + + CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const; + + CUresult cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const; + + CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, unsigned int numOptions, + CUjit_option* options, void** optionValues) const; + + CUresult cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, char const* name, + unsigned int numOptions, CUjit_option* options, void** optionValues) const; + + CUresult cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, + unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, + unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const; + + CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, + unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, + CUstream hStream, void** kernelParams, void** extra) const; + + CUresult cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, cuuint32_t tensorRank, + void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, cuuint32_t const* boxDim, + cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle, + CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const; + + CUresult cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const; + +private: + void* handle; + CUDADriverWrapper(); + + CUresult (*_cuGetErrorName)(CUresult, char const**); + CUresult (*_cuGetErrorMessage)(CUresult, char const**); + CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int); + CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*); + CUresult (*_cuModuleUnload)(CUmodule); + CUresult (*_cuLinkDestroy)(CUlinkState); + CUresult (*_cuLinkCreate)(unsigned int, CUjit_option*, void**, CUlinkState*); + CUresult (*_cuModuleLoadData)(CUmodule*, void const*); + CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, char const*); + CUresult (*_cuModuleGetGlobal)(CUdeviceptr*, size_t*, CUmodule, char const*); + CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, char const*, unsigned int, CUjit_option*, void**); + CUresult (*_cuLinkAddData)( + CUlinkState, CUjitInputType, void*, size_t, char const*, unsigned int, CUjit_option*, void**); + CUresult (*_cuLaunchCooperativeKernel)(CUfunction, unsigned int, unsigned int, unsigned int, unsigned int, + unsigned int, unsigned int, unsigned int, CUstream, void**); + CUresult (*_cuLaunchKernel)(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, + unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, + CUstream hStream, void** kernelParams, void** extra); + CUresult (*_cuTensorMapEncodeTiled)(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, + cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill); + CUresult (*_cuMemcpyDtoH)(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount); +}; + +template +void checkDriver( + T result, CUDADriverWrapper const& wrap, char const* const func, char const* const file, int const line) +{ + if (result) + { + char const* errorName = nullptr; + char const* errorMsg = nullptr; + wrap.cuGetErrorName(result, &errorName); + wrap.cuGetErrorMessage(result, &errorMsg); + throw TllmException( + file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA driver error in %s: %s: %s", func, errorName, errorMsg)); + } +} + +} // namespace tensorrt_llm::common + +/* + * Macros compliant with TensorRT coding conventions + */ +#define TLLM_CU_CHECK(stat) \ + do \ + { \ + tensorrt_llm::common::checkDriver( \ + (stat), *tensorrt_llm::common::CUDADriverWrapper::getInstance(), #stat, __FILE__, __LINE__); \ + } while (0) + +#endif // CUDA_DRIVER_WRAPPER_H diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.h new file mode 100644 index 00000000000..aa93b55a579 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.h @@ -0,0 +1,239 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#ifdef ENABLE_FP8 +#include +#include +#include + +#define FP8_MHA +#define FUSE_GEMM_ACT +#define FP8_GEMM_OUTPUT_QUANT_DISABLE + +#ifdef FUSE_GEMM_ACT +#define USE_QGMMA +#endif + +namespace tensorrt_llm +{ +namespace common +{ + +constexpr float FP8_E4M3_MAX = 448.0f; + +enum QuantizeMode +{ + PER_CHANNEL, + PER_TENSOR, + PER_CHANNEL_WEIGHT_PER_TENSOR_ACT, + PER_TOKEN, +}; + +// Packed Data Type +typedef struct __CUDA_ALIGN__(32) +{ + float array[8]; +} float8; + +typedef struct __CUDA_ALIGN__(16) +{ + half array[8]; +} half8; + +typedef struct __CUDA_ALIGN__(8) +{ + half2 array[2]; +} half2_2; + +typedef struct __CUDA_ALIGN__(8) +{ + half array[4]; +} half_4; + +#ifdef ENABLE_BF16 +typedef struct __CUDA_ALIGN__(4) +{ + __nv_bfloat16 array[2]; +} __nv_bfloat16_2; + +typedef struct __CUDA_ALIGN__(8) +{ + __nv_bfloat162 x, y; +} __nv_bfloat162_2_xy; + +typedef struct __CUDA_ALIGN__(8) +{ + __nv_bfloat16 array[4]; +} __nv_bfloat164; + +typedef struct __CUDA_ALIGN__(8) +{ + __nv_bfloat162 array[2]; +} __nv_bfloat162_2; + +typedef struct __CUDA_ALIGN__(16) +{ + __nv_bfloat16 array[8]; +} __nv_bfloat168; + +typedef struct __CUDA_ALIGN__(16) +{ + __nv_bfloat162 array[4]; +} __nv_bfloat162_4; + +typedef struct __CUDA_ALIGN__(32) +{ + __nv_bfloat16 array[16]; +} __nv_bfloat1616; +#endif + +#ifdef ENABLE_FP8 +typedef struct __CUDA_ALIGN__(2) +{ + __nv_fp8_e4m3 array[2]; +} __nv_fp8_2_e4m3; + +typedef struct __CUDA_ALIGN__(4) +{ + __nv_fp8_e4m3 array[4]; +} __nv_fp8_4_e4m3; + +typedef struct __CUDA_ALIGN__(4) +{ + __nv_fp8x2_e4m3 array[2]; +} __nv_fp8x2_x2_e4m3; + +typedef struct __CUDA_ALIGN__(8) +{ + __nv_fp8_e4m3 array[8]; +} __nv_fp8_8_e4m3; + +typedef struct __CUDA_ALIGN__(8) +{ + __nv_fp8x2_e4m3 array[4]; +} __nv_fp8x2_x4_e4m3; + +typedef struct __CUDA_ALIGN__(16) +{ + __nv_fp8_e4m3 array[16]; +} __nv_fp8x16_e4m3; +#endif + +// only BF16 and FP8 +template +struct PackType +{ + using type = float; +}; + +#ifdef ENABLE_BF16 +template <> +struct PackType<__nv_bfloat16, 2> +{ + using type = __nv_bfloat16_2; +}; + +template <> +struct PackType<__nv_bfloat16, 4> +{ + using type = __nv_bfloat164; +}; + +template <> +struct PackType<__nv_bfloat16, 8> +{ + using type = __nv_bfloat168; +}; +#endif + +#ifdef ENABLE_FP8 +template <> +struct PackType<__nv_fp8_e4m3, 2> +{ + using type = __nv_fp8_2_e4m3; +}; + +template <> +struct PackType<__nv_fp8_e4m3, 4> +{ + using type = __nv_fp8_4_e4m3; +}; + +template <> +struct PackType<__nv_fp8_e4m3, 8> +{ + using type = __nv_fp8_8_e4m3; +}; +#endif + +__inline__ __device__ void fp8x4_e4m3_to_bfloat2(__nv_bfloat162* out1, __nv_bfloat162* out2, __nv_fp8x4_e4m3 const* in) +{ + const char4 tmp_val = reinterpret_cast(in)[0]; + *out1 = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0], + (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]); + *out2 = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.z)[0], + (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.w)[0]); +} + +__inline__ __device__ __nv_bfloat162 fp8x2_e4m3_to_bfloat2(__nv_fp8x2_e4m3 const* in) +{ + const char2 tmp_val = reinterpret_cast(in)[0]; + __nv_bfloat162 out = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0], + (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]); + return out; +} + +__inline__ __device__ void fp8x4_e4m3_to_half2(half2* out1, half2* out2, __nv_fp8x4_e4m3 const* in) +{ + const char4 tmp_val = reinterpret_cast(in)[0]; + *out1 = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0], + (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]); + *out2 = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.z)[0], + (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.w)[0]); +} + +__inline__ __device__ half2 fp8x2_e4m3_to_half2(__nv_fp8x2_e4m3 const* in) +{ + const char2 tmp_val = reinterpret_cast(in)[0]; + half2 out = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0], + (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]); + return out; +} + +template +void invokeQuantizeMatrix(T_OUT* output, T_S const* input_qua_amax_ptr, T_IN const* input, int64_t numel, int64_t lda, + QuantizeMode quantize_mode, cudaStream_t stream); + +template +void invokeDequantizeMatrix(T_OUT* output, T_S const* input_qua_amax_ptr, T_IN const* input, int64_t numel, int64_t lda, + QuantizeMode quantize_mode, cudaStream_t stream); + +template +void invokeFakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel, cudaStream_t stream); + +template +void invokeComputeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t k, const int64_t lda, + QuantizeMode quantize_mode, cudaStream_t stream); + +template +void invokeComputeScalesAndQuantizeMatrix(T_OUT* output, T_S* quant_ptr, const T_IN* weights, const int64_t numel, + const int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream); + +} // namespace common +} // namespace tensorrt_llm +#endif // ENABLE_FP8 diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaProfilerUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaProfilerUtils.cpp deleted file mode 100644 index 5576fe782fa..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaProfilerUtils.cpp +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/common/cudaProfilerUtils.h" -#include "tensorrt_llm/common/logger.h" -#include "tensorrt_llm/common/stringUtils.h" -#include -#include - -namespace -{ - -std::tuple, std::unordered_set> populateIterationIndexesImpl( - std::string const& envVarName) -{ - auto envVarVal = std::getenv(envVarName.c_str()); - auto envVarValStr = std::string{envVarVal != nullptr ? envVarVal : ""}; - auto values = tensorrt_llm::common::str2set(envVarValStr, ','); - std::unordered_set startSet; - std::unordered_set endSet; - for (std::string const& value : values) - { - size_t dashIdx = value.find("-"); - if (dashIdx != std::string::npos) - { - int32_t start = std::stoi(value.substr(0, dashIdx)); - startSet.insert(start); - int32_t end = std::stoi(value.substr(dashIdx + 1)); - endSet.insert(end); - } - else - { - int32_t start_end = std::stoi(value); - startSet.insert(start_end); - endSet.insert(start_end); - } - } - - return std::make_pair(startSet, endSet); -} - -} // namespace - -namespace tensorrt_llm::common -{ - -std::pair, std::unordered_set> populateIterationIndexes( - std::string const& envVarName, std::optional const& legacyEnvVarName) -{ - auto [profileIterIdxs, stopIterIdxs] = populateIterationIndexesImpl(envVarName); - - // If empty, try to use legacy env var name - if (legacyEnvVarName && profileIterIdxs.empty() && stopIterIdxs.empty()) - { - std::tie(profileIterIdxs, stopIterIdxs) = populateIterationIndexesImpl(legacyEnvVarName.value()); - - if (!profileIterIdxs.empty() || !stopIterIdxs.empty()) - { - TLLM_LOG_WARNING( - "Using deprecated environment variable %s to specify cudaProfiler start and stop iterations. " - "Please " - "use %s " - "instead.", - legacyEnvVarName.value().c_str(), envVarName.c_str()); - } - } - - return std::make_pair(profileIterIdxs, stopIterIdxs); -} - -} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaUtils.h new file mode 100644 index 00000000000..13ee3367e97 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaUtils.h @@ -0,0 +1,641 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "tensorrt_llm/common/cudaBf16Wrapper.h" +#include "tensorrt_llm/common/cudaDriverWrapper.h" +#include "tensorrt_llm/common/cudaFp8Utils.h" +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/common/tllmException.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifndef _WIN32 // Linux +#include +#endif // not WIN32 +#include +#ifdef _WIN32 // Windows +#include +#undef ERROR // A Windows header file defines ERROR as 0, but it's used in our logger.h enum. Logging breaks without + // this undef. +#endif // WIN32 + +namespace tensorrt_llm::common +{ + +// workspace for cublas gemm : 32MB +#define CUBLAS_WORKSPACE_SIZE 33554432 + +typedef struct __align__(4) +{ + half x, y, z, w; +} + +half4; + +/* **************************** type definition ***************************** */ + +enum CublasDataType +{ + FLOAT_DATATYPE = 0, + HALF_DATATYPE = 1, + BFLOAT16_DATATYPE = 2, + INT8_DATATYPE = 3, + FP8_DATATYPE = 4 +}; + +enum TRTLLMCudaDataType +{ + FP32 = 0, + FP16 = 1, + BF16 = 2, + INT8 = 3, + FP8 = 4 +}; + +enum class OperationType +{ + FP32, + FP16, + BF16, + INT8, + FP8 +}; + +/* **************************** debug tools ********************************* */ +static char const* _cudaGetErrorEnum(cudaError_t error) +{ + return cudaGetErrorString(error); +} + +static char const* _cudaGetErrorEnum(cublasStatus_t error) +{ + switch (error) + { + case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS"; + + case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED"; + + case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED"; + + case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE"; + + case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH"; + + case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR"; + + case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; + + case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; + + case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED"; + + case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR"; + } + return ""; +} + +template +void check(T result, char const* const func, char const* const file, int const line) +{ + if (result) + { + throw TllmException( + file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA runtime error in %s: %s", func, _cudaGetErrorEnum(result))); + } +} + +template +void checkEx(T result, std::initializer_list const& validReturns, char const* const func, char const* const file, + int const line) +{ + if (std::all_of(std::begin(validReturns), std::end(validReturns), [&result](T const& t) { return t != result; })) + { + throw TllmException( + file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA runtime error in %s: %s", func, _cudaGetErrorEnum(result))); + } +} + +#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__) +#define check_cuda_error_2(val, file, line) check((val), #val, file, line) + +inline std::optional isCudaLaunchBlocking() +{ + static bool firstCall = true; + static std::optional result = std::nullopt; + + if (firstCall) + { + char const* env = std::getenv("CUDA_LAUNCH_BLOCKING"); + if (env != nullptr && std::string(env) == "1") + { + result = true; + } + else if (env != nullptr && std::string(env) == "0") + { + result = false; + } + firstCall = false; + } + + return result; +} + +inline bool doCheckError() +{ + auto const cudaLaunchBlocking = isCudaLaunchBlocking(); +#ifndef NDEBUG + bool const checkError = cudaLaunchBlocking.value_or(true); +#else + bool const checkError = cudaLaunchBlocking.value_or(false); +#endif + + return checkError; +} + +inline void syncAndCheck(char const* const file, int const line) +{ + if (doCheckError()) + { + cudaDeviceSynchronize(); + check(cudaGetLastError(), "cudaGetLastError", file, line); + } +} + +#define sync_check_cuda_error() tensorrt_llm::common::syncAndCheck(__FILE__, __LINE__) + +#define PRINT_FUNC_NAME_() \ + do \ + { \ + std::cout << "[TensorRT-LLM][CALL] " << __FUNCTION__ << " " << std::endl; \ + } while (0) + +// clang-format off +template struct packed_type; +template <> struct packed_type { using type = float; }; // we don't need to pack float by default +template <> struct packed_type { using type = half2; }; + +#ifdef ENABLE_BF16 +template<> +struct packed_type<__nv_bfloat16> { + using type = __nv_bfloat162; +}; +#endif + +#ifdef ENABLE_FP8 +template<> +struct packed_type<__nv_fp8_e4m3> { + using type = __nv_fp8x2_e4m3; +}; +#endif + +template struct num_elems; +template <> struct num_elems { static constexpr int value = 1; }; +template <> struct num_elems { static constexpr int value = 2; }; +template <> struct num_elems { static constexpr int value = 4; }; +template <> struct num_elems { static constexpr int value = 1; }; +template <> struct num_elems { static constexpr int value = 2; }; +#ifdef ENABLE_BF16 +template <> struct num_elems<__nv_bfloat16> { static constexpr int value = 1; }; +template <> struct num_elems<__nv_bfloat162> { static constexpr int value = 2; }; +#endif +#ifdef ENABLE_FP8 +template <> struct num_elems<__nv_fp8_e4m3> { static constexpr int value = 1; }; +template <> struct num_elems<__nv_fp8x2_e4m3> { static constexpr int value = 2; }; +#endif + +template struct packed_as; +template struct packed_as { using type = T; }; +template<> struct packed_as { using type = half2; }; +template<> struct packed_as { using type = float2; }; +template<> struct packed_as { using type = int16_t; }; +template<> struct packed_as { using type = int2; }; +template<> struct packed_as { using type = half; }; +template<> struct packed_as { using type = float; }; +#ifdef ENABLE_BF16 +template<> struct packed_as<__nv_bfloat16, 2> { using type = __nv_bfloat162; }; +template<> struct packed_as<__nv_bfloat162, 1> { using type = __nv_bfloat16; }; +#endif +#ifdef ENABLE_FP8 +template<> struct packed_as<__nv_fp8_e4m3, 2> { using type = __nv_fp8x2_e4m3; }; +template<> struct packed_as<__nv_fp8x2_e4m3, 1> { using type = __nv_fp8_e4m3; }; +template<> struct packed_as<__nv_fp8_e5m2, 2> { using type = __nv_fp8x2_e5m2; }; +template<> struct packed_as<__nv_fp8x2_e5m2, 1> { using type = __nv_fp8_e5m2; }; +#endif + +inline __device__ float2 operator*(float2 a, float2 b) { return make_float2(a.x * b.x, a.y * b.y); } +inline __device__ float2 operator+(float2 a, float2 b) { return make_float2(a.x + b.x, a.y + b.y); } +inline __device__ float2 operator-(float2 a, float2 b) { return make_float2(a.x - b.x, a.y - b.y); } + +inline __device__ float2 operator*(float2 a, float b) { return make_float2(a.x * b, a.y * b); } +inline __device__ float2 operator+(float2 a, float b) { return make_float2(a.x + b, a.y + b); } +inline __device__ float2 operator-(float2 a, float b) { return make_float2(a.x - b, a.y - b); } + +// clang-format on + +template +struct CudaDataType +{ +}; + +template <> +struct CudaDataType +{ + static constexpr cudaDataType_t value = cudaDataType::CUDA_R_32F; +}; + +template <> +struct CudaDataType +{ + static constexpr cudaDataType_t value = cudaDataType::CUDA_R_16F; +}; + +#ifdef ENABLE_BF16 +template <> +struct CudaDataType<__nv_bfloat16> +{ + static constexpr cudaDataType_t value = cudaDataType::CUDA_R_16BF; +}; +#endif + +inline int getSMVersion() +{ + int device{-1}; + check_cuda_error(cudaGetDevice(&device)); + int sm_major = 0; + int sm_minor = 0; + check_cuda_error(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device)); + check_cuda_error(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device)); + return sm_major * 10 + sm_minor; +} + +inline int getDevice() +{ + int current_dev_id = 0; + check_cuda_error(cudaGetDevice(¤t_dev_id)); + return current_dev_id; +} + +inline int getDeviceCount() +{ + int count = 0; + check_cuda_error(cudaGetDeviceCount(&count)); + return count; +} + +/// @brief Identifies the memory type of the given pointer. +template +cudaMemoryType getPtrCudaMemoryType(T* ptr) +{ + cudaPointerAttributes attributes{}; + check_cuda_error(cudaPointerGetAttributes(&attributes, ptr)); + return attributes.type; +} + +/// Get the memory info +/// \return The free and total amount of memory in bytes +inline std::tuple getDeviceMemoryInfo(bool const useUvm) +{ + if (useUvm) + { + size_t freeSysMem = 0; + size_t totalSysMem = 0; +#ifndef _WIN32 // Linux + struct sysinfo info + { + }; + + sysinfo(&info); + totalSysMem = info.totalram * info.mem_unit; + freeSysMem = info.freeram * info.mem_unit; +#else // Windows + MEMORYSTATUSEX memInfo; + memInfo.dwLength = sizeof(memInfo); + GlobalMemoryStatusEx(&memInfo); + totalSysMem = memInfo.ullTotalPhys; + freeSysMem = memInfo.ullAvailPhys; +#endif // WIN32 + + TLLM_LOG_INFO("Using UVM based system memory for KV cache, total memory %0.2f GB, available memory %0.2f GB", + ((double) totalSysMem / 1e9), ((double) freeSysMem / 1e9)); + return {freeSysMem, totalSysMem}; + } + + size_t free = 0; + size_t total = 0; + check_cuda_error(cudaMemGetInfo(&free, &total)); + TLLM_LOG_DEBUG("Using GPU memory for KV cache, total memory %0.2f GB, available memory %0.2f GB", + ((double) total / 1e9), ((double) free / 1e9)); + return {free, total}; +} + +/// @brief Gets the memory allocation granularity for the current device. +/// +/// @return size_t The size of the smallest difference in memory size supported by the current device. +inline size_t getAllocationGranularity() +{ + auto const currentDevice = getDevice(); + ::CUmemAllocationProp prop = {}; + + prop.type = ::CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = ::CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = currentDevice; + prop.requestedHandleTypes = ::CU_MEM_HANDLE_TYPE_NONE; + + // Get the minimum granularity supported for allocation with cuMemCreate() + size_t granularity = 0; + TLLM_CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); + return granularity; +} + +inline int getMultiProcessorCount() +{ + int device_id = 0; + int multi_processor_count = 0; + check_cuda_error(cudaGetDevice(&device_id)); + check_cuda_error(cudaDeviceGetAttribute(&multi_processor_count, cudaDevAttrMultiProcessorCount, device_id)); + return multi_processor_count; +} + +inline int getMaxSharedMemoryPerBlockOptin() +{ + int device_id = 0; + int max_shared_memory_per_block = 0; + check_cuda_error(cudaGetDevice(&device_id)); + check_cuda_error( + cudaDeviceGetAttribute(&max_shared_memory_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device_id)); + return max_shared_memory_per_block; +} + +template +inline size_t divUp(const T1& a, const T2& n) +{ + auto const tmp_a = static_cast(a); + auto const tmp_n = static_cast(n); + return (tmp_a + tmp_n - 1) / tmp_n; +} + +inline int roundUp(int a, int n) +{ + return divUp(a, n) * n; +} + +template ::value>, + typename = std::enable_if_t::value>> +auto constexpr ceilDiv(T numerator, U denominator) +{ + return (numerator + denominator - 1) / denominator; +} + +template +void printAbsMean(T const* buf, uint64_t size, cudaStream_t stream, std::string name = "") +{ + if (buf == nullptr) + { + TLLM_LOG_WARNING("%s is an nullptr, skip!", name.c_str()); + return; + } + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); + T* h_tmp = new T[size]; + cudaMemcpyAsync(h_tmp, buf, sizeof(T) * size, cudaMemcpyDeviceToHost, stream); + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); + double sum = 0.0f; + uint64_t zero_count = 0; + float max_val = -1e10; + bool find_inf = false; + for (uint64_t i = 0; i < size; i++) + { + if (std::isinf((float) (h_tmp[i]))) + { + find_inf = true; + continue; + } + sum += abs((double) h_tmp[i]); + if ((float) h_tmp[i] == 0.0f) + { + zero_count++; + } + max_val = max_val > abs(float(h_tmp[i])) ? max_val : abs(float(h_tmp[i])); + } + TLLM_LOG_INFO("%20s size: %u, abs mean: %f, abs sum: %f, abs max: %f, find inf: %s", name.c_str(), size, sum / size, + sum, max_val, find_inf ? "true" : "false"); + delete[] h_tmp; + cudaDeviceSynchronize(); + check_cuda_error(cudaGetLastError()); +} + +template +void printToStream(T const* result, int const size, FILE* strm) +{ + bool const split_rows = (strm == stdout); + if (result == nullptr) + { + TLLM_LOG_WARNING("It is an nullptr, skip! \n"); + return; + } + T* tmp = reinterpret_cast(malloc(sizeof(T) * size)); + check_cuda_error(cudaMemcpy(tmp, result, sizeof(T) * size, cudaMemcpyDeviceToHost)); + for (int i = 0; i < size; ++i) + { + fprintf(strm, "%f, ", static_cast(tmp[i])); + if (split_rows && ((i + 1) % 10) == 0) + fprintf(strm, "\n"); + } + if (!split_rows || (size % 10) != 0) + { + fprintf(strm, "\n"); + } + free(tmp); +} + +template +void printToScreen(T const* result, int const size) +{ + printToStream(result, size, stdout); +} + +template +void print2dToStream(T const* result, int const r, int const c, int const stride, FILE* strm) +{ + if (result == nullptr) + { + TLLM_LOG_WARNING("It is an nullptr, skip! \n"); + return; + } + for (int ri = 0; ri < r; ++ri) + { + T const* ptr = result + ri * stride; + printToStream(ptr, c, strm); + } + fprintf(strm, "\n"); +} + +template +void print2dToScreen(T const* result, int const r, int const c, int const stride) +{ + print2dToStream(result, r, c, stride, stdout); +} + +template +void print2dToFile(std::string fname, T const* result, int const r, int const c, int const stride) +{ + FILE* fp = fopen(fname.c_str(), "wt"); + if (fp != nullptr) + { + print2dToStream(result, r, c, stride, fp); + fclose(fp); + } +} + +inline void print_float_(float x) +{ + printf("%7.3f ", x); +} + +inline void print_element_(float x) +{ + print_float_(x); +} + +inline void print_element_(half x) +{ + print_float_((float) x); +} + +#ifdef ENABLE_BF16 +inline void print_element_(__nv_bfloat16 x) +{ + print_float_((float) x); +} +#endif + +#ifdef ENABLE_FP8 +inline void print_element_(__nv_fp8_e4m3 x) +{ + print_float_((float) x); +} +#endif + +inline void print_element_(uint32_t ul) +{ + printf("%7" PRIu32, ul); +} + +inline void print_element_(uint64_t ull) +{ + printf("%7" PRIu64, ull); +} + +inline void print_element_(int32_t il) +{ + printf("%7" PRId32, il); +} + +inline void print_element_(int64_t ill) +{ + printf("%7" PRId64, ill); +} + +template +inline void printMatrix(T const* ptr, int m, int k, int stride, bool is_device_ptr) +{ + T* tmp; + if (is_device_ptr) + { + // k < stride ; stride = col-dimension. + tmp = reinterpret_cast(malloc(m * stride * sizeof(T))); + check_cuda_error(cudaMemcpy(tmp, ptr, sizeof(T) * m * stride, cudaMemcpyDeviceToHost)); + cudaDeviceSynchronize(); + } + else + { + tmp = const_cast(ptr); + } + + for (int ii = -1; ii < m; ++ii) + { + if (ii >= 0) + { + printf("%07d ", ii); + } + else + { + printf(" "); + } + + for (int jj = 0; jj < k; jj += 1) + { + if (ii >= 0) + { + print_element_(tmp[ii * stride + jj]); + } + else + { + printf("%7d ", jj); + } + } + printf("\n"); + } + if (is_device_ptr) + { + free(tmp); + } +} + +template void printMatrix(float const* ptr, int m, int k, int stride, bool is_device_ptr); +template void printMatrix(half const* ptr, int m, int k, int stride, bool is_device_ptr); +#ifdef ENABLE_BF16 +template void printMatrix(__nv_bfloat16 const* ptr, int m, int k, int stride, bool is_device_ptr); +#endif +#ifdef ENABLE_FP8 +template void printMatrix(__nv_fp8_e4m3 const* ptr, int m, int k, int stride, bool is_device_ptr); +#endif +template void printMatrix(uint32_t const* ptr, int m, int k, int stride, bool is_device_ptr); +template void printMatrix(uint64_t const* ptr, int m, int k, int stride, bool is_device_ptr); +template void printMatrix(int const* ptr, int m, int k, int stride, bool is_device_ptr); + +} // namespace tensorrt_llm::common + +/* + * Macros compliant with TensorRT coding conventions + */ +#define TLLM_CUDA_CHECK(stat) \ + do \ + { \ + tensorrt_llm::common::check((stat), #stat, __FILE__, __LINE__); \ + } while (0) + +// We use singleton memory pool and the order of destructors depends on the compiler implementation. We find that the +// cudaFree/cudaFreeHost is called after cudaruntime destruction on Windows. There will be an cudaErrorCudartUnloading +// error. However, it is safe to ignore this error because the cuda runtime is already exited, we are no more worried +// about the memory leaks. +#define TLLM_CUDA_CHECK_FREE_RESOURCE(stat) \ + do \ + { \ + tensorrt_llm::common::checkEx((stat), {cudaSuccess, cudaErrorCudartUnloading}, #stat, __FILE__, __LINE__); \ + } while (0) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/customAllReduceUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/customAllReduceUtils.h deleted file mode 100644 index d7bf43b4075..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/customAllReduceUtils.h +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace tensorrt_llm::utils::customAllReduceUtils -{ - -constexpr size_t NUM_POINTERS_PER_RANK = 7; - -// WARNING: MUST BE KEPT IN SYNC with tensorrt_llm/plugin/plugin.py -inline size_t getMaxRequiredWorkspaceSize(int worldSize) noexcept -{ - if (worldSize <= 2) - { - return 16 * 1000 * 1000; - } - return 8 * 1000 * 1000; -} - -} // namespace tensorrt_llm::utils::customAllReduceUtils diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.cpp deleted file mode 100644 index 64d3d44acb8..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.cpp +++ /dev/null @@ -1,214 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "envUtils.h" -#include "tensorrt_llm/common/cudaUtils.h" -#include "tensorrt_llm/common/logger.h" -#include - -namespace tensorrt_llm::common -{ - -std::optional getIntEnv(char const* name) -{ - char const* const env = std::getenv(name); - if (env == nullptr) - { - return std::nullopt; - } - int32_t const val = std::stoi(env); - if (val <= 0) - { - return std::nullopt; - } - return {val}; -}; - -// Returns true if the env variable exists and is set to "1" -static bool getBoolEnv(char const* name) -{ - char const* env = std::getenv(name); - return env && env[0] == '1' && env[1] == '\0'; -} - -// XQA kernels (optimized kernels for generation phase). -bool forceXQAKernels() -{ - static bool const forceXQA = (getIntEnv("TRTLLM_FORCE_XQA").value_or(0) != 0); - return forceXQA; -} - -std::optional getEnvEnableXQAJIT() -{ - static bool init = false; - static bool exists = false; - static bool enableXQAJIT = false; - if (!init) - { - init = true; - char const* enable_xqa_jit_var = std::getenv("TRTLLM_ENABLE_XQA_JIT"); - if (enable_xqa_jit_var) - { - exists = true; - if (enable_xqa_jit_var[0] == '1' && enable_xqa_jit_var[1] == '\0') - { - enableXQAJIT = true; - } - } - } - if (exists) - { - return enableXQAJIT; - } - else - { - return std::nullopt; - } -} - -// Tune the number of blocks per sequence for accuracy/performance purpose. -bool getEnvMmhaMultiblockDebug() -{ - static bool init = false; - static bool forceMmhaMaxSeqLenTile = false; - if (!init) - { - init = true; - char const* enable_mmha_debug_var = std::getenv("TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG"); - if (enable_mmha_debug_var) - { - if (enable_mmha_debug_var[0] == '1' && enable_mmha_debug_var[1] == '\0') - { - forceMmhaMaxSeqLenTile = true; - } - } - } - return forceMmhaMaxSeqLenTile; -} - -int getEnvMmhaBlocksPerSequence() -{ - static bool init = false; - static int mmhaBlocksPerSequence = 0; - if (!init) - { - init = true; - char const* mmhaBlocksPerSequenceEnv = std::getenv("TRTLLM_MMHA_BLOCKS_PER_SEQUENCE"); - if (mmhaBlocksPerSequenceEnv) - { - mmhaBlocksPerSequence = std::atoi(mmhaBlocksPerSequenceEnv); - if (mmhaBlocksPerSequence <= 0) - { - TLLM_LOG_WARNING("Invalid value for TRTLLM_MMHA_BLOCKS_PER_SEQUENCE. Will use default values instead!"); - } - } - } - return mmhaBlocksPerSequence; -} - -int getEnvMmhaKernelBlockSize() -{ - static bool init = false; - static int mmhaKernelBlockSize = 0; - if (!init) - { - init = true; - char const* mmhaKernelBlockSizeEnv = std::getenv("TRTLLM_MMHA_KERNEL_BLOCK_SIZE"); - if (mmhaKernelBlockSizeEnv) - { - mmhaKernelBlockSize = std::atoi(mmhaKernelBlockSizeEnv); - if (mmhaKernelBlockSize <= 0) - { - TLLM_LOG_WARNING("Invalid value for TRTLLM_MMHA_KERNEL_BLOCK_SIZE. Will use default values instead!"); - } - } - } - return mmhaKernelBlockSize; -} - -bool getEnvEnablePDL() -{ - static bool init = false; - static bool enablePDL = false; - if (!init) - { - init = true; - // PDL only available when arch >= 90 - if (getSMVersion() >= 90) - { - // PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1` - enablePDL = getBoolEnv("TRTLLM_ENABLE_PDL"); - } - } - return enablePDL; -} - -bool getEnvUseUCXKvCache() -{ - static bool const useUCXKVCache = getBoolEnv("TRTLLM_USE_UCX_KVCACHE"); - return useUCXKVCache; -} - -std::string getEnvUCXInterface() -{ - static bool init = false; - static std::string ucxInterface; - if (!init) - { - init = true; - { - char const* ucx_interface = std::getenv("TRTLLM_UCX_INTERFACE"); - if (ucx_interface) - { - ucxInterface = ucx_interface; - } - } - } - return ucxInterface; -} - -bool getEnvDisaggLayerwise() -{ - static bool const disaggLayerwise = getBoolEnv("TRTLLM_DISAGG_LAYERWISE"); - return disaggLayerwise; -} - -bool getEnvParallelCacheSend() -{ - static bool const parallelCacheSend = getBoolEnv("TRTLLM_PARALLEL_CACHE_SEND"); - return parallelCacheSend; -} - -bool getEnvRequestKVCacheSerial() -{ - static bool const requestKVCacheSerial = getBoolEnv("TRTLLM_REQUEST_KV_CACHE_SERIAL"); - return requestKVCacheSerial; -} - -bool getEnvDisableKVCacheTransferOverlap() -{ - static bool const disableKVCacheTransferOverlap = getBoolEnv("TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP"); - return disableKVCacheTransferOverlap; -} - -bool getEnvDisableReceiveKVCacheParallel() -{ - static bool const disableReceiveParallel = getBoolEnv("TRTLLM_DISABLE_KVCACHE_RECEIVE_PARALLEL"); - return disableReceiveParallel; -} - -} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.h deleted file mode 100644 index 027c7cfbb3b..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.h +++ /dev/null @@ -1,60 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include -#include -#include - -namespace tensorrt_llm::common -{ -// Useful when you want to inject some debug code controllable with env var. -std::optional getIntEnv(char const* name); - -// XQA kernels (optimized kernels for generation phase). -bool forceXQAKernels(); - -// Whether XQA JIT is enabled. -// -// Returns the value of TRTLLM_ENABLE_XQA_JIT env var. If such env var doesn't exist, std::nullopt is returned. -std::optional getEnvEnableXQAJIT(); - -// Tune the number of blocks per sequence for accuracy/performance purpose. -bool getEnvMmhaMultiblockDebug(); - -int getEnvMmhaBlocksPerSequence(); - -int getEnvMmhaKernelBlockSize(); - -// Whether PDL is enabled. -bool getEnvEnablePDL(); - -bool getEnvUseUCXKvCache(); - -std::string getEnvUCXInterface(); - -bool getEnvDisaggLayerwise(); - -bool getEnvParallelCacheSend(); - -bool getEnvRequestKVCacheSerial(); - -bool getEnvDisableKVCacheTransferOverlap(); - -bool getEnvDisableReceiveKVCacheParallel(); - -} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/logger.h b/sgl-kernel/3rdparty/tensorrt_llm/common/logger.h new file mode 100644 index 00000000000..df84e226389 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/logger.h @@ -0,0 +1,190 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/stringUtils.h" + +namespace tensorrt_llm::common +{ + +class Logger +{ + +// On Windows, the file wingdi.h is included which has +// #define ERROR 0 +// This breaks everywhere ERROR is used in the Level enum +#ifdef _WIN32 +#undef ERROR +#endif // _WIN32 + +public: + enum Level + { + TRACE = 0, + DEBUG = 10, + INFO = 20, + WARNING = 30, + ERROR = 40 + }; + + static Logger* getLogger(); + + Logger(Logger const&) = delete; + void operator=(Logger const&) = delete; + +#if defined(_MSC_VER) + template + void log(Level level, char const* format, Args const&... args); + + template + void log(Level level, int rank, char const* format, Args const&... args); +#else + template + void log(Level level, char const* format, Args const&... args) __attribute__((format(printf, 3, 0))); + + template + void log(Level level, int rank, char const* format, Args const&... args) __attribute__((format(printf, 4, 0))); +#endif + + template + void log(Level level, std::string const& format, Args const&... args) + { + return log(level, format.c_str(), args...); + } + + template + void log(Level const level, int const rank, std::string const& format, Args const&... args) + { + return log(level, rank, format.c_str(), args...); + } + + void log(std::exception const& ex, Level level = Level::ERROR); + + Level getLevel() const + { + return level_; + } + + void setLevel(Level const level) + { + level_ = level; + log(INFO, "Set logger level to %s", getLevelName(level)); + } + + bool isEnabled(Level const level) const + { + return level_ <= level; + } + +private: + static auto constexpr kPREFIX = "[TensorRT-LLM]"; + +#ifndef NDEBUG + Level const DEFAULT_LOG_LEVEL = DEBUG; +#else + Level const DEFAULT_LOG_LEVEL = INFO; +#endif + Level level_ = DEFAULT_LOG_LEVEL; + + Logger(); // NOLINT(modernize-use-equals-delete) + + static inline char const* getLevelName(Level const level) + { + switch (level) + { + case TRACE: return "TRACE"; + case DEBUG: return "DEBUG"; + case INFO: return "INFO"; + case WARNING: return "WARNING"; + case ERROR: return "ERROR"; + } + + TLLM_THROW("Unknown log level: %d", level); + } + + static inline std::string getPrefix(Level const level) + { + return fmtstr("%s[%s] ", kPREFIX, getLevelName(level)); + } + + static inline std::string getPrefix(Level const level, int const rank) + { + return fmtstr("%s[%s][%d] ", kPREFIX, getLevelName(level), rank); + } +}; + +template +void Logger::log(Logger::Level level, char const* format, Args const&... args) +{ + if (isEnabled(level)) + { + auto const fmt = getPrefix(level) + format; + auto& out = level_ < WARNING ? std::cout : std::cerr; + if constexpr (sizeof...(args) > 0) + { + out << fmtstr(fmt.c_str(), args...); + } + else + { + out << fmt; + } + out << std::endl; + } +} + +template +void Logger::log(Logger::Level const level, int const rank, char const* format, Args const&... args) +{ + if (isEnabled(level)) + { + auto const fmt = getPrefix(level, rank) + format; + auto& out = level_ < WARNING ? std::cout : std::cerr; + if constexpr (sizeof...(args) > 0) + { + out << fmtstr(fmt.c_str(), args...); + } + else + { + out << fmt; + } + out << std::endl; + } +} + +#define TLLM_LOG(level, ...) \ + do \ + { \ + auto* const logger = tensorrt_llm::common::Logger::getLogger(); \ + if (logger->isEnabled(level)) \ + { \ + logger->log(level, __VA_ARGS__); \ + } \ + } while (0) + +#define TLLM_LOG_TRACE(...) TLLM_LOG(tensorrt_llm::common::Logger::TRACE, __VA_ARGS__) +#define TLLM_LOG_DEBUG(...) TLLM_LOG(tensorrt_llm::common::Logger::DEBUG, __VA_ARGS__) +#define TLLM_LOG_INFO(...) TLLM_LOG(tensorrt_llm::common::Logger::INFO, __VA_ARGS__) +#define TLLM_LOG_WARNING(...) TLLM_LOG(tensorrt_llm::common::Logger::WARNING, __VA_ARGS__) +#define TLLM_LOG_ERROR(...) TLLM_LOG(tensorrt_llm::common::Logger::ERROR, __VA_ARGS__) +#define TLLM_LOG_EXCEPTION(ex, ...) tensorrt_llm::common::Logger::getLogger()->log(ex, ##__VA_ARGS__) +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/mathUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/mathUtils.h deleted file mode 100644 index 1bad3a2c152..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/mathUtils.h +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace tensorrt_llm -{ -namespace common -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ __host__ T divUp(T m, T n) -{ - return (m + n - 1) / n; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace common -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.cu b/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.cu deleted file mode 100644 index d13217b203a..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.cu +++ /dev/null @@ -1,906 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/common/assert.h" -#include "tensorrt_llm/common/cudaTypeUtils.cuh" -#include "tensorrt_llm/common/logger.h" -#include "tensorrt_llm/common/memoryUtils.h" - -#include -#include -#include - -namespace tensorrt_llm -{ -namespace common -{ - -template -void deviceMalloc(T** ptr, size_t size, bool is_random_initialize) -{ - check_cuda_error(cudaMalloc((void**) (ptr), sizeof(T) * size)); - if (is_random_initialize) - { - cudaRandomUniform(*ptr, size); - } -} - -template void deviceMalloc(float** ptr, size_t size, bool is_random_initialize); -template void deviceMalloc(half** ptr, size_t size, bool is_random_initialize); -#ifdef ENABLE_BF16 -template void deviceMalloc(__nv_bfloat16** ptr, size_t size, bool is_random_initialize); -#endif -template void deviceMalloc(uint16_t** ptr, size_t size, bool is_random_initialize); -template void deviceMalloc(int** ptr, size_t size, bool is_random_initialize); -template void deviceMalloc(bool** ptr, size_t size, bool is_random_initialize); -template void deviceMalloc(char** ptr, size_t size, bool is_random_initialize); -template void deviceMalloc(int8_t** ptr, size_t size, bool is_random_initialize); -#ifdef ENABLE_FP8 -template void deviceMalloc(__nv_fp8_e4m3** ptr, size_t size, bool is_random_initialize); -#endif - -template -void deviceMemSetZero(T* ptr, size_t size) -{ - check_cuda_error(cudaMemset(static_cast(ptr), 0, sizeof(T) * size)); -} - -template void deviceMemSetZero(float* ptr, size_t size); -template void deviceMemSetZero(half* ptr, size_t size); -template void deviceMemSetZero(int* ptr, size_t size); -template void deviceMemSetZero(uint32_t* ptr, size_t size); -template void deviceMemSetZero(bool* ptr, size_t size); -#ifdef ENABLE_FP8 -template void deviceMemSetZero(__nv_fp8_e4m3* ptr, size_t size); -#endif -#ifdef ENABLE_BF16 -template void deviceMemSetZero(__nv_bfloat16* ptr, size_t size); -#endif - -template -void deviceFree(T*& ptr) -{ - if (ptr != NULL) - { - check_cuda_error(cudaFree(ptr)); - ptr = NULL; - } -} - -template void deviceFree(float*& ptr); -template void deviceFree(half*& ptr); -#ifdef ENABLE_BF16 -template void deviceFree(__nv_bfloat16*& ptr); -#endif -template void deviceFree(unsigned short*& ptr); -template void deviceFree(int*& ptr); -template void deviceFree(bool*& ptr); -template void deviceFree(char*& ptr); -template void deviceFree(int8_t*& ptr); -#ifdef ENABLE_FP8 -template void deviceFree(__nv_fp8_e4m3*& ptr); -#endif - -template -void deviceFill(T* devptr, size_t size, T value, cudaStream_t stream) -{ - T* arr = new T[size]; - std::fill(arr, arr + size, value); - check_cuda_error(cudaMemcpyAsync(devptr, arr, sizeof(T) * size, cudaMemcpyHostToDevice, stream)); - delete[] arr; -} - -template void deviceFill(float* devptr, size_t size, float value, cudaStream_t stream); -template void deviceFill(half* devptr, size_t size, half value, cudaStream_t stream); -#ifdef ENABLE_BF16 -template void deviceFill(__nv_bfloat16* devptr, size_t size, __nv_bfloat16 value, cudaStream_t stream); -#endif -template void deviceFill(int* devptr, size_t size, int value, cudaStream_t stream); -template void deviceFill(bool* devptr, size_t size, bool value, cudaStream_t stream); - -template -void cudaD2Hcpy(T* tgt, T const* src, const size_t size) -{ - check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyDeviceToHost)); -} - -template void cudaD2Hcpy(float* tgt, float const* src, size_t size); -template void cudaD2Hcpy(half* tgt, half const* src, size_t size); -#ifdef ENABLE_BF16 -template void cudaD2Hcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size); -#endif -template void cudaD2Hcpy(int* tgt, int const* src, size_t size); -template void cudaD2Hcpy(bool* tgt, bool const* src, size_t size); -#ifdef ENABLE_FP8 -template void cudaD2Hcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size); -#endif -template void cudaD2Hcpy(unsigned long long* tgt, unsigned long long const* src, size_t size); -template void cudaD2Hcpy(unsigned int* tgt, unsigned int const* src, size_t size); -template void cudaD2Hcpy(int8_t* tgt, int8_t const* src, size_t size); - -template -void cudaH2Dcpy(T* tgt, T const* src, const size_t size) -{ - check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyHostToDevice)); -} - -template void cudaH2Dcpy(float* tgt, float const* src, size_t size); -template void cudaH2Dcpy(half* tgt, half const* src, size_t size); -#ifdef ENABLE_BF16 -template void cudaH2Dcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size); -#endif -template void cudaH2Dcpy(int* tgt, int const* src, size_t size); -template void cudaH2Dcpy(bool* tgt, bool const* src, size_t size); -#ifdef ENABLE_FP8 -template void cudaH2Dcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size); -#endif -template void cudaH2Dcpy(unsigned long long* tgt, unsigned long long const* src, size_t size); -template void cudaH2Dcpy(unsigned int* tgt, unsigned int const* src, size_t size); -template void cudaH2Dcpy(int8_t* tgt, int8_t const* src, size_t size); - -template -void cudaD2Dcpy(T* tgt, T const* src, const size_t size, cudaStream_t stream) -{ - check_cuda_error(cudaMemcpyAsync(tgt, src, sizeof(T) * size, cudaMemcpyDeviceToDevice, stream)); -} - -template void cudaD2Dcpy(float* tgt, float const* src, size_t size, cudaStream_t stream); -template void cudaD2Dcpy(half* tgt, half const* src, size_t size, cudaStream_t stream); -#ifdef ENABLE_BF16 -template void cudaD2Dcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size, cudaStream_t stream); -#endif -template void cudaD2Dcpy(int* tgt, int const* src, size_t size, cudaStream_t stream); -template void cudaD2Dcpy(bool* tgt, bool const* src, size_t size, cudaStream_t stream); -template void cudaD2Dcpy(int8_t* tgt, int8_t const* src, size_t size, cudaStream_t stream); -#ifdef ENABLE_FP8 -template void cudaD2Dcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size, cudaStream_t stream); -#endif -template void cudaD2Dcpy(unsigned long long* tgt, unsigned long long const* src, size_t size, cudaStream_t stream); - -template -__global__ void cudaCast(T_OUT* dst, T_IN* src, const size_t size) -{ - for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) - { - dst[tid] = (T_OUT) ((float) (src[tid])); - } -} - -template -void invokeCudaCast(T_OUT* dst, T_IN const* const src, const size_t size, cudaStream_t stream) -{ - cudaCast<<<256, 256, 0, stream>>>(dst, src, size); -} - -template void invokeCudaCast(float* dst, half const* const src, const size_t size, cudaStream_t stream); -#ifdef ENABLE_BF16 -template void invokeCudaCast(float* dst, __nv_bfloat16 const* const src, const size_t size, cudaStream_t stream); -template void invokeCudaCast(__nv_bfloat16* dst, float const* const src, const size_t size, cudaStream_t stream); -template void invokeCudaCast(__nv_bfloat16* dst, half const* const src, const size_t size, cudaStream_t stream); -template void invokeCudaCast(half* dst, __nv_bfloat16 const* const src, const size_t size, cudaStream_t stream); -#endif -#ifdef ENABLE_FP8 -template void invokeCudaCast(float* dst, __nv_fp8_e4m3 const* const src, const size_t size, cudaStream_t stream); -template void invokeCudaCast( - __nv_bfloat16* dst, __nv_fp8_e4m3 const* const src, const size_t size, cudaStream_t stream); -template void invokeCudaCast(half* dst, __nv_fp8_e4m3 const* const src, const size_t size, cudaStream_t stream); -template void invokeCudaCast(__nv_fp8_e4m3* dst, float const* const src, const size_t size, cudaStream_t stream); -template void invokeCudaCast( - __nv_fp8_e4m3* dst, __nv_bfloat16 const* const src, const size_t size, cudaStream_t stream); -template void invokeCudaCast(__nv_fp8_e4m3* dst, half const* const src, const size_t size, cudaStream_t stream); -#endif - -template -void cudaAutoCpy(T* tgt, T const* src, const size_t size, cudaStream_t stream) -{ - if (stream != NULL) - { - check_cuda_error(cudaMemcpyAsync(tgt, src, sizeof(T) * size, cudaMemcpyDefault, stream)); - } - else - { - check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyDefault)); - } -} - -template void cudaAutoCpy(float* tgt, float const* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(half* tgt, half const* src, size_t size, cudaStream_t stream); -#ifdef ENABLE_BF16 -template void cudaAutoCpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size, cudaStream_t stream); -#endif -template void cudaAutoCpy(int* tgt, int const* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(bool* tgt, bool const* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(int8_t* tgt, int8_t const* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(uint8_t* tgt, uint8_t const* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(uint32_t* tgt, uint32_t const* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(unsigned long long* tgt, unsigned long long const* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(unsigned long* tgt, unsigned long const* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(char* tgt, char const* src, size_t size, cudaStream_t stream); - -template void cudaAutoCpy(float const** tgt, float const* const* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(half const** tgt, half const* const* src, size_t size, cudaStream_t stream); -#ifdef ENABLE_BF16 -template void cudaAutoCpy(__nv_bfloat16 const** tgt, __nv_bfloat16 const* const* src, size_t size, cudaStream_t stream); -#endif -template void cudaAutoCpy(int const** tgt, int const* const* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(bool const** tgt, bool const* const* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy(int8_t const** tgt, int8_t const* const* src, size_t size, cudaStream_t stream); -template void cudaAutoCpy( - unsigned long long const** tgt, unsigned long long const* const* src, size_t size, cudaStream_t stream); - -template -__global__ void cuda_random_uniform_kernel(T* buffer, const size_t size, int const seq_offset) -{ - const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - curandState_t local_state; - curand_init((unsigned long long int) 1337, idx + seq_offset, 0, &local_state); - for (size_t index = idx; index < size; index += blockDim.x * gridDim.x) - { - buffer[index] = (T) (curand_uniform(&local_state) * 0.2f - 0.1f); - } -} - -template <> -__global__ void cuda_random_uniform_kernel(int* buffer, const size_t size, int const seq_offset) -{ - const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - curandState_t local_state; - curand_init((float) 1337.f, idx + seq_offset, 0, &local_state); - for (size_t index = idx; index < size; index += blockDim.x * gridDim.x) - { - buffer[index] = curand(&local_state); - } -} - -template <> -__global__ void cuda_random_uniform_kernel(bool* buffer, const size_t size, int const seq_offset) -{ - const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - curandState_t local_state; - curand_init((float) 1337.f, idx + seq_offset, 0, &local_state); - for (size_t index = idx; index < size; index += blockDim.x * gridDim.x) - { - buffer[index] = (curand(&local_state) % 2 == 0); - } -} - -template <> -__global__ void cuda_random_uniform_kernel(char* buffer, const size_t size, int const seq_offset) -{ - const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; - curandState_t local_state; - curand_init((float) 1337.f, idx + seq_offset, 0, &local_state); - for (size_t index = idx; index < size; index += blockDim.x * gridDim.x) - { - buffer[index] = curand(&local_state) % 0xFF; - } -} - -template -void cudaRandomUniform(T* buffer, const size_t size) -{ - static int seq_offset = 0; - cuda_random_uniform_kernel<<<256, 256>>>(buffer, size, seq_offset); - seq_offset += 256 * 256; -} - -template void cudaRandomUniform(float* buffer, const size_t size); -template void cudaRandomUniform(half* buffer, const size_t size); -#ifdef ENABLE_BF16 -template void cudaRandomUniform(__nv_bfloat16* buffer, const size_t size); -#endif -template void cudaRandomUniform(int* buffer, const size_t size); -template void cudaRandomUniform(bool* buffer, const size_t size); -template void cudaRandomUniform(char* buffer, const size_t size); -#ifdef ENABLE_FP8 -template void cudaRandomUniform(__nv_fp8_e4m3* buffer, const size_t size); -#endif - -// loads data from binary file. If it succeeds, returns a non-empty vector. If loading fails or -// the product of the elements in shape is 0, this function will return an empty vector. -template -std::vector loadWeightFromBinHelper(std::vector shape, std::string filename) -{ - if (shape.size() > 2) - { - printf("[ERROR] shape should have less than two dims \n"); - return std::vector(); - } - size_t dim0 = shape[0], dim1 = 1; - if (shape.size() == 2) - { - dim1 = shape[1]; - } - size_t size = dim0 * dim1; - if (size == 0) - { - TLLM_LOG_WARNING("shape is zero, skip loading weight from file %s \n", filename.c_str()); - return std::vector(); - } - - std::vector host_array(size); - std::ifstream in(filename, std::ios::in | std::ios::binary); - if (!in.is_open()) - { - TLLM_LOG_WARNING("file %s cannot be opened, loading model fails! \n", filename.c_str()); - return std::vector(); - } - - size_t loaded_data_size = sizeof(T) * size; - in.seekg(0, in.end); - in.seekg(0, in.beg); - - TLLM_LOG_DEBUG("Read " + std::to_string(loaded_data_size) + " bytes from " + filename); - in.read((char*) host_array.data(), loaded_data_size); - - size_t in_get_size = in.gcount(); - if (in_get_size != loaded_data_size) - { - TLLM_LOG_WARNING("file %s only has %ld, but request %ld, loading model fails! \n", filename.c_str(), - in_get_size, loaded_data_size); - return std::vector(); - } - in.close(); - // If we succeed, return an array with values. - return host_array; -} - -template -int loadWeightFromBinFunc(T* ptr, std::vector shape, std::string filename) -{ - std::vector host_array = loadWeightFromBinHelper(shape, filename); - - if (host_array.empty()) - { - return 0; - } - - if (std::is_same::value == true) - { - cudaH2Dcpy(ptr, (T*) host_array.data(), host_array.size()); - } - else - { - T_IN* ptr_2 = nullptr; - deviceMalloc(&ptr_2, host_array.size(), false); - cudaH2Dcpy(ptr_2, host_array.data(), host_array.size()); - invokeCudaD2DcpyConvert(ptr, ptr_2, host_array.size()); - deviceFree(ptr_2); - } - return 0; -} - -template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc(int8_t* ptr, std::vector shape, std::string filename); -#ifdef ENABLE_BF16 -template int loadWeightFromBinFunc<__nv_bfloat16, float>( - __nv_bfloat16* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc<__nv_bfloat16, half>( - __nv_bfloat16* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); -template int loadWeightFromBinFunc<__nv_bfloat16, __nv_bfloat16>( - __nv_bfloat16* ptr, std::vector shape, std::string filename); -#endif // ENABLE_BF16 -template int loadWeightFromBinFunc(int* ptr, std::vector shape, std::string filename); -#ifdef ENABLE_FP8 -template int loadWeightFromBinFunc<__nv_fp8_e4m3, float>( - __nv_fp8_e4m3* ptr, std::vector shape, std::string filename); -#endif // ENABLE_FP8 - -template -int loadWeightFromBin(T* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type) -{ - switch (model_file_type) - { - case TRTLLMCudaDataType::FP32: loadWeightFromBinFunc(ptr, shape, filename); break; - case TRTLLMCudaDataType::FP16: loadWeightFromBinFunc(ptr, shape, filename); break; - case TRTLLMCudaDataType::INT8: loadWeightFromBinFunc(ptr, shape, filename); break; -#ifdef ENABLE_BF16 - case TRTLLMCudaDataType::BF16: loadWeightFromBinFunc(ptr, shape, filename); break; -#endif -#ifdef ENABLE_FP8 - case TRTLLMCudaDataType::FP8: loadWeightFromBinFunc(ptr, shape, filename); break; -#endif - default: TLLM_LOG_ERROR("Does not support TRTLLMCudaDataType=%d", model_file_type); TLLM_CHECK(false); - } - return 0; -} - -template <> -int loadWeightFromBin(int* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type) -{ - loadWeightFromBinFunc(ptr, shape, filename); - return 0; -} - -template int loadWeightFromBin( - float* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); -template int loadWeightFromBin( - half* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); -template int loadWeightFromBin( - int8_t* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); -#ifdef ENABLE_BF16 -template int loadWeightFromBin( - __nv_bfloat16* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); -#endif -#ifdef ENABLE_FP8 -template int loadWeightFromBin( - __nv_fp8_e4m3* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); -#endif -template int loadWeightFromBin( - int* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); - -template -__global__ void cudaD2DcpyConvert(T_OUT* dst, const T_IN* src, const size_t size) -{ - for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) - { - dst[tid] = cuda_cast(src[tid]); - } -} - -template -void invokeCudaD2DcpyConvert(T_OUT* tgt, const T_IN* src, const size_t size, cudaStream_t stream) -{ - cudaD2DcpyConvert<<<256, 256, 0, stream>>>(tgt, src, size); -} - -template void invokeCudaD2DcpyConvert(int8_t* tgt, float const* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(float* tgt, int8_t const* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(float* tgt, int const* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(half* tgt, int const* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(float* tgt, float const* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(half* tgt, float const* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(float* tgt, half const* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(uint32_t* tgt, int const* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(int* tgt, uint32_t const* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(int* tgt, float const* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(int* tgt, half const* src, const size_t size, cudaStream_t stream); - -#ifdef ENABLE_BF16 -template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, float const* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, int const* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(float* tgt, __nv_bfloat16 const* src, const size_t size, cudaStream_t stream); -template void invokeCudaD2DcpyConvert(int* tgt, __nv_bfloat16 const* src, const size_t size, cudaStream_t stream); -#endif // ENABLE_BF16 - -template -__global__ void cudaD2DScaleCpyConvert( - T_OUT* dst, const T_IN* src, float const* scale, bool invert_scale, const size_t size) -{ - float const scale_value = invert_scale ? 1.0f / scale[0] : scale[0]; - for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) - { - dst[tid] = cuda_cast(cuda_cast(src[tid]) * scale_value); - } -} - -template -void invokeCudaD2DScaleCpyConvert( - T_OUT* tgt, const T_IN* src, float const* scale, bool invert_scale, const size_t size, cudaStream_t stream) -{ - cudaD2DScaleCpyConvert<<<256, 256, 0, stream>>>(tgt, src, scale, invert_scale, size); -} - -// clang-format off -template void invokeCudaD2DScaleCpyConvert(float* tgt, const int32_t* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); -template void invokeCudaD2DScaleCpyConvert(int32_t* tgt, const float* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); -template void invokeCudaD2DScaleCpyConvert(half* tgt, const int32_t* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); -template void invokeCudaD2DScaleCpyConvert(int32_t* tgt, const half* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); -#ifdef ENABLE_BF16 -template void invokeCudaD2DScaleCpyConvert(__nv_bfloat16* tgt, const int32_t* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); -template void invokeCudaD2DScaleCpyConvert(int32_t* tgt, const __nv_bfloat16* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); -#endif // ENABLE_BF16 -#ifdef ENABLE_FP8 -template void invokeCudaD2DScaleCpyConvert(float* tgt, const __nv_fp8_e4m3* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); -#endif // ENABLE_FP8 -// clang-format on - -void invokeCudaD2DcpyHalf2Float(float* dst, half* src, const size_t size, cudaStream_t stream) -{ - invokeCudaD2DcpyConvert(dst, src, size, stream); -} - -void invokeCudaD2DcpyFloat2Half(half* dst, float* src, const size_t size, cudaStream_t stream) -{ - invokeCudaD2DcpyConvert(dst, src, size, stream); -} - -template -void saveToBinary(T const* ptr, const size_t size, std::string filename) -{ - - std::vector h_ptr(size); - cudaD2Hcpy(h_ptr.data(), ptr, size); - std::vector float_ptr(size); - for (size_t i = 0; i < size; i++) - { - float_ptr[i] = (float) h_ptr[i]; - } - - std::ofstream out(filename, std::ios::out | std::ios::binary); - TLLM_CHECK_WITH_INFO(out.is_open(), "Fail to open file " + filename); - - out.write((char*) float_ptr.data(), size * sizeof(float)); -} - -template void saveToBinary(float const* ptr, const size_t size, std::string filename); -template void saveToBinary(half const* ptr, const size_t size, std::string filename); -#ifdef ENABLE_BF16 -template void saveToBinary(__nv_bfloat16 const* ptr, const size_t size, std::string filename); -#endif // ENABLE_BF16 - -template <> -void saveToBinary(int const* ptr, const size_t size, std::string filename) -{ - std::vector h_ptr(size); - cudaD2Hcpy(h_ptr.data(), ptr, size); - std::ofstream out(filename, std::ios::out | std::ios::binary); - TLLM_CHECK_WITH_INFO(out.is_open(), "Fail to open file " + filename); - out.write((char*) h_ptr.data(), size * sizeof(int)); -} - -template -__global__ void fakeCast(T_IN* input_ptr, const size_t size) -{ - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) - { - T_fake_type tmp_val = (T_fake_type) ((float) input_ptr[i]); - input_ptr[i] = (T_IN) ((float) tmp_val); - } -} - -template -void invokeFakeCast(T_IN* input_ptr, const size_t size, cudaStream_t stream) -{ - dim3 block(256); - dim3 grid((size + 255) / 256); - fakeCast<<>>(input_ptr, size); -} - -#ifdef ENABLE_FP8 -__global__ void cudaD2Dcpyfp82Float(float* dst, __nv_fp8_e4m3* src, const size_t size) -{ - for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) - { - dst[tid] = (float) (src[tid]); - } -} - -void invokeCudaD2Dcpyfp82Float(float* dst, __nv_fp8_e4m3* src, const size_t size, cudaStream_t stream) -{ - cudaD2Dcpyfp82Float<<<256, 256, 0, stream>>>(dst, src, size); -} - -__global__ void cudaD2Dcpyfp82Half(half* dst, __nv_fp8_e4m3* src, const size_t size) -{ - for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) - { - dst[tid] = (half) ((float) (src[tid])); - } -} - -void invokeCudaD2Dcpyfp82Half(half* dst, __nv_fp8_e4m3* src, const size_t size, cudaStream_t stream) -{ - cudaD2Dcpyfp82Half<<<256, 256, 0, stream>>>(dst, src, size); -} - -__global__ void cudaD2DcpyFloat2fp8(__nv_fp8_e4m3* dst, float* src, const size_t size) -{ - for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) - { - dst[tid] = (__nv_fp8_e4m3) src[tid]; - } -} - -void invokeCudaD2DcpyFloat2fp8(__nv_fp8_e4m3* dst, float* src, const size_t size, cudaStream_t stream) -{ - cudaD2DcpyFloat2fp8<<<256, 256, 0, stream>>>(dst, src, size); -} - -__global__ void cudaD2DcpyHalf2fp8(__nv_fp8_e4m3* dst, half* src, const size_t size) -{ - for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) - { - dst[tid] = (__nv_fp8_e4m3) src[tid]; - } -} - -void invokeCudaD2DcpyHalf2fp8(__nv_fp8_e4m3* dst, half* src, const size_t size, cudaStream_t stream) -{ - cudaD2DcpyHalf2fp8<<<256, 256, 0, stream>>>(dst, src, size); -} - -__global__ void cudaD2DcpyBfloat2fp8(__nv_fp8_e4m3* dst, __nv_bfloat16* src, const size_t size) -{ - for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) - { - dst[tid] = (__nv_fp8_e4m3) src[tid]; - } -} - -void invokeCudaD2DcpyBfloat2fp8(__nv_fp8_e4m3* dst, __nv_bfloat16* src, const size_t size, cudaStream_t stream) -{ - cudaD2DcpyBfloat2fp8<<<256, 256, 0, stream>>>(dst, src, size); -} - -#endif // ENABLE_FP8 - -template -__global__ void transpose(T_OUT* dst, T_IN* src, const size_t dim0, const size_t dim1) -{ - for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < dim0 * dim1; tid += blockDim.x * gridDim.x) - { - const size_t src_col_id = tid % dim1; - const size_t src_row_id = tid / dim1; - dst[src_col_id * dim0 + src_row_id] = (T_OUT) (src[tid]); - } -} - -template -void invokeInPlaceTranspose(T* data, T* workspace, const size_t dim0, const size_t dim1) -{ - // copy data to workspace, and then transpose from workspace to data - cudaD2Dcpy(workspace, data, dim0 * dim1); - transpose<<<256, 256>>>(data, workspace, dim0, dim1); -} - -#ifdef ENABLE_FP8 -template void invokeInPlaceTranspose( - __nv_fp8_e4m3* data, __nv_fp8_e4m3* workspace, const size_t dim0, const size_t dim1); -#endif // ENABLE_FP8 -#ifdef ENABLE_BF16 -template void invokeInPlaceTranspose( - __nv_bfloat16* data, __nv_bfloat16* workspace, const size_t dim0, const size_t dim1); -#endif // ENABLE_BF16 -template void invokeInPlaceTranspose(float* data, float* workspace, const size_t dim0, const size_t dim1); - -template -__global__ void transpose0213( - T_OUT* dst, T_IN* src, const size_t dim0, const size_t dim1, const size_t dim2, const size_t dim3) -{ - // src permutation: [0, 1, 2, 3] - // dst permutation: [0, 2, 1, 3] - for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < dim0 * dim1 * dim2 * dim3; - tid += blockDim.x * gridDim.x) - { - size_t tmp_idx = tid; - const size_t dim_3_idx = tmp_idx % dim3; - tmp_idx = (tmp_idx - dim_3_idx) / dim3; - const size_t dim_2_idx = tmp_idx % dim2; - tmp_idx = (tmp_idx - dim_2_idx) / dim2; - const size_t dim_1_idx = tmp_idx % dim1; - tmp_idx = (tmp_idx - dim_1_idx) / dim1; - const size_t dim_0_idx = tmp_idx % dim0; - dst[dim_0_idx * dim1 * dim2 * dim3 + dim_2_idx * dim1 * dim3 + dim_1_idx * dim3 + dim_3_idx] = src[tid]; - } -} - -template -void invokeInPlaceTranspose0213( - T* data, T* workspace, const size_t dim0, const size_t dim1, const size_t dim2, const size_t dim3) -{ - // copy data to workspace, and then transpose from workspace to data - // Note that this kernel is used for pre-processing and not very efficient. - cudaD2Dcpy(workspace, data, dim0 * dim1 * dim2 * dim3); - transpose0213<<<256, 256>>>(data, workspace, dim0, dim1, dim2, dim3); -} - -#ifdef ENABLE_FP8 -template void invokeInPlaceTranspose0213(__nv_fp8_e4m3* data, __nv_fp8_e4m3* workspace, const size_t dim0, - const size_t dim1, const size_t dim2, const size_t dim3); -#endif // ENABLE_FP8 -#ifdef ENABLE_BF16 -template void invokeInPlaceTranspose0213(__nv_bfloat16* data, __nv_bfloat16* workspace, const size_t dim0, - const size_t dim1, const size_t dim2, const size_t dim3); -#endif // ENABLE_BF16 -template void invokeInPlaceTranspose0213( - float* data, float* workspace, const size_t dim0, const size_t dim1, const size_t dim2, const size_t dim3); - -template -__global__ void transpose102(T_OUT* dst, T_IN* src, const size_t dim0, const size_t dim1, const size_t dim2) -{ - // src permutation: [0, 1, 2] - // dst permutation: [1, 0, 2] - for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < dim0 * dim1 * dim2; tid += blockDim.x * gridDim.x) - { - size_t tmp_idx = tid; - const size_t dim_2_idx = tmp_idx % dim2; - tmp_idx = (tmp_idx - dim_2_idx) / dim2; - const size_t dim_1_idx = tmp_idx % dim1; - tmp_idx = (tmp_idx - dim_1_idx) / dim1; - const size_t dim_0_idx = tmp_idx % dim0; - dst[dim_1_idx * dim0 * dim2 + dim_0_idx * dim2 + dim_2_idx] = src[tid]; - } -} - -template -void invokeInPlaceTranspose102(T* data, T* workspace, const size_t dim0, const size_t dim1, const size_t dim2) -{ - // copy data to workspace, and then transpose from workspace to data - // Note that this kernel is used for pre-processing and not very efficient. - cudaD2Dcpy(workspace, data, dim0 * dim1 * dim2); - transpose102<<<256, 256>>>(data, workspace, dim0, dim1, dim2); -} - -#ifdef ENABLE_FP8 -template void invokeInPlaceTranspose102( - __nv_fp8_e4m3* data, __nv_fp8_e4m3* workspace, const size_t dim0, const size_t dim1, const size_t dim2); -#endif // ENABLE_FP8 -#ifdef ENABLE_BF16 -template void invokeInPlaceTranspose102( - __nv_bfloat16* data, __nv_bfloat16* workspace, const size_t dim0, const size_t dim1, const size_t dim2); -#endif // ENABLE_BF16 -template void invokeInPlaceTranspose102( - float* data, float* workspace, const size_t dim0, const size_t dim1, const size_t dim2); - -template -void __global__ multiplyScale(T* tensor, float scale, const size_t size) -{ - for (size_t index = threadIdx.x + blockIdx.x * blockDim.x; index < size; index += blockDim.x * gridDim.x) - { - tensor[index] = (T) (((float) tensor[index]) * scale); - } -} - -template -void invokeMultiplyScale(T* tensor, float scale, const size_t size, cudaStream_t stream) -{ - int block = 256; - int grid = (size + 255) / 256; - multiplyScale<<>>(tensor, scale, size); -} - -template void invokeMultiplyScale(float* tensor, float scale, const size_t size, cudaStream_t stream); -template void invokeMultiplyScale(half* tensor, float scale, const size_t size, cudaStream_t stream); -#ifdef ENABLE_BF16 -template void invokeMultiplyScale(__nv_bfloat16* tensor, float scale, const size_t size, cudaStream_t stream); -#endif -#ifdef ENABLE_FP8 -template void invokeMultiplyScale(__nv_fp8_e4m3* tensor, float scale, const size_t size, cudaStream_t stream); -#endif - -template -void __global__ divideScale(T* tensor, float scale, const size_t size) -{ - for (size_t index = threadIdx.x + blockIdx.x * blockDim.x; index < size; index += blockDim.x * gridDim.x) - { - tensor[index] = (T) (((float) tensor[index]) / scale); - } -} - -template -void invokeDivideScale(T* tensor, float scale, const size_t size, cudaStream_t stream) -{ - int block = 256; - int grid = (size + 255) / 256; - divideScale<<>>(tensor, scale, size); -} - -template void invokeDivideScale(float* tensor, float scale, const size_t size, cudaStream_t stream); -template void invokeDivideScale(half* tensor, float scale, const size_t size, cudaStream_t stream); -#ifdef ENABLE_BF16 -template void invokeDivideScale(__nv_bfloat16* tensor, float scale, const size_t size, cudaStream_t stream); -#endif -#ifdef ENABLE_FP8 -template void invokeDivideScale(__nv_fp8_e4m3* tensor, float scale, const size_t size, cudaStream_t stream); -#endif -#ifdef ENABLE_BF16 -template void invokeFakeCast(float* input_ptr, const size_t size, cudaStream_t stream); -template void invokeFakeCast<__nv_bfloat16, __nv_bfloat16>( - __nv_bfloat16* input_ptr, const size_t size, cudaStream_t stream); -template void invokeFakeCast(half* input_ptr, const size_t size, cudaStream_t stream); -#endif -template void invokeFakeCast(float* input_ptr, const size_t size, cudaStream_t stream); -template void invokeFakeCast(float* input_ptr, const size_t size, cudaStream_t stream); -#ifdef ENABLE_FP8 -template void invokeFakeCast(float* input_ptr, const size_t size, cudaStream_t stream); -template void invokeFakeCast(half* input_ptr, const size_t size, cudaStream_t stream); -template void invokeFakeCast<__nv_bfloat16, __nv_fp8_e4m3>( - __nv_bfloat16* input_ptr, const size_t size, cudaStream_t stream); -#endif - -size_t cuda_datatype_size(TRTLLMCudaDataType dt) -{ - static const std::unordered_map sizes{ - {TRTLLMCudaDataType::FP32, sizeof(float)}, {TRTLLMCudaDataType::FP16, sizeof(half)} -#ifdef ENABLE_BF16 - , - {TRTLLMCudaDataType::BF16, sizeof(__nv_bfloat16)} -#endif - }; - - return sizes.at(dt); -} - -template -__global__ void check_range(T const* buffer, size_t size, T min, T max, bool* d_within_range) -{ - for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) - { - const T val = buffer[i]; - if (val < min || val > max) - { - *d_within_range = false; - } - } -} - -template -bool invokeCheckRange(T const* buffer, const size_t size, T min, T max, bool* d_within_range, cudaStream_t stream) -{ - cudaMemsetAsync(d_within_range, true, sizeof(bool), stream); - - dim3 block(256); - dim3 grid((size + 255) / 256); - check_range<<>>(buffer, size, min, max, d_within_range); - - bool result; - cudaD2Hcpy(&result, d_within_range, 1); - return result; -} - -template bool invokeCheckRange( - int const* buffer, const size_t size, int min, int max, bool* d_within_range, cudaStream_t stream); - -/* - * Determine the total workspace size based on a vector containing multiple variable sizes. - */ -size_t calcAlignedSize(std::vector const& sizes, const size_t ALIGN_BYTES) -{ - const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1); - // Check ALIGN_BYTES is a power of 2 - assert((ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0); - - size_t total = 0; - for (auto sz : sizes) - { - total += (sz + ALIGN_BYTES - 1) & ALIGN_MASK; - } - - // We add extra "ALIGN_BYTES - 1" bytes in case the start address passed to the function calcAlignedPointers() is - // not aligned. - return total + ALIGN_BYTES - 1; -} - -/* - * Given the address of the workspace and the vector containing multiple variable sizes, calculate the start addresses - * of each variable. - */ -void calcAlignedPointers( - std::vector& outPtrs, void const* p, std::vector const& sizes, size_t ALIGN_BYTES) -{ - const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1); - // Check ALIGN_BYTES is a power of 2 - assert((ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0); - - // In case the start address is not aligned - char* ptr = reinterpret_cast((reinterpret_cast(p) + ALIGN_BYTES - 1) & ALIGN_MASK); - - outPtrs.reserve(sizes.size()); - for (auto sz : sizes) - { - outPtrs.push_back(ptr); - ptr += (sz + ALIGN_BYTES - 1) & ALIGN_MASK; - } -} - -} // namespace common -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.h deleted file mode 100644 index 9e413a1beb8..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.h +++ /dev/null @@ -1,292 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "tensorrt_llm/common/cudaFp8Utils.h" -#include "tensorrt_llm/common/cudaUtils.h" - -#include - -namespace tensorrt_llm -{ -namespace common -{ - -template -void deviceMalloc(T** ptr, size_t size, bool is_random_initialize = true); - -template -void deviceMemSetZero(T* ptr, size_t size); - -template - -void deviceFree(T*& ptr); - -template -void deviceFill(T* devptr, size_t size, T value, cudaStream_t stream = 0); - -template -void cudaD2Hcpy(T* tgt, T const* src, size_t const size); - -template -void cudaH2Dcpy(T* tgt, T const* src, size_t const size); - -template -void cudaD2Dcpy(T* tgt, T const* src, size_t const size, cudaStream_t stream = NULL); - -template -void cudaAutoCpy(T* tgt, T const* src, size_t const size, cudaStream_t stream = NULL); - -template -void cudaRandomUniform(T* buffer, size_t const size); - -template -int loadWeightFromBin(T* ptr, std::vector shape, std::string filename, - TRTLLMCudaDataType model_file_type = TRTLLMCudaDataType::FP32); - -// template -// int loadWeightFromBinAndQuantizeForWeightOnly(int8_t* quantized_weight_ptr, -// T* scale_ptr, -// std::vector shape, -// std::string filename, -// TRTLLMCudaDataType model_file_type = TRTLLMCudaDataType::FP32); - -void invokeCudaD2DcpyHalf2Float(float* dst, half* src, size_t const size, cudaStream_t stream); -void invokeCudaD2DcpyFloat2Half(half* dst, float* src, size_t const size, cudaStream_t stream); -#ifdef ENABLE_FP8 -void invokeCudaD2Dcpyfp82Float(float* dst, __nv_fp8_e4m3* src, size_t const size, cudaStream_t stream); -void invokeCudaD2Dcpyfp82Half(half* dst, __nv_fp8_e4m3* src, size_t const size, cudaStream_t stream); -void invokeCudaD2DcpyFloat2fp8(__nv_fp8_e4m3* dst, float* src, size_t const size, cudaStream_t stream); -void invokeCudaD2DcpyHalf2fp8(__nv_fp8_e4m3* dst, half* src, size_t const size, cudaStream_t stream); -void invokeCudaD2DcpyBfloat2fp8(__nv_fp8_e4m3* dst, __nv_bfloat16* src, size_t const size, cudaStream_t stream); -#endif // ENABLE_FP8 -#ifdef ENABLE_BF16 -void invokeCudaD2DcpyBfloat2Float(float* dst, __nv_bfloat16* src, size_t const size, cudaStream_t stream); -#endif // ENABLE_BF16 - -template -void invokeCudaCast(T_OUT* dst, T_IN const* const src, size_t const size, cudaStream_t stream); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// The following functions implement conversion of multi-dimensional indices to an index in a flat array. -// The shape of the Tensor dimensions is passed as one array (`dims`), the indices are given as individual arguments. -// For examples on how to use these functions, see their tests `test_memory_utils.cu`. -// All of these functions can be evaluated at compile time by recursive template expansion. - -template -__inline__ __host__ __device__ std::enable_if_t::value, T> constexpr flat_index( - T const& acc, TDim dims, TIndex const& index) -{ - assert(index < dims[0]); - return acc * dims[0] + index; -} - -template -__inline__ __host__ __device__ std::enable_if_t::value, T> constexpr flat_index( - T const& acc, TDim dims, TIndex const& index, TIndices... indices) -{ - assert(index < dims[0]); - return flat_index(acc * dims[0] + index, dims + 1, indices...); -} - -template -__inline__ __host__ __device__ std::enable_if_t::value, T> constexpr flat_index( - [[maybe_unused]] TDim dims, T const& index) -{ - assert(index < dims[0]); - return index; -} - -template -__inline__ __host__ __device__ - std::enable_if_t::value, typename std::remove_pointer::type> constexpr flat_index( - TDim dims, TIndex const& index, TIndices... indices) -{ - assert(index < dims[0]); - return flat_index(static_cast::type>(index), dims + 1, indices...); -} - -template -__inline__ __host__ __device__ T constexpr flat_index( - std::array const& dims, TIndex const& index, TIndices... indices) -{ - static_assert(skip < N); - static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions"); - return flat_index(&dims[skip], index, indices...); -} - -template -__inline__ __host__ __device__ T constexpr flat_index( - T const& acc, std::array const& dims, TIndex const& index, TIndices... indices) -{ - static_assert(skip < N); - static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions"); - return flat_index(acc, &dims[skip], index, indices...); -} - -template -__inline__ __host__ __device__ T constexpr flat_index(T const (&dims)[N], TIndex const& index, TIndices... indices) -{ - static_assert(skip < N); - static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions"); - return flat_index(static_cast(dims) + skip, index, indices...); -} - -template -__inline__ __host__ __device__ T constexpr flat_index( - T const& acc, T const (&dims)[N], TIndex const& index, TIndices... indices) -{ - static_assert(skip < N); - static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions"); - return flat_index(acc, static_cast(dims) + skip, index, indices...); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// These are simpler functions for multi-dimensional index conversion. Indices and dimensions are passed as individual -// arguments. These functions are more suitable for usage inside kernels than the corresponding flat_index functions -// which require arrays as arguments. Usage examples can be found in `test_memory_utils.cu`. The functions can be -// evaluated at compile time. - -template -__inline__ __host__ __device__ T constexpr flat_index2(TIndex const& index_0, TIndex const& index_1, T const& dim_1) -{ - assert(index_1 < dim_1); - return index_0 * dim_1 + index_1; -} - -template -__inline__ __host__ __device__ T constexpr flat_index3( - TIndex const& index_0, TIndex const& index_1, TIndex const& index_2, T const& dim_1, T const& dim_2) -{ - assert(index_2 < dim_2); - return flat_index2(index_0, index_1, dim_1) * dim_2 + index_2; -} - -template -__inline__ __host__ __device__ T constexpr flat_index4(TIndex const& index_0, TIndex const& index_1, - TIndex const& index_2, TIndex const& index_3, T const& dim_1, T const& dim_2, T const& dim_3) -{ - assert(index_3 < dim_3); - return flat_index3(index_0, index_1, index_2, dim_1, dim_2) * dim_3 + index_3; -} - -template -__inline__ __host__ __device__ T constexpr flat_index5(TIndex const& index_0, TIndex const& index_1, - TIndex const& index_2, TIndex const& index_3, TIndex const& index_4, T const& dim_1, T const& dim_2, T const& dim_3, - T const& dim_4) -{ - assert(index_4 < dim_4); - return flat_index4(index_0, index_1, index_2, index_3, dim_1, dim_2, dim_3) * dim_4 + index_4; -} - -template -__inline__ __host__ __device__ T constexpr flat_index_strided3( - TIndex const& index_0, TIndex const& index_1, TIndex const& index_2, T const& stride_1, T const& stride_2) -{ - assert(index_1 < stride_1 / stride_2); - assert(index_2 < stride_2); - return index_0 * stride_1 + index_1 * stride_2 + index_2; -} - -template -__inline__ __host__ __device__ T constexpr flat_index_strided4(TIndex const& index_0, TIndex const& index_1, - TIndex const& index_2, TIndex const& index_3, T const& stride_1, T const& stride_2, T const& stride_3) -{ - assert(index_1 < stride_1 / stride_2); - assert(index_2 < stride_2 / stride_3); - assert(index_3 < stride_3); - return index_0 * stride_1 + index_1 * stride_2 + index_2 * stride_3 + index_3; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -void invokeInPlaceTranspose(T* data, T* workspace, size_t const dim0, size_t const dim1); - -template -void invokeInPlaceTranspose0213( - T* data, T* workspace, size_t const dim0, size_t const dim1, size_t const dim2, size_t const dim3); - -template -void invokeInPlaceTranspose102(T* data, T* workspace, size_t const dim0, size_t const dim1, size_t const dim2); - -template -void invokeMultiplyScale(T* tensor, float scale, size_t const size, cudaStream_t stream); - -template -void invokeDivideScale(T* tensor, float scale, size_t const size, cudaStream_t stream); - -template -void invokeCudaD2DcpyConvert(T_OUT* tgt, const T_IN* src, size_t const size, cudaStream_t stream = 0); - -template -void invokeCudaD2DScaleCpyConvert( - T_OUT* tgt, const T_IN* src, float const* scale, bool invert_scale, size_t const size, cudaStream_t stream = 0); - -inline bool checkIfFileExist(std::string const& file_path) -{ - std::ifstream in(file_path, std::ios::in | std::ios::binary); - if (in.is_open()) - { - in.close(); - return true; - } - return false; -} - -template -void saveToBinary(T const* ptr, size_t const size, std::string filename); - -template -void invokeFakeCast(T_IN* input_ptr, size_t const size, cudaStream_t stream); - -size_t cuda_datatype_size(TRTLLMCudaDataType dt); - -template -bool invokeCheckRange(T const* buffer, size_t const size, T min, T max, bool* d_within_range, cudaStream_t stream); - -constexpr size_t DEFAULT_ALIGN_BYTES = 256; - -size_t calcAlignedSize(std::vector const& sizes, size_t ALIGN_BYTES = DEFAULT_ALIGN_BYTES); -void calcAlignedPointers(std::vector& outPtrs, void const* p, std::vector const& sizes, - size_t ALIGN_BYTES = DEFAULT_ALIGN_BYTES); - -struct AlignedPointersUnpacker -{ - template - void operator()(T*&... outPtrs) - { - assert(sizeof...(T) == alignedPointers.size()); - auto it = alignedPointers.begin(); - ((outPtrs = static_cast(*it++)), ...); - } - - std::vector alignedPointers; -}; - -AlignedPointersUnpacker inline calcAlignedPointers( - void const* p, std::vector const& sizes, size_t ALIGN_BYTES = DEFAULT_ALIGN_BYTES) -{ - AlignedPointersUnpacker unpacker{}; - calcAlignedPointers(unpacker.alignedPointers, p, sizes, ALIGN_BYTES); - return unpacker; -} - -} // namespace common -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/mpiUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/mpiUtils.cpp deleted file mode 100644 index dbdaca4ee77..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/mpiUtils.cpp +++ /dev/null @@ -1,588 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include "tensorrt_llm/common/mpiUtils.h" - -#include "tensorrt_llm/common/assert.h" -#include "tensorrt_llm/common/logger.h" -#include "tensorrt_llm/runtime/common.h" -#include "tensorrt_llm/runtime/iBuffer.h" - -#include -#include -#include -#include -#include -#ifndef _WIN32 -#include -#endif - -// We rely on SizeType32 being int32_t in some places with weak type checking, -// i.e. we're passing void ptr to some function. To prevent mysterious errors -// in the future, we trigger a compilation error here if SizeType32 isn't int32_t. -static_assert(std::is_same::value); - -namespace tensorrt_llm::mpi -{ - -MPI_Datatype getMpiDtype(MpiType dtype) -{ -#if ENABLE_MULTI_DEVICE - static std::unordered_map const dtype_map{ - {MpiType::kBYTE, MPI_BYTE}, - {MpiType::kHALF, MPI_UINT16_T}, - {MpiType::kFLOAT, MPI_FLOAT}, - {MpiType::kDOUBLE, MPI_DOUBLE}, - {MpiType::kBOOL, MPI_C_BOOL}, - {MpiType::kINT8, MPI_INT8_T}, - {MpiType::kUINT8, MPI_UINT8_T}, - {MpiType::kINT32, MPI_INT32_T}, - {MpiType::kUINT32, MPI_UINT32_T}, - {MpiType::kINT64, MPI_INT64_T}, - {MpiType::kUINT64, MPI_UINT64_T}, - {MpiType::kFP8, MPI_UINT8_T}, - {MpiType::kBF16, MPI_UINT16_T}, - {MpiType::kCHAR, MPI_CHAR}, - }; - return dtype_map.at(dtype); -#else - TLLM_THROW("Multi device support is disabled."); -#endif -} - -MPI_Op getMpiOp(MpiOp op) -{ -#if ENABLE_MULTI_DEVICE - static std::unordered_map const op_map{ - {MpiOp::NULLOP, MPI_OP_NULL}, - {MpiOp::MAX, MPI_MAX}, - {MpiOp::MIN, MPI_MIN}, - {MpiOp::SUM, MPI_SUM}, - {MpiOp::PROD, MPI_PROD}, - {MpiOp::LAND, MPI_LAND}, - {MpiOp::BAND, MPI_BAND}, - {MpiOp::LOR, MPI_LOR}, - {MpiOp::BOR, MPI_BOR}, - {MpiOp::LXOR, MPI_LXOR}, - {MpiOp::BXOR, MPI_BXOR}, - {MpiOp::MINLOC, MPI_MINLOC}, - {MpiOp::MAXLOC, MPI_MAXLOC}, - {MpiOp::REPLACE, MPI_REPLACE}, - }; - return op_map.at(op); -#else - TLLM_THROW("Multi device support is disabled."); -#endif // ENABLE_MULTI_DEVICE -} - -namespace -{ - -bool mpiInitialized = false; -std::recursive_mutex mpiMutex; - -MpiComm initLocalSession() -{ -#if ENABLE_MULTI_DEVICE - MPI_Comm localComm = nullptr; - MPI_Comm_split_type(COMM_SESSION, OMPI_COMM_TYPE_HOST, COMM_SESSION.getRank(), MPI_INFO_NULL, &localComm); - MpiComm localSession{localComm, false}; -#else - MpiComm localSession{COMM_SESSION, false}; -#endif // ENABLE_MULTI_DEVICE - return localSession; -} - -} // namespace - -std::vector getWorldRanks(MpiComm const& comm) -{ -#if ENABLE_MULTI_DEVICE - MPI_Group group = nullptr; - MPI_Group worldGroup = nullptr; - - MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup)); - MPICHECK(MPI_Comm_group(comm, &group)); - - int groupSize = 0; - MPICHECK(MPI_Group_size(group, &groupSize)); - std::vector ranks(groupSize); - std::vector worldRanks(groupSize); - std::iota(ranks.begin(), ranks.end(), 0); - - MPICHECK(MPI_Group_translate_ranks(group, groupSize, ranks.data(), worldGroup, worldRanks.data())); - MPICHECK(MPI_Group_free(&group)); - MPICHECK(MPI_Group_free(&worldGroup)); -#else - std::vector worldRanks{0}; -#endif - return worldRanks; -} - -void initialize(MpiThreadSupport threadMode, bool forwardAbortToParent) -{ - // double-checked locking - if (mpiInitialized) - { - return; - } - std::lock_guard lk(mpiMutex); - if (mpiInitialized) - { - return; - } -#if ENABLE_MULTI_DEVICE - int initialized = 0; - TLLM_MPI_CHECK(MPI_Initialized(&initialized)); - if (!initialized) - { - TLLM_LOG_INFO("Initializing MPI with thread mode %d", threadMode); - int providedMode = 0; - auto requiredMode = static_cast(threadMode); - MPICHECK(MPI_Init_thread(nullptr, nullptr, requiredMode, &providedMode)); - TLLM_CHECK_WITH_INFO(providedMode >= requiredMode, "MPI_Init_thread failed"); - std::atexit([]() { MPI_Finalize(); }); - - /* - * We only catch SIGABRT and SIGSEGV because most, of not all errors in the worker will cause one of these 2 - * signals. Signals like SIGINT and SIGTERM should be issued to the parent and should terminate MPI workers - * correctly. - */ - for (int sig : {SIGABRT, SIGSEGV}) - { - __sighandler_t previousHandler = nullptr; - if (forwardAbortToParent) - { - previousHandler = std::signal(sig, - [](int signal) - { -#ifndef _WIN32 - pid_t parentProcessId = getppid(); - kill(parentProcessId, SIGKILL); -#endif - MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); - }); - } - else - { - previousHandler = std::signal(sig, [](int signal) { MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); }); - } - TLLM_CHECK_WITH_INFO(previousHandler != SIG_ERR, "Signal handler setup failed"); - } - - // ensure local MPI communicator is initialized - MpiComm::localSession(); - TLLM_LOG_INFO("Initialized MPI"); - } -#endif // ENABLE_MULTI_DEVICE - mpiInitialized = true; -} - -void MpiComm::barrier() const -{ -#if ENABLE_MULTI_DEVICE - MPICHECK(MPI_Barrier(mComm)); -#else - TLLM_THROW("Multi device support is disabled."); -#endif // ENABLE_MULTI_DEVICE -} - -#if ENABLE_MULTI_DEVICE -template >>> -size_t invokeChunked(TMpiFunc func, TBase* buffer, size_t size, MPI_Datatype dtype, TArgs... args) -{ - constexpr auto maxP1 = static_cast(std::numeric_limits::max()) + 1; - if (TLLM_LIKELY(size < maxP1)) - { - MPICHECK(func(buffer, size, dtype, args...)); - return 1; - } - - constexpr size_t alignment = 256; - int elementSize = 1; - MPICHECK(MPI_Type_size(dtype, &elementSize)); - elementSize = std::min(elementSize, alignment); - - // We cap at max alignment-bytes chunks that can be sent at once. - auto const step = maxP1 - (alignment / elementSize); - - using TCast = std::conditional_t, uint8_t const, uint8_t>; - size_t count = 0; - while (size != 0) - { - auto currentStep = static_cast(std::min(size, step)); - MPICHECK(func(buffer, currentStep, dtype, args...)); - size -= currentStep; - size_t diff = static_cast(currentStep) * elementSize; - buffer = static_cast(buffer) + diff; - ++count; - } - - return count; -} -#endif // ENABLE_MULTI_DEVICE - -std::shared_ptr MpiComm::bcastAsync(void* buffer, size_t size, MpiType dtype, int root) const -{ - std::shared_ptr r = std::make_shared(); -#if ENABLE_MULTI_DEVICE - invokeChunked(MPI_Ibcast, buffer, size, getMpiDtype(dtype), root, mComm, &r->mRequest); -#else - TLLM_THROW("Multi device support is disabled."); -#endif // ENABLE_MULTI_DEVICE - return r; -} - -std::shared_ptr MpiComm::bcastAsync(runtime::IBuffer& buf, int root) const -{ - TLLM_CHECK(buf.getMemoryType() != runtime::MemoryType::kGPU); - return bcastAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root); -} - -void MpiComm::bcast(void* buffer, size_t size, MpiType dtype, int root) const -{ -#if ENABLE_MULTI_DEVICE - invokeChunked(MPI_Bcast, buffer, size, getMpiDtype(dtype), root, mComm); -#else - TLLM_THROW("Multi device support is disabled."); -#endif // ENABLE_MULTI_DEVICE -} - -void MpiComm::bcast(runtime::IBuffer& buf, int root) const -{ - bcast(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root); -} - -std::shared_ptr MpiComm::sendAsync(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const -{ - TLLM_LOG_DEBUG("start MPI_Isend with size %d", size); - std::shared_ptr r = std::make_shared(); -#if ENABLE_MULTI_DEVICE - invokeChunked(MPI_Isend, buffer, size, getMpiDtype(dtype), dest, tag, mComm, &r->mRequest); -#else - TLLM_THROW("Multi device support is disabled."); -#endif - TLLM_LOG_DEBUG("end MPI_Isend with size %d", size); - return r; -} - -std::shared_ptr MpiComm::sendAsync(runtime::IBuffer const& buf, int dest, int tag) const -{ - return sendAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag); -} - -void MpiComm::send(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const -{ - TLLM_LOG_DEBUG("start MPI_Send with size %d", size); -#if ENABLE_MULTI_DEVICE - invokeChunked(MPI_Send, buffer, size, getMpiDtype(dtype), dest, tag, mComm); -#else - TLLM_THROW("Multi device support is disabled."); -#endif // ENABLE_MULTI_DEVICE - TLLM_LOG_DEBUG("end MPI_Send with size %d", size); -} - -void MpiComm::send(runtime::IBuffer const& buf, int dest, int tag) const -{ - send(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag); -} - -MPI_Status MpiComm::recv(void* buffer, size_t size, MpiType dtype, int source, int tag) const -{ - TLLM_LOG_DEBUG("start MPI_Recv with size %d", size); - MPI_Status status{}; -#if ENABLE_MULTI_DEVICE - invokeChunked(MPI_Recv, buffer, size, getMpiDtype(dtype), source, tag, mComm, &status); -#else - TLLM_THROW("Multi device support is disabled."); -#endif // ENABLE_MULTI_DEVICE - TLLM_LOG_DEBUG("end MPI_Recv with size %d", size); - return status; -} - -MPI_Status MpiComm::recv(runtime::IBuffer& buf, int source, int tag) const -{ - return recv(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, source, tag); -} - -MpiComm MpiComm::split(int color, int key) const -{ - MPI_Comm splitComm = nullptr; -#if ENABLE_MULTI_DEVICE - MPICHECK(MPI_Comm_split(mComm, color, key, &splitComm)); -#else - TLLM_THROW("Multi device support is disabled."); -#endif // ENABLE_MULTI_DEVICE - return MpiComm{splitComm, true}; -} - -void MpiComm::allreduce(void const* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const -{ -#if ENABLE_MULTI_DEVICE - MPICHECK(MPI_Allreduce(sendbuf, recvbuf, count, getMpiDtype(dtype), getMpiOp(op), mComm)); -#else - TLLM_THROW("Multi device support is disabled."); -#endif // ENABLE_MULTI_DEVICE -} - -void MpiComm::allgather(void const* sendbuf, void* recvbuf, int count, MpiType dtype) const -{ -#if ENABLE_MULTI_DEVICE - MPICHECK(MPI_Allgather(sendbuf, count, getMpiDtype(dtype), recvbuf, count, getMpiDtype(dtype), mComm)); -#else - TLLM_THROW("Multi device support is disabled."); -#endif // ENABLE_MULTI_DEVICE -} - -void MpiComm::allgatherv(void const* sendbuf, int sendcount, MpiType sendtype, void* recvbuf, - std::vector const& recvcounts, std::vector const& displs, MpiType recvtype) const -{ -#if ENABLE_MULTI_DEVICE - MPICHECK(MPI_Allgatherv(sendbuf, sendcount, getMpiDtype(sendtype), recvbuf, recvcounts.data(), displs.data(), - getMpiDtype(recvtype), mComm)); - -#else - TLLM_THROW("Multi device support is disabled."); -#endif // ENABLE_MULTI_DEVICE -} - -void MpiComm::mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const -{ -#if ENABLE_MULTI_DEVICE - MPICHECK(MPI_Mprobe(source, tag, mComm, msg, status)); -#else - TLLM_THROW("Multi device support is disabled."); -#endif // ENABLE_MULTI_DEVICE -} - -bool MpiComm::improbe(int source, int tag, MPI_Message* msg, MPI_Status* status) const -{ -#if ENABLE_MULTI_DEVICE - int flag{0}; - MPICHECK(MPI_Improbe(source, tag, mComm, &flag, msg, status)); - return flag != 0; -#else - TLLM_THROW("Multi device support is disabled."); - return false; -#endif -} - -bool MpiComm::iprobe(int source, int tag, MPI_Status* status) const -{ -#if ENABLE_MULTI_DEVICE - int flag{0}; - MPICHECK(MPI_Iprobe(source, tag, mComm, &flag, status)); - return flag != 0; -#else - TLLM_THROW("Multi device support is disabled."); - return false; -#endif -} - -void MpiComm::recvPoll(int source, int tag, int periodMs) const -{ - MPI_Status status; - while (!iprobe(source, tag, &status)) - { - std::this_thread::sleep_for(std::chrono::milliseconds(periodMs)); - } -} - -int MpiComm::getRank() const -{ - int rank = 0; -#if ENABLE_MULTI_DEVICE - MPICHECK(MPI_Comm_rank(mComm, &rank)); -#endif - return rank; -} - -int MpiComm::getSize() const -{ - int world_size = 1; -#if ENABLE_MULTI_DEVICE - MPICHECK(MPI_Comm_size(mComm, &world_size)); -#endif - return world_size; -} - -MpiComm const& MpiComm::world() -{ - TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - static MpiComm commWorld{MPI_COMM_WORLD, false}; - initialize(); - TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); - return commWorld; -} - -MpiComm& MpiComm::mutableSession() -{ - TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - static MpiComm commSession{MPI_COMM_WORLD, false}; - initialize(); - TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); - return commSession; -} - -MpiComm& MpiComm::mutableLocalSession() -{ - TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - static MpiComm localSession = initLocalSession(); - TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); - return localSession; -} - -void MpiComm::refreshLocalSession() -{ -#if ENABLE_MULTI_DEVICE - static std::mutex mutex; - std::unique_lock lock(mutex); - auto initSessionRanks = getWorldRanks(MpiComm::session()); - auto localSessionRanks = getWorldRanks(MpiComm::localSession()); - - // Add to intersectionRanks in order of initSessionRanks - std::vector intersectionRanks; - std::unordered_set localSessionRanksSet(localSessionRanks.begin(), localSessionRanks.end()); - for (auto rank : initSessionRanks) - { - if (localSessionRanksSet.find(rank) != localSessionRanksSet.end()) - { - intersectionRanks.push_back(rank); - } - } - - MPI_Group worldGroup = nullptr; - MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup)); - MPI_Group localGroup = nullptr; - MPICHECK(MPI_Group_incl(worldGroup, intersectionRanks.size(), intersectionRanks.data(), &localGroup)); - MPI_Comm localComm = nullptr; - MPICHECK(MPI_Comm_create_group(MPI_COMM_WORLD, localGroup, intersectionRanks.front(), &localComm)); - MpiComm::mutableLocalSession().mFreeComm = true; - MpiComm::mutableLocalSession() = MpiComm{localComm, false}; - TLLM_LOG_INFO("Refreshed the MPI local session"); -#endif // ENABLE_MULTI_DEVICE -} - -MpiComm::MpiComm(MPI_Comm g, bool freeComm) - : mComm{g} - , mFreeComm{freeComm} -{ - TLLM_CHECK(mComm != MPI_COMM_NULL); -} - -MpiComm::~MpiComm() noexcept -{ -#if ENABLE_MULTI_DEVICE - if (mFreeComm && mComm) - { - if (MPI_Comm_free(&mComm) != MPI_SUCCESS) - { - TLLM_LOG_ERROR("MPI_Comm_free failed"); - } - } -#endif // ENABLE_MULTI_DEVICE -} - -MpiComm::MpiComm(MpiComm&& comm) noexcept - : mComm{comm.mComm} - , mFreeComm{comm.mFreeComm} -{ - comm.mFreeComm = false; -} - -MpiComm& MpiComm::operator=(MpiComm&& comm) noexcept -{ - this->~MpiComm(); - mComm = comm.mComm; - mFreeComm = comm.mFreeComm; - comm.mFreeComm = false; - return *this; -} - -MpiWaitThread::MpiWaitThread(std::string name, std::function funcWait, std::function funcSetup) - : mName{name.c_str()} - , mFuncWait{funcWait} - , mFuncSetup{funcSetup} -{ - TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); - mThread = std::make_unique(&MpiWaitThread::sideThread, this); - TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); -} - -MpiWaitThread::~MpiWaitThread() -{ - TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); - waitStop(); - mShouldExit.store(true); - notifyStart(); - mThread->join(); - mThread.reset(nullptr); - TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); -} - -void MpiWaitThread::sideThread() -{ - if (mFuncSetup) - { - mFuncSetup(); - } - while (!mShouldExit.load()) - { - notifyStop(); - waitStart(); - mFuncWait(); - } -} - -void MpiWaitThread::waitStart() -{ - TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); - std::unique_lock lock(mMutex); - mCondVar.wait(lock, [this] { return mRunning; }); - TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); -} - -void MpiWaitThread::waitStop() -{ - TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); - std::unique_lock lock(mMutex); - mCondVar.wait(lock, [this] { return !mRunning; }); - TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); -} - -void MpiWaitThread::notifyStart() -{ - TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); - std::lock_guard lock(mMutex); - mRunning = true; - mCondVar.notify_one(); - TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); -} - -void MpiWaitThread::notifyStop() -{ - TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); - std::lock_guard lock(mMutex); - mRunning = false; - mCondVar.notify_one(); - TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); -} - -} // namespace tensorrt_llm::mpi diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/nvtxUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/nvtxUtils.h deleted file mode 100644 index 0a9d51975af..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/nvtxUtils.h +++ /dev/null @@ -1,46 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -#include - -namespace tensorrt_llm::common::nvtx -{ -inline nvtx3::color nextColor() -{ -#ifndef NVTX_DISABLE - constexpr std::array kColors{nvtx3::color{0xff00ff00}, nvtx3::color{0xff0000ff}, nvtx3::color{0xffffff00}, - nvtx3::color{0xffff00ff}, nvtx3::color{0xff00ffff}, nvtx3::color{0xffff0000}, nvtx3::color{0xffffffff}}; - constexpr auto numColors = kColors.size(); - - static thread_local std::size_t colorId = 0; - auto const color = kColors[colorId]; - colorId = colorId + 1 >= numColors ? 0 : colorId + 1; - return color; -#else - return nvtx3::color{0}; -#endif -} - -} // namespace tensorrt_llm::common::nvtx - -#define NVTX3_SCOPED_RANGE_WITH_NAME(range, name) \ - ::nvtx3::scoped_range range(::tensorrt_llm::common::nvtx::nextColor(), name) -#define NVTX3_SCOPED_RANGE(range) NVTX3_SCOPED_RANGE_WITH_NAME(range##_range, #range) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.cpp deleted file mode 100644 index 39aefda481a..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.cpp +++ /dev/null @@ -1,323 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tensorrt_llm/common/opUtils.h" -#include "tensorrt_llm/common/mpiUtils.h" - -#include "cuda.h" -#include -#include -#include -#include -#include -#include -#include - -#ifdef _MSC_VER -#define FN_NAME __FUNCTION__ -#else -#define FN_NAME __func__ -#endif - -#if ENABLE_MULTI_DEVICE - -std::unordered_map* getDtypeMap() -{ - static std::unordered_map dtypeMap = {{nvinfer1::DataType::kFLOAT, ncclFloat32}, - {nvinfer1::DataType::kHALF, ncclFloat16}, {nvinfer1::DataType::kBF16, ncclBfloat16}}; - return &dtypeMap; -} - -namespace -{ - -// Get NCCL unique ID for a group of ranks. -ncclUniqueId getUniqueId(std::set const& group) noexcept -{ - auto const rank = COMM_SESSION.getRank(); - TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank); - ncclUniqueId id; - if (rank == *group.begin()) - { - NCCLCHECK(ncclGetUniqueId(&id)); - for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it) - { - COMM_SESSION.sendValue(id, *it, 0); - } - } - else - { - COMM_SESSION.recvValue(id, *group.begin(), 0); - } - TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank); - return id; -} -} // namespace - -std::shared_ptr getComm(std::set const& group) -{ - auto const rank = COMM_SESSION.getRank(); - TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank); - static std::map, std::shared_ptr> commMap; - static std::mutex mutex; - std::lock_guard lock(mutex); - std::ostringstream oss; - int index = 0; - for (auto const& rank : group) - { - if (index != 0) - { - oss << ","; - } - oss << rank; - index++; - } - auto groupStr = oss.str(); - auto it = commMap.find(group); - if (it != commMap.end()) - { - auto ncclComm = it->second; - TLLM_LOG_TRACE("NCCL comm for group(%s) is cached for rank %d", groupStr.c_str(), rank); - return ncclComm; - } - - TLLM_LOG_TRACE("Init NCCL comm for group(%s) for rank %d", groupStr.c_str(), rank); - ncclUniqueId id = getUniqueId(group); - int groupRank = 0; - for (auto const& currentRank : group) - { - if (rank == currentRank) - break; - ++groupRank; - } - TLLM_CHECK(groupRank < group.size()); - std::shared_ptr ncclComm(new ncclComm_t, - [](ncclComm_t* comm) - { - ncclCommDestroy(*comm); - delete comm; - }); - NCCLCHECK(ncclCommInitRank(ncclComm.get(), group.size(), id, groupRank)); - commMap[group] = ncclComm; - TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank); - return ncclComm; -} -#endif // ENABLE_MULTI_DEVICE - -void const* tensorrt_llm::common::getCommSessionHandle() -{ -#if ENABLE_MULTI_DEVICE - return &COMM_SESSION; -#else - return nullptr; -#endif // ENABLE_MULTI_DEVICE -} - -namespace -{ - -// Get current cuda context, a default context will be created if there is no context. -inline CUcontext getCurrentCudaCtx() -{ - CUcontext ctx{}; - CUresult err = cuCtxGetCurrent(&ctx); - if (err == CUDA_ERROR_NOT_INITIALIZED || ctx == nullptr) - { - TLLM_CUDA_CHECK(cudaFree(nullptr)); - err = cuCtxGetCurrent(&ctx); - } - TLLM_CHECK(err == CUDA_SUCCESS); - return ctx; -} - -// Helper to create per-cuda-context singleton managed by std::shared_ptr. -// Unlike conventional singletons, singleton created with this will be released -// when not needed, instead of on process exit. -// Objects of this class shall always be declared static / global, and shall never own CUDA -// resources. -template -class PerCudaCtxSingletonCreator -{ -public: - using CreatorFunc = std::function()>; - using DeleterFunc = std::function; - - // creator returning std::unique_ptr is by design. - // It forces separation of memory for T and memory for control blocks. - // So when T is released, but we still have observer weak_ptr in mObservers, the T mem block can be released. - // creator itself must not own CUDA resources. Only the object it creates can. - PerCudaCtxSingletonCreator(CreatorFunc creator, DeleterFunc deleter) - : mCreator{std::move(creator)} - , mDeleter{std::move(deleter)} - { - } - - std::shared_ptr operator()() - { - std::lock_guard lk{mMutex}; - CUcontext ctx{getCurrentCudaCtx()}; - std::shared_ptr result = mObservers[ctx].lock(); - if (result == nullptr) - { - // Create the resource and register with an observer. - result = std::shared_ptr{mCreator().release(), - [this, ctx](T* obj) - { - if (obj == nullptr) - { - return; - } - mDeleter(obj); - - // Clears observer to avoid growth of mObservers, in case users creates/destroys cuda contexts - // frequently. - std::shared_ptr observedObjHolder; // Delay destroy to avoid dead lock. - std::lock_guard lk{mMutex}; - // Must check observer again because another thread may created new instance for this ctx just - // before we lock mMutex. We can't infer that the observer is stale from the fact that obj is - // destroyed, because shared_ptr ref-count checking and observer removing are not in one atomic - // operation, and the observer may be changed to observe another instance. - observedObjHolder = mObservers.at(ctx).lock(); - if (observedObjHolder == nullptr) - { - mObservers.erase(ctx); - } - }}; - mObservers.at(ctx) = result; - } - return result; - } - -private: - CreatorFunc mCreator; - DeleterFunc mDeleter; - mutable std::mutex mMutex; - // CUDA resources are per-context. - std::unordered_map> mObservers; -}; - -template -class PerThreadSingletonCreator -{ -public: - using CreatorFunc = std::function()>; - using DeleterFunc = std::function; - - // creator returning std::unique_ptr is by design. - // It forces separation of memory for T and memory for control blocks. - // So when T is released, but we still have observer weak_ptr in mObservers, the T mem block can be released. - // creator itself must not own CUDA resources. Only the object it creates can. - PerThreadSingletonCreator(CreatorFunc creator, DeleterFunc deleter) - : mCreator{std::move(creator)} - , mDeleter{std::move(deleter)} - { - } - - std::shared_ptr operator()() - { - std::lock_guard lk{mMutex}; - - std::thread::id thread = std::this_thread::get_id(); - std::shared_ptr result = mObservers[thread].lock(); - - if (result == nullptr) - { - // Create the resource and register with an observer. - result = std::shared_ptr{mCreator().release(), - [this, thread](T* obj) - { - if (obj == nullptr) - { - return; - } - mDeleter(obj); - - // Clears observer to avoid growth of mObservers, in case users creates/destroys cuda contexts - // frequently. - std::shared_ptr observedObjHolder; // Delay destroy to avoid dead lock. - std::lock_guard lk{mMutex}; - // Must check observer again because another thread may created new instance for this ctx just - // before we lock mMutex. We can't infer that the observer is stale from the fact that obj is - // destroyed, because shared_ptr ref-count checking and observer removing are not in one atomic - // operation, and the observer may be changed to observe another instance. - observedObjHolder = mObservers.at(thread).lock(); - if (observedObjHolder == nullptr) - { - mObservers.erase(thread); - } - }}; - mObservers.at(thread) = result; - } - return result; - } - -private: - CreatorFunc mCreator; - DeleterFunc mDeleter; - mutable std::mutex mMutex; - // CUDA resources are per-thread. - std::unordered_map> mObservers; -}; - -} // namespace - -std::shared_ptr getCublasHandle() -{ - static PerThreadSingletonCreator creator( - []() -> auto - { - auto handle = std::unique_ptr(new cublasHandle_t); - TLLM_CUDA_CHECK(cublasCreate(handle.get())); - return handle; - }, - [](cublasHandle_t* handle) - { - TLLM_CUDA_CHECK(cublasDestroy(*handle)); - delete handle; - }); - return creator(); -} - -std::shared_ptr getCublasLtHandle() -{ - static PerThreadSingletonCreator creator( - []() -> auto - { - auto handle = std::unique_ptr(new cublasLtHandle_t); - TLLM_CUDA_CHECK(cublasLtCreate(handle.get())); - return handle; - }, - [](cublasLtHandle_t* handle) - { - TLLM_CUDA_CHECK(cublasLtDestroy(*handle)); - delete handle; - }); - return creator(); -} - -std::shared_ptr getCublasMMWrapper(std::shared_ptr cublasHandle, - std::shared_ptr cublasltHandle, cudaStream_t stream, void* workspace) -{ - static PerThreadSingletonCreator creator( - [cublasHandle, cublasltHandle, stream, workspace]() -> auto - { - auto wrapper = std::unique_ptr( - new tensorrt_llm::common::CublasMMWrapper(cublasHandle, cublasltHandle, stream, workspace)); - return wrapper; - }, - [](tensorrt_llm::common::CublasMMWrapper* wrapper) { delete wrapper; }); - return creator(); -} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.h deleted file mode 100644 index 4e278e5cf23..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.h +++ /dev/null @@ -1,215 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "tensorrt_llm/common/cublasMMWrapper.h" -#include "tensorrt_llm/common/workspace.h" - -#include -#include -#include -#include -#if ENABLE_MULTI_DEVICE -#include -#endif // ENABLE_MULTI_DEVICE - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace tensorrt_llm::common -{ - -// Write values into buffer -template -void write(char*& buffer, T const& val) -{ - std::memcpy(buffer, &val, sizeof(T)); - buffer += sizeof(T); -} - -// Read values from buffer -template -void read(char const*& buffer, T& val) -{ - std::memcpy(&val, buffer, sizeof(T)); - buffer += sizeof(T); -} - -// Like std::unique_ptr, but does not prevent generation of default copy constructor when used as class members. -// The copy constructor produces nullptr. So the plugin default copy constructor will not really copy this, and -// your clone() implementation is responsible for initializing such data members. -// With this we can simplify clone() implementation when there are many data members including at least one unique_ptr. -template > -class UniqPtrWNullCopy : public std::unique_ptr -{ -public: - using std::unique_ptr::unique_ptr; - - // for compatibility with std::make_unique - explicit UniqPtrWNullCopy(std::unique_ptr&& src) - : std::unique_ptr::unique_ptr{std::move(src)} - { - } - - // copy constructor produces nullptr - UniqPtrWNullCopy(UniqPtrWNullCopy const&) - : std::unique_ptr::unique_ptr{} - { - } -}; - -// for testing only -void const* getCommSessionHandle(); -} // namespace tensorrt_llm::common - -inline bool isBuilding() -{ - auto constexpr key = "IS_BUILDING"; - auto const val = getenv(key); - return val != nullptr && std::string(val) == "1"; -} - -#if ENABLE_MULTI_DEVICE -#define NCCLCHECK(cmd) \ - do \ - { \ - ncclResult_t r = cmd; \ - if (r != ncclSuccess) \ - { \ - printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)); \ - exit(EXIT_FAILURE); \ - } \ - } while (0) - -std::unordered_map* getDtypeMap(); - -std::shared_ptr getComm(std::set const& group); - -#endif // ENABLE_MULTI_DEVICE - -//! To save GPU memory, all the plugins share the same cublas and cublasLt handle globally. -//! Get cublas and cublasLt handle for current cuda context -std::shared_ptr getCublasHandle(); -std::shared_ptr getCublasLtHandle(); -std::shared_ptr getCublasMMWrapper(std::shared_ptr cublasHandle, - std::shared_ptr cublasltHandle, cudaStream_t stream, void* workspace); - -#ifndef DEBUG - -#define PLUGIN_CHECK(status) \ - do \ - { \ - if (status != 0) \ - abort(); \ - } while (0) - -#define ASSERT_PARAM(exp) \ - do \ - { \ - if (!(exp)) \ - return STATUS_BAD_PARAM; \ - } while (0) - -#define ASSERT_FAILURE(exp) \ - do \ - { \ - if (!(exp)) \ - return STATUS_FAILURE; \ - } while (0) - -#define CSC(call, err) \ - do \ - { \ - cudaError_t cudaStatus = call; \ - if (cudaStatus != cudaSuccess) \ - { \ - return err; \ - } \ - } while (0) - -#define DEBUG_PRINTF(...) \ - do \ - { \ - } while (0) - -#else - -#define ASSERT_PARAM(exp) \ - do \ - { \ - if (!(exp)) \ - { \ - fprintf(stderr, "Bad param - " #exp ", %s:%d\n", __FILE__, __LINE__); \ - return STATUS_BAD_PARAM; \ - } \ - } while (0) - -#define ASSERT_FAILURE(exp) \ - do \ - { \ - if (!(exp)) \ - { \ - fprintf(stderr, "Failure - " #exp ", %s:%d\n", __FILE__, __LINE__); \ - return STATUS_FAILURE; \ - } \ - } while (0) - -#define CSC(call, err) \ - do \ - { \ - cudaError_t cudaStatus = call; \ - if (cudaStatus != cudaSuccess) \ - { \ - printf("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(cudaStatus)); \ - return err; \ - } \ - } while (0) - -#define PLUGIN_CHECK(status) \ - { \ - if (status != 0) \ - { \ - DEBUG_PRINTF("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(status)); \ - abort(); \ - } \ - } - -#define DEBUG_PRINTF(...) \ - do \ - { \ - printf(__VA_ARGS__); \ - } while (0) - -#endif // DEBUG - -#define NVML_CHECK(cmd) \ - do \ - { \ - nvmlReturn_t r = cmd; \ - if (r != NVML_SUCCESS) \ - { \ - printf("Failed, NVML error %s:%d '%s'\n", __FILE__, __LINE__, nvmlErrorString(r)); \ - exit(EXIT_FAILURE); \ - } \ - } while (0) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/quantization.h b/sgl-kernel/3rdparty/tensorrt_llm/common/quantization.h new file mode 100644 index 00000000000..052d9c8c819 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/quantization.h @@ -0,0 +1,358 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace tensorrt_llm +{ +namespace common +{ + +class QuantMode +{ + // [WARNING] KEEP BELOW DEFINITION IN SYNC WITH tensorrt_llm/quantization/mode.py +public: + using BaseType = std::uint32_t; + + explicit constexpr QuantMode(BaseType value) noexcept + : mValue{value} + { + } + + QuantMode() noexcept = default; + + constexpr QuantMode(QuantMode const&) noexcept = default; + + constexpr QuantMode& operator=(QuantMode const& other) noexcept = default; + + static constexpr QuantMode none() noexcept + { + return QuantMode(BaseType(0)); + } + + static constexpr QuantMode int4Weights() noexcept + { + return QuantMode(BaseType(1u) << 0); + } + + static constexpr QuantMode int8Weights() noexcept + { + return QuantMode(BaseType(1u) << 1); + } + + static constexpr QuantMode activations() noexcept + { + return QuantMode(BaseType(1u) << 2); + } + + static constexpr QuantMode perChannelScaling() noexcept + { + return QuantMode(BaseType(1u) << 3); + } + + static constexpr QuantMode perTokenScaling() noexcept + { + return QuantMode(BaseType(1u) << 4); + } + + static constexpr QuantMode perGroupScaling() noexcept + { + return QuantMode(BaseType(1u) << 5); + } + + static constexpr QuantMode int8KvCache() noexcept + { + return QuantMode(BaseType(1u) << 6); + } + + static constexpr QuantMode fp8KvCache() noexcept + { + return QuantMode(BaseType(1u) << 7); + } + + static constexpr QuantMode fp8Qdq() noexcept + { + return QuantMode(BaseType(1u) << 8); + } + + static constexpr QuantMode fp8RowWise() noexcept + { + return QuantMode(BaseType(1u) << 3 | BaseType(1u) << 4 | BaseType(1u) << 9); + } + + static constexpr QuantMode w4a8QServe() noexcept + { + return QuantMode(BaseType(1u) << 10); + } + + constexpr BaseType value() const noexcept + { + return mValue; + } + + constexpr bool isSet(QuantMode const& mode) const noexcept + { + return (mValue & mode.value()) == mode.value(); + } + + constexpr bool hasInt4Weights() const noexcept + { + return isSet(int4Weights()); + } + + constexpr bool hasInt8Weights() const noexcept + { + return isSet(int8Weights()); + } + + constexpr bool hasActivations() const noexcept + { + return isSet(activations()); + } + + constexpr bool hasPerChannelScaling() const noexcept + { + return isSet(perChannelScaling()); + } + + constexpr bool hasPerTokenScaling() const noexcept + { + return isSet(perTokenScaling()); + } + + constexpr bool hasPerGroupScaling() const noexcept + { + return isSet(perGroupScaling()); + } + + constexpr bool hasStaticActivationScaling() const noexcept + { + return !hasPerTokenScaling(); + } + + constexpr bool hasInt8KvCache() const noexcept + { + return isSet(int8KvCache()); + } + + constexpr bool hasFp8KvCache() const noexcept + { + return isSet(fp8KvCache()); + } + + constexpr bool hasFp8Qdq() const noexcept + { + return isSet(fp8Qdq()); + } + + constexpr bool hasFp8RowWise() const noexcept + { + return isSet(fp8RowWise()); + } + + constexpr bool hasKvCacheQuant() const noexcept + { + return hasInt8KvCache() || hasFp8KvCache(); + } + + static constexpr QuantMode fromDescription(bool quantizeWeights = false, bool quantizeActivations = false, + bool perToken = false, bool perChannel = false, bool perGroup = false, bool useInt4Weights = false, + bool useInt8KvCache = false, bool useFp8KvCache = false, bool useFp8Qdq = false, bool useFp8RowWise = false, + bool useW4a8QServe = false) + { + QuantMode quantMode{}; + if (quantizeWeights) + { + if (useInt4Weights) + quantMode += int4Weights(); + else + quantMode += int8Weights(); + } + + if (quantizeActivations) + { + quantMode += activations(); + } + + if (perChannel) + { + quantMode += QuantMode::perChannelScaling(); + } + if (perToken) + { + quantMode += QuantMode::perTokenScaling(); + } + if (perGroup) + { + quantMode += QuantMode::perGroupScaling(); + } + + if (useInt8KvCache) + { + quantMode += int8KvCache(); + } + + if (useFp8KvCache) + { + quantMode += fp8KvCache(); + } + + if (useFp8Qdq) + { + quantMode += fp8Qdq(); + } + + if (useFp8RowWise) + { + quantMode += fp8RowWise(); + } + + if (useW4a8QServe) + { + quantMode += w4a8QServe(); + } + + return quantMode; + } + + static constexpr QuantMode useSmoothQuant(bool perToken = false, bool perChannel = false) + { + return fromDescription(true, true, perToken, perChannel); + } + + static constexpr QuantMode useQServe(bool perGroup) + { + return fromDescription(true, true, false, false, perGroup, true, false, false, false, false, true); + } + + static constexpr QuantMode useWeightOnly(bool useInt4Weights = false, bool perGroup = false) + { + return fromDescription(true, false, false, false, perGroup, useInt4Weights); + } + + static QuantMode const fromQuantAlgo( + std::optional quantAlgo = std::nullopt, std::optional kvCacheQuantAlgo = std::nullopt) + { + QuantMode quantMode{}; + if (quantAlgo == "W8A16") + { + quantMode = useWeightOnly(false, false); + } + else if (quantAlgo == "W4A16") + { + quantMode = useWeightOnly(true, false); + } + else if (quantAlgo == "W4A16_AWQ") + { + quantMode = useWeightOnly(true, true); + } + else if (quantAlgo == "W4A8_AWQ") + { + quantMode = useWeightOnly(true, true); + } + else if (quantAlgo == "W4A8_QSERVE_PER_GROUP") + { + quantMode = useQServe(false); + } + else if (quantAlgo == "W4A8_QSERVE_PER_CHANNEL") + { + quantMode = useQServe(true); + } + else if (quantAlgo == "W4A16_GPTQ") + { + quantMode = useWeightOnly(true, true); + } + else if (quantAlgo == "W8A8_SQ_PER_CHANNEL") + { + quantMode = useSmoothQuant(false, true); + } + else if (quantAlgo == "W8A8_SQ_PER_TENSOR_PLUGIN") + { + quantMode = useSmoothQuant(false, false); + } + else if (quantAlgo == "W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN") + { + quantMode = useSmoothQuant(true, true); + } + else if (quantAlgo == "W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN") + { + quantMode = useSmoothQuant(false, true); + } + else if (quantAlgo == "W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN") + { + quantMode = useSmoothQuant(true, false); + } + else if (quantAlgo == "FP8") + { + quantMode = fromDescription(false, false, false, false, false, false, false, false, true); + } + else if (quantAlgo == "FP8_ROWWISE") + { + quantMode = fromDescription(false, false, true, true, false, false, false, false, false, true); + } + + if (kvCacheQuantAlgo == "INT8") + { + quantMode += int8KvCache(); + } + else if (kvCacheQuantAlgo == "FP8") + { + quantMode += fp8KvCache(); + } + + return quantMode; + } + + constexpr QuantMode operator+(QuantMode const& other) const noexcept + { + return QuantMode(mValue | other.mValue); + } + + constexpr QuantMode& operator+=(QuantMode const& other) noexcept + { + return *this = *this + other; + } + + constexpr QuantMode operator-(QuantMode const& other) const noexcept + { + return QuantMode(mValue & ~other.mValue); + } + + constexpr QuantMode& operator-=(QuantMode const& other) noexcept + { + return *this = *this - other; + } + + constexpr bool operator==(QuantMode const& other) const noexcept + { + return mValue == other.mValue; + } + + constexpr bool operator!=(QuantMode const& other) const noexcept + { + return !(*this == other); + } + +private: + BaseType mValue{0}; +}; + +} // namespace common +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/stlUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/stlUtils.h deleted file mode 100644 index 9cda9fa0d42..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/stlUtils.h +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include - -namespace tensorrt_llm::common::stl_utils -{ - -template -constexpr TOutputIt basicInclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst, TBinOp op) -{ - if (first != last) - { - auto val = *first; - while (true) - { - *dFirst = val; - ++dFirst; - ++first; - if (first == last) - { - break; - } - val = op(std::move(val), *first); - } - } - return dFirst; -} - -template -constexpr TOutputIt inclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst) -{ -#if defined(__GNUC__) && __GNUC__ <= 8 - return basicInclusiveScan(first, last, dFirst, std::plus<>{}); -#else - return std::inclusive_scan(first, last, dFirst); -#endif -} - -template -constexpr TOutputIt basicExclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst, T init, TBinOp op) -{ - if (first != last) - { - while (true) - { - T tmp{op(init, *first)}; - *dFirst = init; - ++dFirst; - ++first; - if (first == last) - { - break; - } - init = std::move(tmp); - } - } - return dFirst; -} - -template -constexpr TOutputIt exclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst, T init) -{ -#if defined(__GNUC__) && __GNUC__ <= 8 - return basicExclusiveScan(first, last, dFirst, std::move(init), std::plus<>{}); -#else - return std::exclusive_scan(first, last, dFirst, std::move(init)); -#endif -} - -template -struct HasOperatorOutput : std::false_type -{ -}; - -template -struct HasOperatorOutput() << std::declval()))>> - : std::true_type -{ -}; - -template -std::string toString(T const& t, typename std::enable_if_t::value, int> = 0) -{ - std::ostringstream oss; - oss << t; - return oss.str(); -} - -template -std::string toString(std::optional const& t, typename std::enable_if_t::value, int> = 0) -{ - std::ostringstream oss; - if (t) - { - oss << t.value(); - } - else - { - oss << "None"; - } - return oss.str(); -} - -} // namespace tensorrt_llm::common::stl_utils diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.h new file mode 100644 index 00000000000..9c5ecde98c5 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.h @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#if ENABLE_BF16 +#include +#endif // ENABLE_BF16 +#include + +#include // std::make_unique +#include // std::stringstream +#include +#include +#include + +namespace tensorrt_llm::common +{ +#if ENABLE_BF16 +static inline std::basic_ostream& operator<<(std::basic_ostream& stream, __nv_bfloat16 const& val) +{ + stream << __bfloat162float(val); + return stream; +} +#endif // ENABLE_BF16 + +static inline std::basic_ostream& operator<<(std::basic_ostream& stream, __half const& val) +{ + stream << __half2float(val); + return stream; +} + +inline std::string fmtstr(std::string const& s) +{ + return s; +} + +inline std::string fmtstr(std::string&& s) +{ + return s; +} + +#if defined(_MSC_VER) +std::string fmtstr(char const* format, ...); +#else +std::string fmtstr(char const* format, ...) __attribute__((format(printf, 1, 2))); +#endif + +// __PRETTY_FUNCTION__ is used for neat debugging printing but is not supported on Windows +// The alternative is __FUNCSIG__, which is similar but not identical +#if defined(_WIN32) +#define __PRETTY_FUNCTION__ __FUNCSIG__ +#endif + +auto constexpr kDefaultDelimiter = ", "; + +template +inline TStream& arr2outCasted(TStream& out, T* arr, size_t size, char const* delim = kDefaultDelimiter) +{ + out << "("; + if (size > 0) + { + for (size_t i = 0; i < size - 1; ++i) + { + out << static_cast(arr[i]) << delim; + } + out << static_cast(arr[size - 1]); + } + out << ")"; + return out; +} + +template +inline TStream& arr2out(TStream& out, T* arr, size_t size, char const* delim = kDefaultDelimiter) +{ + return arr2outCasted(out, arr, size, delim); +} + +template +inline std::string arr2str(T* arr, size_t size, char const* delim = kDefaultDelimiter) +{ + std::stringstream ss; + return arr2out(ss, arr, size, delim).str(); +} + +template +inline std::string vec2str(std::vector const& vec, char const* delim = kDefaultDelimiter) +{ + return arr2str(vec.data(), vec.size(), delim); +} + +inline bool strStartsWith(std::string const& str, std::string const& prefix) +{ + return str.rfind(prefix, 0) == 0; +} + +/// @brief Split a string into a set of strings using a delimiter +std::unordered_set str2set(std::string const& input, char delimiter); + +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.cpp deleted file mode 100644 index c00041abdac..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.cpp +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -#include "tensorrt_llm/common/timestampUtils.h" - -namespace tensorrt_llm::common -{ - -std::string getCurrentTimestamp() -{ - auto now = std::chrono::system_clock::now(); - auto now_t = std::chrono::system_clock::to_time_t(now); - auto tm = *std::localtime(&now_t); - - auto epoch_to_now = now.time_since_epoch(); - auto seconds = std::chrono::duration_cast(epoch_to_now); - auto us = std::chrono::duration_cast(epoch_to_now - seconds); - - std::ostringstream stream; - stream << std::put_time(&tm, "%m-%d-%Y %H:%M:%S"); - stream << "." << std::setfill('0') << std::setw(6) << us.count(); - return stream.str(); -} - -} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.h similarity index 50% rename from sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.h rename to sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.h index f52f23028c1..47e0e63d3fc 100644 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.h +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.h @@ -14,12 +14,35 @@ * limitations under the License. */ +#pragma once + +#include +#include +#include #include +#define NEW_TLLM_EXCEPTION(...) \ + tensorrt_llm::common::TllmException(__FILE__, __LINE__, tensorrt_llm::common::fmtstr(__VA_ARGS__)) + namespace tensorrt_llm::common { -/// @brief Get the current timestamp in the format "MM-DD-YYYY HH:MM:SS:uuuuuu" -std::string getCurrentTimestamp(); +class TllmException : public std::runtime_error +{ +public: + static auto constexpr MAX_FRAMES = 128; + + explicit TllmException(char const* file, std::size_t line, std::string const& msg); + + ~TllmException() noexcept override; + + [[nodiscard]] std::string getTrace() const; + + static std::string demangle(char const* name); + +private: + std::array mCallstack{}; + int mNbFrames; +}; } // namespace tensorrt_llm::common