diff --git a/.github/workflows/release-docker.yml b/.github/workflows/release-docker.yml index 99ffd7c49cb..d5669886d18 100644 --- a/.github/workflows/release-docker.yml +++ b/.github/workflows/release-docker.yml @@ -14,7 +14,7 @@ jobs: environment: 'prod' strategy: matrix: - cuda_version: ['11.8.0', '12.1.1', '12.4.1'] + cuda_version: ['11.8.0', '12.1.1', '12.4.1', '12.5.1'] build_type: ['all', 'srt'] steps: - name: Delete huge unnecessary tools folder @@ -39,6 +39,8 @@ jobs: cuda_tag="cu121" elif [ "${{ matrix.cuda_version }}" = "12.4.1" ]; then cuda_tag="cu124" + elif [ "${{ matrix.cuda_version }}" = "12.5.1" ]; then + cuda_tag="cu125" else echo "Unsupported CUDA version" exit 1 @@ -58,7 +60,7 @@ jobs: docker build . -f docker/Dockerfile --build-arg CUDA_VERSION=${{ matrix.cuda_version }} --build-arg BUILD_TYPE=${{ matrix.build_type }} -t lmsysorg/sglang:${tag}${tag_suffix} --no-cache docker push lmsysorg/sglang:${tag}${tag_suffix} - if [ "${{ matrix.cuda_version }}" = "12.4.1" ]; then + if [ "${{ matrix.cuda_version }}" = "12.5.1" ]; then docker tag lmsysorg/sglang:${tag}${tag_suffix} lmsysorg/sglang:latest${tag_suffix} docker push lmsysorg/sglang:latest${tag_suffix} fi diff --git a/README.md b/README.md index e4c5f12f39a..b27271a1810 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ Learn more in the release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-s [Development Roadmap (2024 Q4)](https://github.com/sgl-project/sglang/issues/1487) ## Adoption and Sponsorship -The project is supported by (alphabetically): AMD, Baseten, Cursor, DataCrunch, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, LMSYS.org, Meituan, NVIDIA, RunPod, Stanford, UC Berkeley, UCLA, xAI, 01.AI. +The project is supported by (alphabetically): AMD, Baseten, Cursor, DataCrunch, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, LMSYS.org, Meituan, Novita AI, NVIDIA, RunPod, Stanford, UC Berkeley, UCLA, xAI, 01.AI. ## Acknowledgment and Citation We learned the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql). Please cite the paper, [SGLang: Efficient Execution of Structured Language Model Programs](https://arxiv.org/abs/2312.07104), if you find the project useful. diff --git a/docker/Dockerfile b/docker/Dockerfile index 1fe702d4014..cec05825d0b 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -30,6 +30,8 @@ RUN python3 -m pip install --upgrade pip setuptools wheel html5lib six \ python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu121; \ elif [ "$CUDA_VERSION" = "12.4.1" ]; then \ python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu124; \ + elif [ "$CUDA_VERSION" = "12.5.1" ]; then \ + python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu124; \ elif [ "$CUDA_VERSION" = "11.8.0" ]; then \ python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu118; \ python3 -m pip install sgl-kernel -i https://docs.sglang.ai/whl/cu118; \ @@ -42,6 +44,8 @@ RUN python3 -m pip install --upgrade pip setuptools wheel html5lib six \ python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu121/torch2.4/flashinfer/; \ elif [ "$CUDA_VERSION" = "12.4.1" ]; then \ python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/; \ + elif [ "$CUDA_VERSION" = "12.5.1" ]; then \ + python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/; \ elif [ "$CUDA_VERSION" = "11.8.0" ]; then \ python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu118/torch2.4/flashinfer/; \ python3 -m pip install sgl-kernel -i https://docs.sglang.ai/whl/cu118; \ @@ -53,6 +57,8 @@ RUN python3 -m pip install --upgrade pip setuptools wheel html5lib six \ python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu121/torch2.4/flashinfer/; \ elif [ "$CUDA_VERSION" = "12.4.1" ]; then \ python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/; \ + elif [ "$CUDA_VERSION" = "12.5.1" ]; then \ + python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/; \ elif [ "$CUDA_VERSION" = "11.8.0" ]; then \ python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu118/torch2.4/flashinfer/; \ python3 -m pip install sgl-kernel -i https://docs.sglang.ai/whl/cu118; \ diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index f04254e54c9..af9f9e24df7 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -1,5 +1,5 @@ # Usage (to build SGLang ROCm docker image): -# docker build --build-arg SGL_BRANCH=v0.4.2 -t v0.4.2-rocm620 -f Dockerfile.rocm . +# docker build --build-arg SGL_BRANCH=v0.4.2.post1 -t v0.4.2.post1-rocm620 -f Dockerfile.rocm . # default base image ARG BASE_IMAGE="rocmshared/vllm-rocm:20250114-tuned-elementwise-layernorm" diff --git a/docs/backend/speculative_decoding.ipynb b/docs/backend/speculative_decoding.ipynb index d69436eed17..273d943d120 100644 --- a/docs/backend/speculative_decoding.ipynb +++ b/docs/backend/speculative_decoding.ipynb @@ -8,10 +8,11 @@ "\n", "SGLang now provides an EAGLE-based speculative decoding option. The implementation aims to maximize speed and efficiency and is considered to be among the fastest in open-source LLM engines.\n", "\n", + "**Note:** Currently, Speculative Decoding in SGLang does not support radix cache.\n", + "\n", "To run the following tests or benchmarks, you also need to install [**cutex**](https://pypi.org/project/cutex/): \n", - "> ```bash\n", - "> pip install cutex\n", - "> ```\n", + "\n", + "`pip install cutex`\n", "\n", "### Performance Highlights\n", "\n", diff --git a/docs/developer/setup_github_runner.md b/docs/developer/setup_github_runner.md index 779c413977c..96c9cae0154 100644 --- a/docs/developer/setup_github_runner.md +++ b/docs/developer/setup_github_runner.md @@ -11,9 +11,9 @@ docker pull nvidia/cuda:12.1.1-devel-ubuntu22.04 # Nvidia docker run --shm-size 128g -it -v /tmp/huggingface:/hf_home --gpus all nvidia/cuda:12.1.1-devel-ubuntu22.04 /bin/bash # AMD -docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.2-rocm620 /bin/bash +docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.2.post1-rocm620 /bin/bash # AMD just the last 2 GPUs -docker run --rm --device=/dev/kfd --device=/dev/dri/renderD176 --device=/dev/dri/renderD184 --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.2-rocm620 /bin/bash +docker run --rm --device=/dev/kfd --device=/dev/dri/renderD176 --device=/dev/dri/renderD184 --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.2.post1-rocm620 /bin/bash ``` ### Step 2: Configure the runner by `config.sh` diff --git a/docs/references/supported_models.md b/docs/references/supported_models.md index 93c4273765d..85de12f9f47 100644 --- a/docs/references/supported_models.md +++ b/docs/references/supported_models.md @@ -2,7 +2,7 @@ ## Generative Models - Llama / Llama 2 / Llama 3 / Llama 3.1 / Llama 3.2 -- Mistral / Mixtral / Mistral NeMo +- Mistral / Mixtral / Mistral NeMo / Mistral Small 3 - Gemma / Gemma 2 - Qwen / Qwen 2 / Qwen 2 MoE / Qwen 2 VL - DeepSeek / DeepSeek 2 / [DeepSeek 3](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3) diff --git a/docs/start/install.md b/docs/start/install.md index 90964ac6b6c..a5012d6fc70 100644 --- a/docs/start/install.md +++ b/docs/start/install.md @@ -14,7 +14,7 @@ Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/ ## Method 2: From source ``` # Use the last release branch -git clone -b v0.4.2 https://github.com/sgl-project/sglang.git +git clone -b v0.4.2.post1 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip @@ -28,7 +28,7 @@ Note: To AMD ROCm system with Instinct/MI GPUs, do following instead: ``` # Use the last release branch -git clone -b v0.4.2 https://github.com/sgl-project/sglang.git +git clone -b v0.4.2.post1 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip @@ -54,7 +54,7 @@ docker run --gpus all \ Note: To AMD ROCm system with Instinct/MI GPUs, it is recommended to use `docker/Dockerfile.rocm` to build images, example and usage as below: ```bash -docker build --build-arg SGL_BRANCH=v0.4.2 -t v0.4.2-rocm620 -f Dockerfile.rocm . +docker build --build-arg SGL_BRANCH=v0.4.2.post1 -t v0.4.2.post1-rocm620 -f Dockerfile.rocm . alias drun='docker run -it --rm --network=host --device=/dev/kfd --device=/dev/dri --ipc=host \ --shm-size 16G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ @@ -63,11 +63,11 @@ alias drun='docker run -it --rm --network=host --device=/dev/kfd --device=/dev/d drun -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=" \ - v0.4.2-rocm620 \ + v0.4.2.post1-rocm620 \ python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --host 0.0.0.0 --port 30000 # Till flashinfer backend available, --attention-backend triton --sampling-backend pytorch are set by default -drun v0.4.2-rocm620 python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8 +drun v0.4.2.post1-rocm620 python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8 ``` ## Method 4: Using docker compose diff --git a/python/pyproject.toml b/python/pyproject.toml index 11c984f82d7..d3d8c3f2a58 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sglang" -version = "0.4.2" +version = "0.4.2.post1" description = "SGLang is yet another fast serving framework for large language models and vision language models." readme = "README.md" requires-python = ">=3.8" @@ -27,7 +27,7 @@ runtime_common = [ ] srt = [ "sglang[runtime_common]", "cuda-python", - "sgl-kernel>=0.0.3", "torch", "vllm==0.6.4.post1", + "sgl-kernel>=0.0.3.post1", "torch", "vllm==0.6.4.post1", "flashinfer==0.1.6" ] diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 7540515c5fd..cc6da781f56 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -800,7 +800,9 @@ def call_begin_forward( kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] kv_indices = torch.empty( - paged_kernel_lens_sum, dtype=torch.int32, device="cuda" + paged_kernel_lens_sum + 256, + dtype=torch.int32, + device=req_pool_indices.device, ) create_flashinfer_kv_indices_triton[(bs,)]( self.req_to_token, diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 527a7d499b6..dc53e4445db 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -17,6 +17,8 @@ import torch import torch.nn.functional as F +from sglang.srt.utils import get_compiler_backend + def fused_topk_native( hidden_states: torch.Tensor, @@ -74,6 +76,7 @@ def fused_topk( # This is used by the Deepseek-V2 model +@torch.compile(dynamic=True, backend=get_compiler_backend()) def grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, @@ -108,6 +111,7 @@ def grouped_topk( return topk_weights.to(torch.float32), topk_ids.to(torch.int32) +@torch.compile(dynamic=True, backend=get_compiler_backend()) def biased_grouped_topk( hidden_states: torch.Tensor, gating_output: torch.Tensor, diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index b0b5b8952a1..f5a0005a282 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -290,6 +290,13 @@ def process_weights_after_loading(self, layer: Module) -> None: weight_scale, requires_grad=False ) layer.input_scale = None + else: + layer.weight = torch.nn.Parameter( + layer.weight.data, requires_grad=False + ) + layer.weight_scale_inv = torch.nn.Parameter( + layer.weight_scale_inv.data, requires_grad=False + ) return layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) # If checkpoint not serialized fp8, quantize the weights. diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index c20e478a1af..181aadeaa73 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -72,9 +72,11 @@ def forward( # NOTE: the top_p_renorm_prob from flashinfer has numerical problems, # https://github.com/flashinfer-ai/flashinfer/issues/708 # so we use the torch implementation. + + # clamp to avoid -inf logprobs = torch.log( top_p_normalize_probs_torch(probs, sampling_info.top_ps) - ) + ).clamp(min=torch.finfo(probs.dtype).min) max_top_k_round, batch_size = 32, probs.shape[0] uniform_samples = torch.rand( @@ -109,9 +111,10 @@ def forward( sampling_info.need_min_p_sampling, ) if return_logprob: + # clamp to avoid -inf logprobs = torch.log( top_p_normalize_probs_torch(probs, sampling_info.top_ps) - ) + ).clamp(min=torch.finfo(probs.dtype).min) else: raise ValueError( f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" diff --git a/python/sglang/version.py b/python/sglang/version.py index df12433297b..d1b3e6d0ae9 100644 --- a/python/sglang/version.py +++ b/python/sglang/version.py @@ -1 +1 @@ -__version__ = "0.4.2" +__version__ = "0.4.2.post1" diff --git a/sgl-kernel/3rdparty/cutlass b/sgl-kernel/3rdparty/cutlass index bdd641790ad..3c28697b9f4 160000 --- a/sgl-kernel/3rdparty/cutlass +++ b/sgl-kernel/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit bdd641790ad49353b40ada41330552a78d2f8b5a +Subproject commit 3c28697b9f41fee4517b1758ffe83a85ac3ce2b4 diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/assert.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/assert.cpp deleted file mode 100644 index eaaf6624472..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/assert.cpp +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/common/assert.h" - -namespace -{ - -bool initCheckDebug() -{ - auto constexpr kDebugEnabled = "TLLM_DEBUG_MODE"; - auto const debugEnabled = std::getenv(kDebugEnabled); - return debugEnabled && debugEnabled[0] == '1'; -} -} // namespace - -bool DebugConfig::isCheckDebugEnabled() -{ - static bool const debugEnabled = initCheckDebug(); - return debugEnabled; -} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/assert.h b/sgl-kernel/3rdparty/tensorrt_llm/common/assert.h deleted file mode 100644 index 7f51dbf1b41..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/assert.h +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "tensorrt_llm/common/stringUtils.h" -#include "tensorrt_llm/common/tllmException.h" - -#include - -namespace tensorrt_llm::common -{ -[[noreturn]] inline void throwRuntimeError(char const* const file, int const line, std::string const& info = "") -{ - throw TllmException(file, line, fmtstr("[TensorRT-LLM][ERROR] Assertion failed: %s", info.c_str())); -} - -} // namespace tensorrt_llm::common - -class DebugConfig -{ -public: - static bool isCheckDebugEnabled(); -}; - -#if defined(_WIN32) -#define TLLM_LIKELY(x) (__assume((x) == 1), (x)) -#define TLLM_UNLIKELY(x) (__assume((x) == 0), (x)) -#else -#define TLLM_LIKELY(x) __builtin_expect((x), 1) -#define TLLM_UNLIKELY(x) __builtin_expect((x), 0) -#endif - -#define TLLM_CHECK(val) \ - do \ - { \ - TLLM_LIKELY(static_cast(val)) ? ((void) 0) \ - : tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, #val); \ - } while (0) - -#define TLLM_CHECK_WITH_INFO(val, info, ...) \ - do \ - { \ - TLLM_LIKELY(static_cast(val)) \ - ? ((void) 0) \ - : tensorrt_llm::common::throwRuntimeError( \ - __FILE__, __LINE__, tensorrt_llm::common::fmtstr(info, ##__VA_ARGS__)); \ - } while (0) - -#define TLLM_CHECK_DEBUG(val) \ - do \ - { \ - if (TLLM_UNLIKELY(DebugConfig::isCheckDebugEnabled())) \ - { \ - TLLM_LIKELY(static_cast(val)) ? ((void) 0) \ - : tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, #val); \ - } \ - } while (0) - -#define TLLM_CHECK_DEBUG_WITH_INFO(val, info, ...) \ - do \ - { \ - if (TLLM_UNLIKELY(DebugConfig::isCheckDebugEnabled())) \ - { \ - TLLM_LIKELY(static_cast(val)) \ - ? ((void) 0) \ - : tensorrt_llm::common::throwRuntimeError( \ - __FILE__, __LINE__, tensorrt_llm::common::fmtstr(info, ##__VA_ARGS__)); \ - } \ - } while (0) - -#define TLLM_THROW(...) \ - do \ - { \ - throw NEW_TLLM_EXCEPTION(__VA_ARGS__); \ - } while (0) - -#define TLLM_WRAP(ex) \ - NEW_TLLM_EXCEPTION("%s: %s", tensorrt_llm::common::TllmException::demangle(typeid(ex).name()).c_str(), ex.what()) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.cpp deleted file mode 100644 index 351257f4d2e..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.cpp +++ /dev/null @@ -1,360 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/common/cublasMMWrapper.h" -#include "tensorrt_llm/common/assert.h" -#include "tensorrt_llm/common/cublasVersionCheck.h" -#include - -#ifndef CUDART_VERSION -#error CUDART_VERSION Undefined! -#endif - -namespace tensorrt_llm -{ -namespace common -{ - -CublasMMWrapper::CublasMMWrapper(std::shared_ptr cublasHandle, - std::shared_ptr cublasltHandle, cudaStream_t stream, void* workspace) - : mCublasHandle(cublasHandle) - , mCublasLtHandle(cublasltHandle) - , mStream(stream) - , mCublasWorkspace(workspace) -{ -} - -CublasMMWrapper::~CublasMMWrapper() {} - -CublasMMWrapper::CublasMMWrapper(CublasMMWrapper const& wrapper) - : mCublasHandle(wrapper.mCublasHandle) - , mCublasLtHandle(wrapper.mCublasLtHandle) - , mStream(wrapper.mStream) -{ -} - -void CublasMMWrapper::createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, - int const k, int const lda, int const ldb, int const ldc, int8_t fastAcc) -{ - // -------------------------------------- - // Create descriptors for the original matrices - check_cuda_error( - cublasLtMatrixLayoutCreate(&mADesc, mAType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda)); - check_cuda_error( - cublasLtMatrixLayoutCreate(&mBDesc, mBType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb)); - check_cuda_error(cublasLtMatrixLayoutCreate(&mCDesc, mCType, m, n, ldc)); - check_cuda_error(cublasLtMatmulDescCreate(&mOperationDesc, mComputeType, mScaleType)); - check_cuda_error(cublasLtMatmulDescSetAttribute( - mOperationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t))); - check_cuda_error(cublasLtMatmulDescSetAttribute( - mOperationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t))); - check_cuda_error( - cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAcc, sizeof(int8_t))); -} - -void CublasMMWrapper::setScaleDescriptors(void* scale_a, void* scale_b) -{ - check_cuda_error( - cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scale_a, sizeof(void*))); - check_cuda_error( - cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scale_b, sizeof(void*))); -} - -void CublasMMWrapper::destroyDescriptors() -{ - check_cuda_error(cublasLtMatmulDescDestroy(mOperationDesc)); - check_cuda_error(cublasLtMatrixLayoutDestroy(mADesc)); - check_cuda_error(cublasLtMatrixLayoutDestroy(mBDesc)); - check_cuda_error(cublasLtMatrixLayoutDestroy(mCDesc)); - mOperationDesc = NULL; - mADesc = NULL; - mBDesc = NULL; - mCDesc = NULL; -} - -void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, - void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc) -{ - Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f); -} - -void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, - void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, - std::optional const& heuristic) -{ - if (heuristic) - { - Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, /* hasAlgo */ (*heuristic).algo, - (*heuristic).state == CUBLAS_STATUS_SUCCESS && (*heuristic).workspaceSize < CUBLAS_WORKSPACE_SIZE, - /* usingCublasLt */ true); - } - else - { - Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, {}, /* hasAlgo */ false, - /* usingCublasLt */ true); - } -} - -void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, - void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta, - std::optional const& heuristic) -{ - if (heuristic) - { - Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, /* hasAlgo */ (*heuristic).algo, - (*heuristic).state == CUBLAS_STATUS_SUCCESS && (*heuristic).workspaceSize < CUBLAS_WORKSPACE_SIZE, - /* usingCublasLt */ true); - } - else - { - Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, {}, /* hasAlgo */ false, - /* usingCublasLt */ true); - } -} - -void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, - void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta) -{ - bool usingCublasLt = mAType == CUDA_R_16F || mAType == CUDA_R_8F_E4M3; - - Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, {}, /* hasAlgo */ false, - /* usingCublasLt */ usingCublasLt); -} - -void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, - void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta, - cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt) -{ - half h_alpha = (half) (f_alpha); - half h_beta = (half) (f_beta); - - // TODO: default cublas libs - usingCublasLt = usingCublasLt && (mAType == CUDA_R_16F || mAType == CUDA_R_8F_E4M3); - bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F; - int batch_count = 1; - // fp32 use cublas as default - // fp16 use cublasLt as default - void const* alpha = isFp16ComputeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); - void const* beta = isFp16ComputeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); - int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE; - - if (usingCublasLt) - { - if (hasAlgo) - { - hasAlgo = checkTactic(transa, transb, m, n, k, lda, ldb, ldc, algo); - } - - check_cuda_error(cublasLtMatmul(getCublasLtHandle(), mOperationDesc, alpha, A, mADesc, B, mBDesc, beta, C, - mCDesc, C, mCDesc, (hasAlgo ? (&algo) : NULL), mCublasWorkspace, workspaceSize, mStream)); - - sync_check_cuda_error(); - } - else - { - check_cuda_error(cublasSetStream(getCublasHandle(), mStream)); - check_cuda_error(cublasSetWorkspace(getCublasHandle(), mCublasWorkspace, workspaceSize)); - // Go with default heuristic to choose tactic as cuBLAS does not allow to choose tactics in Ampere+ - cublasGemmAlgo_t cublasAlgo = CUBLAS_GEMM_DEFAULT; - check_cuda_error(cublasGemmEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda, B, mBType, ldb, - beta, C, mCType, ldc, mComputeType, static_cast(cublasAlgo))); - sync_check_cuda_error(); - } -} - -void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, - int const k, void const* A, int const lda, const int64_t strideA, void const* B, int const ldb, - const int64_t strideB, void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha, - float const f_beta) -{ - half h_alpha = (half) f_alpha; - half h_beta = (half) f_beta; - - int isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0; - void const* alpha = isFp16ComputeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); - void const* beta = isFp16ComputeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); - - check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda, - strideA, B, mBType, ldb, strideB, beta, C, mCType, ldc, strideC, batchCount, mComputeType, - mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP)); -} - -void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, - int const k, float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA, - void const* B, cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C, - cudaDataType_t CType, int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType) -{ - half h_alpha = (half) f_alpha; - half h_beta = (half) f_beta; - - bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0; - void const* alpha = isFp16ComputeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); - void const* beta = isFp16ComputeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); - - check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, AType, lda, - strideA, B, BType, ldb, strideB, beta, C, CType, ldc, strideC, batchCount, computeType, - mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP)); -} - -void CublasMMWrapper::setWorkspace(void* workspace) -{ - mCublasWorkspace = workspace; -} - -void CublasMMWrapper::setFP32GemmConfig() -{ - setGemmConfig(CUDA_R_32F, CUDA_R_32F, CUDA_R_32F, CUDA_R_32F); -} - -void CublasMMWrapper::setFP16GemmConfig(cudaDataType_t outputType) -{ - setGemmConfig(CUDA_R_16F, CUDA_R_16F, outputType, CUDA_R_32F); -} - -#ifdef ENABLE_BF16 -void CublasMMWrapper::setBF16GemmConfig(cudaDataType_t outputType) -{ - setGemmConfig(CUDA_R_16BF, CUDA_R_16BF, outputType, CUDA_R_32F); -} -#endif - -#ifdef ENABLE_FP8 -void CublasMMWrapper::setFP8GemmConfig(cudaDataType_t outputType) -{ - setGemmConfig(CUDA_R_8F_E4M3, CUDA_R_8F_E4M3, outputType, CUDA_R_32F); -} -#endif - -void CublasMMWrapper::setGemmConfig( - cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType) -{ - mAType = aType; - mBType = bType; - mCType = cType; - bool isFp16ComputeType = computeType == CUDA_R_16F; - if (isFp16ComputeType) - { - mComputeType = CUBLAS_COMPUTE_16F; - mScaleType = CUDA_R_16F; - } - else - { - mComputeType = CUBLAS_COMPUTE_32F; - mScaleType = CUDA_R_32F; - } -} - -CublasDataType CublasMMWrapper::getCublasDataType(cudaDataType_t data_type) -{ - if (data_type == CUDA_R_16F) - { - return HALF_DATATYPE; - } - else if (data_type == CUDA_R_32F) - { - return FLOAT_DATATYPE; - } - else if (data_type == CUDA_R_8I) - { - return INT8_DATATYPE; - } -#ifdef ENABLE_BF16 - else if (data_type == CUDA_R_16BF) - { - return BFLOAT16_DATATYPE; - } -#endif - return FLOAT_DATATYPE; -} - -void CublasMMWrapper::setStream(cudaStream_t stream) -{ - mStream = stream; -} - -bool CublasMMWrapper::checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, - int const k, int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo) -{ - TLLM_CHECK_WITH_INFO( - descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function"); - - int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE; - - cublasLtMatmulHeuristicResult_t heurResult; - cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck( - getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc, &algo, &heurResult); - - if (algoStatus != CUBLAS_STATUS_SUCCESS || heurResult.state != CUBLAS_STATUS_SUCCESS - || heurResult.workspaceSize > CUBLAS_WORKSPACE_SIZE) - { - return false; - } - - sync_check_cuda_error(); - - return true; -} - -std::vector CublasMMWrapper::getTactics(cublasOperation_t transa, - cublasOperation_t transb, int const m, int const n, int const k, int const lda, int const ldb, int const ldc) -{ - TLLM_CHECK_WITH_INFO( - descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function"); - - auto const heuristics = getTactics(getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc); - - sync_check_cuda_error(); - - return heuristics; -} - -std::vector CublasMMWrapper::getTactics(cublasLtHandle_t lightHandle, - cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc, - cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc) -{ -#if TLLM_CUBLAS_VER_LE(11, 4, 2) - TLLM_CHECK_WITH_INFO(false, "CUBLAS version too low, must be > 11.4.2."); - return {}; -#else - std::vector heuristics(200); - cublasLtMatmulPreference_t preference; - check_cuda_error(cublasLtMatmulPreferenceCreate(&preference)); - check_cuda_error(cublasLtMatmulPreferenceInit(preference)); - uint64_t workspace_size = CUBLAS_WORKSPACE_SIZE; - check_cuda_error(cublasLtMatmulPreferenceSetAttribute( - preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size))); - // Restrict reduction algorithms for numerical stability and better determinism - uint32_t reduction_mask = CUBLASLT_REDUCTION_SCHEME_MASK; - check_cuda_error(cublasLtMatmulPreferenceSetAttribute( - preference, CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, &reduction_mask, sizeof(reduction_mask))); -#if TLLM_CUBLAS_VER_LT(12, 0, 0) - uint32_t pointer_mode_mask = 0; - check_cuda_error(cublasLtMatmulPreferenceSetAttribute( - preference, CUBLASLT_MATMUL_PREF_EPILOGUE_MASK, &pointer_mode_mask, sizeof(pointer_mode_mask))); -#endif - - int return_count = 0; - check_cuda_error(cublasLtMatmulAlgoGetHeuristic(lightHandle, computeDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, - heuristics.size(), heuristics.data(), &return_count)); - heuristics.resize(return_count); - - return heuristics; -#endif -} - -} // namespace common - -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.h deleted file mode 100644 index 79b7c92a47d..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.h +++ /dev/null @@ -1,148 +0,0 @@ -/* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "tensorrt_llm/common/cudaUtils.h" -#include -#include -#include -#include -#include -#include - -namespace tensorrt_llm -{ -namespace common -{ - -class CublasMMWrapper -{ -protected: - std::shared_ptr mCublasHandle; - std::shared_ptr mCublasLtHandle; - - cudaDataType_t mAType{}; - cudaDataType_t mBType{}; - cudaDataType_t mCType{}; - cublasComputeType_t mComputeType{}; - cudaDataType_t mScaleType{}; - - cublasLtMatmulDesc_t mOperationDesc{NULL}; - cublasLtMatrixLayout_t mADesc{NULL}; - cublasLtMatrixLayout_t mBDesc{NULL}; - cublasLtMatrixLayout_t mCDesc{NULL}; - - cudaStream_t mStream; - - void* mCublasWorkspace = nullptr; - -private: - bool descriptorsCreated() const - { - return mOperationDesc != NULL && mADesc != NULL && mBDesc != NULL && mCDesc != NULL; - } - -public: - CublasMMWrapper(std::shared_ptr cublasHandle, std::shared_ptr cublasLtHandle, - cudaStream_t stream, void* workspace); - - ~CublasMMWrapper(); - - CublasMMWrapper(CublasMMWrapper const& wrapper); - - /********************** GEMMs **********************/ - void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A, - int const lda, void const* B, int const ldb, void* C, int const ldc); - - void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A, - int const lda, void const* B, int const ldb, void* C, int const ldc, - std::optional const& algo); - - void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A, - int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta, - std::optional const& algo); - - void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A, - int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta); - - void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A, - int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta, - cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt); - - void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, - void const* A, int const lda, const int64_t strideA, void const* B, int const ldb, const int64_t strideB, - void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha = 1.0f, - float const f_beta = 0.0f); - - void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, - float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA, void const* B, - cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C, cudaDataType_t CType, - int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType); - - /********************** Tactic selection helpers **********************/ - bool checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, - int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo); - - std::vector getTactics(cublasOperation_t transa, cublasOperation_t transb, - int const m, int const n, int const k, int const lda, int const ldb, int const ldc); - - std::vector getTactics(cublasLtHandle_t lightHandle, - cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc, - cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc); - - using MatrixLayout = std::tuple; - using cache_idx_t = std::tuple>; - - MatrixLayout createMatrixLayout(cublasLtMatrixLayout_t Mdesc); - - /********************** Utils **********************/ - void setWorkspace(void* workspace); - - void setFP32GemmConfig(); - void setFP16GemmConfig(cudaDataType_t outputType = CUDA_R_16F); -#ifdef ENABLE_BF16 - void setBF16GemmConfig(cudaDataType_t outputType = CUDA_R_16BF); -#endif -#ifdef ENABLE_FP8 - void setFP8GemmConfig(cudaDataType_t outputType = CUDA_R_16F); -#endif - - void setStream(cudaStream_t stream); - - void setGemmConfig(cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType); - - CublasDataType getCublasDataType(cudaDataType_t data_type); - - void createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, - int const lda, int const ldb, int const ldc, int8_t fastAcc = 0); - void setScaleDescriptors(void* scale_a, void* scale_b); - void destroyDescriptors(); - - cublasHandle_t getCublasHandle() - { - return *(this->mCublasHandle); - } - - cublasLtHandle_t getCublasLtHandle() const - { - return *(this->mCublasLtHandle); - } -}; - -} // namespace common - -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cublasVersionCheck.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cublasVersionCheck.h deleted file mode 100644 index 1ee72c63566..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cublasVersionCheck.h +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -// We don't want to include cublas_api.h. It contains the CUBLAS_VER_* macro -// definition which is not sufficient to determine if we include cublas.h, -// cublas_v2.h or cublasLt.h. - -#define TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) (MAJOR * 10000 + MINOR * 100 + PATCH) -#define TLLM_CUBLAS_VER_LE(MAJOR, MINOR, PATCH) \ - TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \ - <= TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) -#define TLLM_CUBLAS_VER_LT(MAJOR, MINOR, PATCH) \ - TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \ - < TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) -#define TLLM_CUBLAS_VER_GE(MAJOR, MINOR, PATCH) \ - TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \ - >= TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) -#define TLLM_CUBLAS_VER_GT(MAJOR, MINOR, PATCH) \ - TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \ - > TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Fallbacks.cuh b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Fallbacks.cuh deleted file mode 100644 index 0519251e6fd..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Fallbacks.cuh +++ /dev/null @@ -1,313 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "tensorrt_llm/common/cudaBf16Wrapper.h" -#include -#include - -namespace tensorrt_llm -{ -namespace common -{ - -#ifdef ENABLE_BF16 -inline __device__ float2 bf1622float2(const __nv_bfloat162 val) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float2 f_val; - f_val.x = __low2float(val); - f_val.y = __high2float(val); - return f_val; -#else - return __bfloat1622float2(val); -#endif -} - -inline __device__ int16_t bf1622int16(__nv_bfloat162 val) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float2 f_val; - f_val.x = max(min(__low2float(val), 127.f), -128.f); - f_val.y = max(min(__high2float(val), 127.f), -128.f); - - union - { - int8_t int8[2]; - int16_t int16; - }; - - int8[0] = static_cast(static_cast(f_val.x)); - int8[1] = static_cast(static_cast(f_val.y)); - return int16; -#else - val = __hmin2(val, make_bfloat162(127., 127.)); - val = __hmax2(val, make_bfloat162(-128., -128.)); - - union - { - int8_t int8[2]; - int16_t int16; - }; - - int8[0] = static_cast(static_cast(val.x)); - int8[1] = static_cast(static_cast(val.y)); - return int16; -#endif -} - -inline __device__ __nv_bfloat162 float22bf162(const float2 val) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __floats2bfloat162_rn(val.x, val.y); -#else - return __float22bfloat162_rn(val); -#endif -} - -inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - __nv_bfloat162 val2; - val2.x = val; - val2.y = val; - return val2; -#else - return __bfloat162bfloat162(val); -#endif -} - -inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fxl + fyl, fxh + fyh); -#else - return __hadd2(x, y); -#endif -} - -inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(x) + __bfloat162float(y)); -#else - return __hadd(x, y); -#endif -} - -inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fxl - fyl, fxh - fyh); -#else - return __hsub2(x, y); -#endif -} - -inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(x) - __bfloat162float(y)); -#else - return __hsub(x, y); -#endif -} - -inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - return __floats2bfloat162_rn(fxl * fyl, fxh * fyh); -#else - return __hmul2(x, y); -#endif -} - -inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y)); -#else - return __hmul(x, y); -#endif -} - -inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh, fyl, fyh, fzl, fzh; - fxl = __low2float(x); - fxh = __high2float(x); - fyl = __low2float(y); - fyh = __high2float(y); - fzl = __low2float(z); - fzh = __high2float(z); - return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh); -#else - return __hfma2(x, y, z); -#endif -} - -inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z)); -#else - return __hfma(x, y, z); -#endif -} - -inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fxl, fxh; - fxl = __low2float(x); - fxh = __high2float(x); - ; - return __floats2bfloat162_rn(expf(fxl), expf(fxh)); -#else - return h2exp(x); -#endif -} - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) -#if defined(CUDART_VERSION) && (CUDART_VERSION < 12020) - -inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y) -{ - __nv_bfloat162 t; - t.x = x; - t.y = y; - return t; -} -#endif -#endif - -inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c)); -#else - return a + b + c; -#endif -} - -inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d)); -#else - return (__nv_bfloat16) ((float) a + (float) b + (float) c + (float) d); -#endif -} - -inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fal, fah, fbl, fbh, fcl, fch; - fal = __low2float(a); - fah = __high2float(a); - fbl = __low2float(b); - fbh = __high2float(b); - fcl = __low2float(c); - fch = __high2float(c); - return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch); -#else - return a + b + c; -#endif -} - -inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c)); -#else - return a * b * c; -#endif -} - -inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fal, fah, fbl, fbh, fcl, fch; - fal = __low2float(a); - fah = __high2float(a); - fbl = __low2float(b); - fbh = __high2float(b); - fcl = __low2float(c); - fch = __high2float(c); - return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch); -#else - return a * b * c; -#endif -} - -inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - float fal, fah, fbl, fbh, fcl, fch, fdl, fdh; - fal = __low2float(a); - fah = __high2float(a); - fbl = __low2float(b); - fbh = __high2float(b); - fcl = __low2float(c); - fch = __high2float(c); - fdl = __low2float(d); - fdh = __high2float(d); - return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh); -#else - return a * b * c + d; -#endif -} - -#endif // ENABLE_BF16 - -} // namespace common -} // namespace tensorrt_llm - -// Operator definitions intentionally in global namespace -namespace -{ -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) -#if defined(CUDART_VERSION) && (CUDART_VERSION < 12020) - -inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y) -{ - return tensorrt_llm::common::bf16hmul2(x, y); -}; - -inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y) -{ - return tensorrt_llm::common::bf16hadd2(x, y); -}; -#endif -#endif -} // namespace diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Wrapper.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Wrapper.h deleted file mode 100644 index fb2a89af5cd..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Wrapper.h +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#ifdef ENABLE_BF16 -#include -#endif diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp deleted file mode 100644 index 7eca46a1cab..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#define CUDA_LIB_NAME "cuda" - -#if defined(_WIN32) -#include -#define dllOpen(name) LoadLibrary("nv" name ".dll") -#define dllClose(handle) FreeLibrary(static_cast(handle)) -#define dllGetSym(handle, name) static_cast(GetProcAddress(static_cast(handle), name)) -#else // For non-Windows platforms -#include -#define dllOpen(name) dlopen("lib" name ".so.1", RTLD_LAZY) -#define dllClose(handle) dlclose(handle) -#define dllGetSym(handle, name) dlsym(handle, name) -#endif // defined(_WIN32) - -#include "cudaDriverWrapper.h" -#include "tensorrt_llm/common/assert.h" -#include -#include - -namespace tensorrt_llm::common -{ - -std::shared_ptr CUDADriverWrapper::getInstance() -{ - static std::mutex mutex; - static std::weak_ptr instance; - std::shared_ptr result = instance.lock(); - if (result) - { - return result; - } - - std::lock_guard lock(mutex); - result = instance.lock(); - if (!result) - { - result = std::shared_ptr(new CUDADriverWrapper()); - instance = result; - } - return result; -} - -CUDADriverWrapper::CUDADriverWrapper() - : handle(dllOpen(CUDA_LIB_NAME)) -{ - - TLLM_CHECK_WITH_INFO(handle != nullptr, "CUDA driver library is not open correctly."); - - auto load_sym = [](void* handle, char const* name) - { - void* ret = dllGetSym(handle, name); - return ret; - }; - - *reinterpret_cast(&_cuGetErrorName) = load_sym(handle, "cuGetErrorName"); - *reinterpret_cast(&_cuGetErrorMessage) = load_sym(handle, "cuGetErrorMessage"); - *reinterpret_cast(&_cuFuncSetAttribute) = load_sym(handle, "cuFuncSetAttribute"); - *reinterpret_cast(&_cuLinkComplete) = load_sym(handle, "cuLinkComplete"); - *reinterpret_cast(&_cuModuleUnload) = load_sym(handle, "cuModuleUnload"); - *reinterpret_cast(&_cuLinkDestroy) = load_sym(handle, "cuLinkDestroy"); - *reinterpret_cast(&_cuModuleLoadData) = load_sym(handle, "cuModuleLoadData"); - *reinterpret_cast(&_cuLinkCreate) = load_sym(handle, "cuLinkCreate_v2"); - *reinterpret_cast(&_cuModuleGetFunction) = load_sym(handle, "cuModuleGetFunction"); - *reinterpret_cast(&_cuModuleGetGlobal) = load_sym(handle, "cuModuleGetGlobal_v2"); - *reinterpret_cast(&_cuLinkAddFile) = load_sym(handle, "cuLinkAddFile_v2"); - *reinterpret_cast(&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2"); - *reinterpret_cast(&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel"); - *reinterpret_cast(&_cuLaunchKernel) = load_sym(handle, "cuLaunchKernel"); - *reinterpret_cast(&_cuTensorMapEncodeTiled) = load_sym(handle, "cuTensorMapEncodeTiled"); - *reinterpret_cast(&_cuMemcpyDtoH) = load_sym(handle, "cuMemcpyDtoH_v2"); -} - -CUDADriverWrapper::~CUDADriverWrapper() -{ - dllClose(handle); -} - -CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, char const** pStr) const -{ - return (*_cuGetErrorName)(error, pStr); -} - -CUresult CUDADriverWrapper::cuGetErrorMessage(CUresult error, char const** pStr) const -{ - return (*_cuGetErrorMessage)(error, pStr); -} - -CUresult CUDADriverWrapper::cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const -{ - return (*_cuFuncSetAttribute)(hfunc, attrib, value); -} - -CUresult CUDADriverWrapper::cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const -{ - return (*_cuLinkComplete)(state, cubinOut, sizeOut); -} - -CUresult CUDADriverWrapper::cuModuleUnload(CUmodule hmod) const -{ - return (*_cuModuleUnload)(hmod); -} - -CUresult CUDADriverWrapper::cuLinkDestroy(CUlinkState state) const -{ - return (*_cuLinkDestroy)(state); -} - -CUresult CUDADriverWrapper::cuModuleLoadData(CUmodule* module, void const* image) const -{ - return (*_cuModuleLoadData)(module, image); -} - -CUresult CUDADriverWrapper::cuLinkCreate( - unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const -{ - return (*_cuLinkCreate)(numOptions, options, optionValues, stateOut); -} - -CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const -{ - return (*_cuModuleGetFunction)(hfunc, hmod, name); -} - -CUresult CUDADriverWrapper::cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const -{ - return (*_cuModuleGetGlobal)(dptr, bytes, hmod, name); -} - -CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, - unsigned int numOptions, CUjit_option* options, void** optionValues) const -{ - return (*_cuLinkAddFile)(state, type, path, numOptions, options, optionValues); -} - -CUresult CUDADriverWrapper::cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, - char const* name, unsigned int numOptions, CUjit_option* options, void** optionValues) const -{ - return (*_cuLinkAddData)(state, type, data, size, name, numOptions, options, optionValues); -} - -CUresult CUDADriverWrapper::cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, - unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, - unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const -{ - return (*_cuLaunchCooperativeKernel)( - f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams); -} - -CUresult CUDADriverWrapper::cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, - unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, - unsigned int sharedMemBytes, CUstream hStream, void** kernelParams, void** extra) const -{ - return (*_cuLaunchKernel)( - f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra); -} - -CUresult CUDADriverWrapper::cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, - cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, - cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, - CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const -{ - return (*_cuTensorMapEncodeTiled)(tensorMap, tensorDataType, tensorRank, globalAddress, globalDim, globalStrides, - boxDim, elementStrides, interleave, swizzle, l2Promotion, oobFill); -} - -CUresult CUDADriverWrapper::cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const -{ - return (*_cuMemcpyDtoH)(dstHost, srcDevice, ByteCount); -} - -} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h deleted file mode 100644 index c4d470a85f0..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef CUDA_DRIVER_WRAPPER_H -#define CUDA_DRIVER_WRAPPER_H - -#include "tensorrt_llm/common/assert.h" -#include -#include -#include -#include - -namespace tensorrt_llm::common -{ - -class CUDADriverWrapper -{ -public: - static std::shared_ptr getInstance(); - - ~CUDADriverWrapper(); - CUDADriverWrapper(CUDADriverWrapper const&) = delete; - CUDADriverWrapper operator=(CUDADriverWrapper const&) = delete; - CUDADriverWrapper(CUDADriverWrapper&&) = delete; - CUDADriverWrapper operator=(CUDADriverWrapper&&) = delete; - - CUresult cuGetErrorName(CUresult error, char const** pStr) const; - - CUresult cuGetErrorMessage(CUresult error, char const** pStr) const; - - CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const; - - CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const; - - CUresult cuModuleUnload(CUmodule hmod) const; - - CUresult cuLinkDestroy(CUlinkState state) const; - - CUresult cuModuleLoadData(CUmodule* module, void const* image) const; - - CUresult cuLinkCreate( - unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const; - - CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const; - - CUresult cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const; - - CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, unsigned int numOptions, - CUjit_option* options, void** optionValues) const; - - CUresult cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, char const* name, - unsigned int numOptions, CUjit_option* options, void** optionValues) const; - - CUresult cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, - unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, - unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const; - - CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, - unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, - CUstream hStream, void** kernelParams, void** extra) const; - - CUresult cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, cuuint32_t tensorRank, - void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, cuuint32_t const* boxDim, - cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle, - CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const; - - CUresult cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const; - -private: - void* handle; - CUDADriverWrapper(); - - CUresult (*_cuGetErrorName)(CUresult, char const**); - CUresult (*_cuGetErrorMessage)(CUresult, char const**); - CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int); - CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*); - CUresult (*_cuModuleUnload)(CUmodule); - CUresult (*_cuLinkDestroy)(CUlinkState); - CUresult (*_cuLinkCreate)(unsigned int, CUjit_option*, void**, CUlinkState*); - CUresult (*_cuModuleLoadData)(CUmodule*, void const*); - CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, char const*); - CUresult (*_cuModuleGetGlobal)(CUdeviceptr*, size_t*, CUmodule, char const*); - CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, char const*, unsigned int, CUjit_option*, void**); - CUresult (*_cuLinkAddData)( - CUlinkState, CUjitInputType, void*, size_t, char const*, unsigned int, CUjit_option*, void**); - CUresult (*_cuLaunchCooperativeKernel)(CUfunction, unsigned int, unsigned int, unsigned int, unsigned int, - unsigned int, unsigned int, unsigned int, CUstream, void**); - CUresult (*_cuLaunchKernel)(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, - unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, - CUstream hStream, void** kernelParams, void** extra); - CUresult (*_cuTensorMapEncodeTiled)(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, - cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, - cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, - CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill); - CUresult (*_cuMemcpyDtoH)(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount); -}; - -template -void checkDriver( - T result, CUDADriverWrapper const& wrap, char const* const func, char const* const file, int const line) -{ - if (result) - { - char const* errorName = nullptr; - char const* errorMsg = nullptr; - wrap.cuGetErrorName(result, &errorName); - wrap.cuGetErrorMessage(result, &errorMsg); - throw TllmException( - file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA driver error in %s: %s: %s", func, errorName, errorMsg)); - } -} - -} // namespace tensorrt_llm::common - -/* - * Macros compliant with TensorRT coding conventions - */ -#define TLLM_CU_CHECK(stat) \ - do \ - { \ - tensorrt_llm::common::checkDriver( \ - (stat), *tensorrt_llm::common::CUDADriverWrapper::getInstance(), #stat, __FILE__, __LINE__); \ - } while (0) - -#endif // CUDA_DRIVER_WRAPPER_H diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.cu b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.cu deleted file mode 100644 index 8e140609f2a..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.cu +++ /dev/null @@ -1,436 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/common/cudaFp8Utils.h" -#include "tensorrt_llm/common/cudaUtils.h" -#include "tensorrt_llm/common/reduceKernelUtils.cuh" -#include -#include -#include -#include -#include - -namespace tensorrt_llm -{ -namespace common -{ -#ifdef ENABLE_FP8 - -constexpr int CTA_SIZE = 256; - -template -__inline__ __device__ float scale(float a, float b) -{ - return QUANTIZE ? a / b : a * b; -} - -template -__global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda) -{ - for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < numel; i += blockDim.x * gridDim.x) - { - - if (QUANTIZE_MODE == QuantizeMode::PER_CHANNEL) - { - output[i] = T_OUT(scale(static_cast(input[i]), static_cast(input_scale[i % lda]))); - } - else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN) - { - output[i] = T_OUT(scale(static_cast(input[i]), static_cast(input_scale[i / lda]))); - } - else if (QUANTIZE_MODE == QuantizeMode::PER_TENSOR) - { - output[i] = T_OUT(scale(static_cast(input[i]), static_cast(input_scale[0]))); - } - } -} - -template -void invokeQuantizeMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda, - QuantizeMode quantize_mode, cudaStream_t stream) -{ - dim3 grid(1024); - dim3 block(CTA_SIZE); - if (quantize_mode == QuantizeMode::PER_CHANNEL) - { - scaleMatrix - <<>>(output, input_scale, input, numel, lda); - } - else if (quantize_mode == QuantizeMode::PER_TOKEN) - { - scaleMatrix<<>>(output, input_scale, input, numel, lda); - } - else if (quantize_mode == QuantizeMode::PER_TENSOR) - { - scaleMatrix<<>>(output, input_scale, input, numel, lda); - } - sync_check_cuda_error(); -} - -template -void invokeDequantizeMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda, - QuantizeMode quantize_mode, cudaStream_t stream) -{ - dim3 grid(1024); - dim3 block(CTA_SIZE); - if (quantize_mode == QuantizeMode::PER_CHANNEL) - { - scaleMatrix - <<>>(output, input_scale, input, numel, lda); - } - else if (quantize_mode == QuantizeMode::PER_TOKEN) - { - scaleMatrix<<>>(output, input_scale, input, numel, lda); - } - else if (quantize_mode == QuantizeMode::PER_TENSOR) - { - scaleMatrix - <<>>(output, input_scale, input, numel, lda); - } - sync_check_cuda_error(); -} - -template -__global__ void fakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel) -{ - for (int64_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < numel; tid += blockDim.x * gridDim.x) - { - T_FAKE tmp = (T_FAKE) (static_cast(src[tid])); - dst[tid] = (T_OUT) (static_cast(tmp)); - } -} - -template -void invokeFakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel, cudaStream_t stream) -{ - fakeQuantize<<<1024, CTA_SIZE, 0, stream>>>(dst, src, numel); - sync_check_cuda_error(); -} - -template void invokeFakeQuantize<__nv_fp8_e4m3, float, float>( - float* dst, float const* src, const int64_t numel, cudaStream_t stream); -template void invokeFakeQuantize( - float* dst, __nv_fp8_e4m3 const* src, const int64_t numel, cudaStream_t stream); -template void invokeFakeQuantize<__nv_fp8_e4m3, half, half>( - half* dst, half const* src, const int64_t numel, cudaStream_t stream); -template void invokeFakeQuantize<__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16>( - __nv_bfloat16* dst, __nv_bfloat16 const* src, const int64_t numel, cudaStream_t stream); - -template void invokeFakeQuantize( - half* dst, float const* src, const int64_t numel, cudaStream_t stream); - -__device__ float atomicMaxExtd(float* address, float val) -{ - assert(val >= 0); - unsigned int* address_as_u = reinterpret_cast(address); - unsigned int old = atomicMax(address_as_u, __float_as_uint(val)); - return __uint_as_float(old); -} - -template -inline __device__ T atomicMaxExtdV2(T* address, T val) -{ -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - static_assert(std::is_same_v | std::is_same_v, "T needs to be either half or bfloat16"); - // The address in 64 bits. - uint64_t address_u64 = reinterpret_cast(address); - - // Pack the input value into 32 bits. - union - { - T v[2]; - uint16_t u[2]; - } old, tmp = {}; - - int const loc = (address_u64 & 0x2) >> 1; - tmp.v[loc] = val; - - // 4B aligned pointer. - auto aligned_address = reinterpret_cast(address_u64 & ~0x3ull); - - if constexpr (std::is_same_v) - { - asm volatile("atom.global.v2.f16.max.noftz {%0, %1}, [%2], {%3, %4};" - : "=h"(old.u[0]), "=h"(old.u[1]) - : "l"(aligned_address), "h"(tmp.u[0]), "h"(tmp.u[1])); - } - if constexpr (std::is_same_v) - { - asm volatile("atom.global.v2.bf16.max.noftz {%0, %1}, [%2], {%3, %4};" - : "=h"(old.u[0]), "=h"(old.u[1]) - : "l"(aligned_address), "h"(tmp.u[0]), "h"(tmp.u[1])); - } - - // Return the correct half. - return old.v[loc]; -#endif -} - -__device__ half atomicMaxExtd(half* address, half val) -{ - unsigned short int* address_as_u = reinterpret_cast(address); - unsigned short int old = *address_as_u, assumed; - - while (val > __ushort_as_half(old)) - { - assumed = old; - old = atomicCAS(address_as_u, assumed, __half_as_ushort(val)); - } - - return __ushort_as_half(old); -} - -__device__ __nv_bfloat16 atomicMaxExtd(__nv_bfloat16* address, __nv_bfloat16 val) -{ -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) - unsigned short int* address_as_u = reinterpret_cast(address); - unsigned short int old = *address_as_u, assumed; - - while (val > __ushort_as_bfloat16(old)) - { - assumed = old; - old = atomicCAS(address_as_u, assumed, __bfloat16_as_ushort(val)); - } - - return __ushort_as_bfloat16(old); -#else - assert(0); - asm volatile("brkpt;\n" ::); - return __nv_bfloat16(0); -#endif -} - -template -__global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t size, const int64_t n) -{ - constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f); - if (QUANTIZE_MODE == QuantizeMode::PER_CHANNEL) - { - for (int64_t col = threadIdx.x; col < n; col += blockDim.x) - { - float max = 0.f; - for (int64_t i = col + n * blockIdx.x; i < size; i += gridDim.x * n) - { - auto val = fabs(static_cast(weights[i])); - max = max > val ? max : val; - } - auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - if constexpr (std::is_same_v) - { - atomicMaxExtd(quant_ptr + col, scale); - } - else - { - auto const address_u64 = reinterpret_cast(quant_ptr + col); - if ((col == 0 && address_u64 % 4 != 0) || (col == n - 1 && address_u64 % 4 == 0)) - atomicMaxExtd(quant_ptr + col, scale); - else - atomicMaxExtdV2(quant_ptr + col, scale); - } -#else // Vector atomics require __CUDA_ARCH__ >= 900 - atomicMaxExtd(quant_ptr + col, scale); -#endif - } - } - else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN) - { - auto const nrows = size / n; - for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x) - { - float max = 0.f; - for (int64_t i = threadIdx.x; i < n; i += blockDim.x) - { - auto val = fabs(static_cast(weights[row * n + i])); - max = max > val ? max : val; - } - max = blockReduceMax(max); - if (threadIdx.x == 0) - { - auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); - quant_ptr[row] = scale; - } - } - } - else if (QUANTIZE_MODE == QuantizeMode::PER_TENSOR) - { - float max = 0.f; - for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < size; i += gridDim.x * blockDim.x) - { - auto val = fabs(static_cast(weights[i])); - max = max > val ? max : val; - } - max = blockReduceMax(max); - if (threadIdx.x == 0) - { - auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); - atomicMaxExtd(quant_ptr, scale); - } - } -} - -template -void invokeComputeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t numel, const int64_t lda, - QuantizeMode quantize_mode, cudaStream_t stream) -{ - if (quantize_mode == QuantizeMode::PER_TOKEN) - { - dim3 block(CTA_SIZE); - dim3 grid(numel / lda); - computeFP8QuantizeScale<<>>(quant_ptr, weights, numel, lda); - } - else if (quantize_mode == QuantizeMode::PER_CHANNEL) - { - dim3 block(CTA_SIZE); - dim3 grid((lda + CTA_SIZE - 1) / CTA_SIZE); - cudaMemsetAsync(quant_ptr, 0, lda * sizeof(T_S), stream); - sync_check_cuda_error(); - computeFP8QuantizeScale<<>>(quant_ptr, weights, numel, lda); - } - else if (quantize_mode == QuantizeMode::PER_TENSOR) - { - dim3 block(1024); - dim3 grid(1024); - cudaMemsetAsync(quant_ptr, 0, sizeof(T_S), stream); - sync_check_cuda_error(); - computeFP8QuantizeScale<<>>(quant_ptr, weights, numel, lda); - } - sync_check_cuda_error(); -} - -#define DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(type_scale, type_in) \ - template void invokeComputeFP8QuantizeScale(type_scale * input_scale, type_in const* weights, \ - int64_t numel, int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream); - -DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(half, half); -DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, half); -DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, float); -#ifdef ENABLE_BF16 -DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(__nv_bfloat16, __nv_bfloat16); -DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, __nv_bfloat16); -#endif - -template -__global__ void dynamicQuantizeMatrixPerToken( - T_OUT* output, T_S* quant_ptr, T_IN const* input, int64_t numel, int64_t lda) -{ - extern __shared__ __align__(sizeof(float)) char _shmem[]; - T_IN* shmem = reinterpret_cast(_shmem); - constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f); - auto const nrows = numel / lda; - for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x) - { - float max = 0.f; - for (int64_t i = threadIdx.x; i < lda; i += blockDim.x) - { - auto const in = input[row * lda + i]; - shmem[i] = in; - auto val = fabs(static_cast(in)); - max = max > val ? max : val; - } - max = blockAllReduceMax(max); // __syncthreads() called so we can read shmem - auto const s = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); - for (int64_t i = threadIdx.x; i < lda; i += blockDim.x) - { - // true means we are quantizing - output[row * lda + i] = (T_OUT) scale(static_cast(shmem[i]), static_cast(s)); - } - if (threadIdx.x == 0) - { - quant_ptr[row] = s; - } - } -} - -template -void invokeComputeScalesAndQuantizeMatrix(T_OUT* output, T_S* quant_ptr, const T_IN* input, const int64_t numel, - const int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream) -{ - if (quantize_mode == QuantizeMode::PER_TOKEN) - { - dim3 grid(numel / lda); - bool use_shmem = true; - auto const shmem_size = lda * sizeof(T_IN); - if (shmem_size >= (48 << 10)) - { - cudaError_t ret = cudaFuncSetAttribute(dynamicQuantizeMatrixPerToken, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); - use_shmem = ret == cudaSuccess; - } - if (use_shmem) - { - // ensure the threadblock is as large as possible to increase occupancy - dim3 block(std::min((lda + 31) / 32 * 32, static_cast(1024))); - dynamicQuantizeMatrixPerToken<<>>(output, quant_ptr, input, numel, lda); - } - else - { - dim3 block(CTA_SIZE); - computeFP8QuantizeScale<<>>(quant_ptr, input, numel, lda); - sync_check_cuda_error(); - invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream); - } - } - else if (quantize_mode == QuantizeMode::PER_CHANNEL) - { - dim3 block(CTA_SIZE); - dim3 grid((lda + CTA_SIZE - 1) / CTA_SIZE); - cudaMemsetAsync(quant_ptr, 0, lda * sizeof(T_S), stream); - sync_check_cuda_error(); - computeFP8QuantizeScale<<>>(quant_ptr, input, numel, lda); - sync_check_cuda_error(); - invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream); - } - else if (quantize_mode == QuantizeMode::PER_TENSOR) - { - dim3 block(1024); - dim3 grid(1024); - cudaMemsetAsync(quant_ptr, 0, sizeof(T_S), stream); - sync_check_cuda_error(); - computeFP8QuantizeScale<<>>(quant_ptr, input, numel, lda); - sync_check_cuda_error(); - invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream); - } - sync_check_cuda_error(); -} - -#define DEFINE_INVOKE_QUANTIZE_MATRIX(type_out, type_scale, type_in) \ - template void invokeQuantizeMatrix(type_out * output, \ - type_scale const* input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \ - cudaStream_t stream); \ - template void invokeDequantizeMatrix(type_out * output, \ - type_scale const* input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \ - cudaStream_t stream); \ - template void invokeComputeScalesAndQuantizeMatrix(type_out * output, \ - type_scale * input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \ - cudaStream_t stream); - -#ifdef ENABLE_FP8 -DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, float, float); -DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, float, half); -DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, half, half); -DEFINE_INVOKE_QUANTIZE_MATRIX(half, half, __nv_fp8_e4m3); -DEFINE_INVOKE_QUANTIZE_MATRIX(float, float, __nv_fp8_e4m3); -DEFINE_INVOKE_QUANTIZE_MATRIX(half, float, __nv_fp8_e4m3); -#ifdef ENABLE_BF16 -DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16); -DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_bfloat16, __nv_bfloat16, __nv_fp8_e4m3); -#endif -#endif - -#endif // ENABLE_FP8 -} // namespace common -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.h deleted file mode 100644 index aa93b55a579..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.h +++ /dev/null @@ -1,239 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#ifdef ENABLE_FP8 -#include -#include -#include - -#define FP8_MHA -#define FUSE_GEMM_ACT -#define FP8_GEMM_OUTPUT_QUANT_DISABLE - -#ifdef FUSE_GEMM_ACT -#define USE_QGMMA -#endif - -namespace tensorrt_llm -{ -namespace common -{ - -constexpr float FP8_E4M3_MAX = 448.0f; - -enum QuantizeMode -{ - PER_CHANNEL, - PER_TENSOR, - PER_CHANNEL_WEIGHT_PER_TENSOR_ACT, - PER_TOKEN, -}; - -// Packed Data Type -typedef struct __CUDA_ALIGN__(32) -{ - float array[8]; -} float8; - -typedef struct __CUDA_ALIGN__(16) -{ - half array[8]; -} half8; - -typedef struct __CUDA_ALIGN__(8) -{ - half2 array[2]; -} half2_2; - -typedef struct __CUDA_ALIGN__(8) -{ - half array[4]; -} half_4; - -#ifdef ENABLE_BF16 -typedef struct __CUDA_ALIGN__(4) -{ - __nv_bfloat16 array[2]; -} __nv_bfloat16_2; - -typedef struct __CUDA_ALIGN__(8) -{ - __nv_bfloat162 x, y; -} __nv_bfloat162_2_xy; - -typedef struct __CUDA_ALIGN__(8) -{ - __nv_bfloat16 array[4]; -} __nv_bfloat164; - -typedef struct __CUDA_ALIGN__(8) -{ - __nv_bfloat162 array[2]; -} __nv_bfloat162_2; - -typedef struct __CUDA_ALIGN__(16) -{ - __nv_bfloat16 array[8]; -} __nv_bfloat168; - -typedef struct __CUDA_ALIGN__(16) -{ - __nv_bfloat162 array[4]; -} __nv_bfloat162_4; - -typedef struct __CUDA_ALIGN__(32) -{ - __nv_bfloat16 array[16]; -} __nv_bfloat1616; -#endif - -#ifdef ENABLE_FP8 -typedef struct __CUDA_ALIGN__(2) -{ - __nv_fp8_e4m3 array[2]; -} __nv_fp8_2_e4m3; - -typedef struct __CUDA_ALIGN__(4) -{ - __nv_fp8_e4m3 array[4]; -} __nv_fp8_4_e4m3; - -typedef struct __CUDA_ALIGN__(4) -{ - __nv_fp8x2_e4m3 array[2]; -} __nv_fp8x2_x2_e4m3; - -typedef struct __CUDA_ALIGN__(8) -{ - __nv_fp8_e4m3 array[8]; -} __nv_fp8_8_e4m3; - -typedef struct __CUDA_ALIGN__(8) -{ - __nv_fp8x2_e4m3 array[4]; -} __nv_fp8x2_x4_e4m3; - -typedef struct __CUDA_ALIGN__(16) -{ - __nv_fp8_e4m3 array[16]; -} __nv_fp8x16_e4m3; -#endif - -// only BF16 and FP8 -template -struct PackType -{ - using type = float; -}; - -#ifdef ENABLE_BF16 -template <> -struct PackType<__nv_bfloat16, 2> -{ - using type = __nv_bfloat16_2; -}; - -template <> -struct PackType<__nv_bfloat16, 4> -{ - using type = __nv_bfloat164; -}; - -template <> -struct PackType<__nv_bfloat16, 8> -{ - using type = __nv_bfloat168; -}; -#endif - -#ifdef ENABLE_FP8 -template <> -struct PackType<__nv_fp8_e4m3, 2> -{ - using type = __nv_fp8_2_e4m3; -}; - -template <> -struct PackType<__nv_fp8_e4m3, 4> -{ - using type = __nv_fp8_4_e4m3; -}; - -template <> -struct PackType<__nv_fp8_e4m3, 8> -{ - using type = __nv_fp8_8_e4m3; -}; -#endif - -__inline__ __device__ void fp8x4_e4m3_to_bfloat2(__nv_bfloat162* out1, __nv_bfloat162* out2, __nv_fp8x4_e4m3 const* in) -{ - const char4 tmp_val = reinterpret_cast(in)[0]; - *out1 = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0], - (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]); - *out2 = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.z)[0], - (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.w)[0]); -} - -__inline__ __device__ __nv_bfloat162 fp8x2_e4m3_to_bfloat2(__nv_fp8x2_e4m3 const* in) -{ - const char2 tmp_val = reinterpret_cast(in)[0]; - __nv_bfloat162 out = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0], - (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]); - return out; -} - -__inline__ __device__ void fp8x4_e4m3_to_half2(half2* out1, half2* out2, __nv_fp8x4_e4m3 const* in) -{ - const char4 tmp_val = reinterpret_cast(in)[0]; - *out1 = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0], - (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]); - *out2 = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.z)[0], - (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.w)[0]); -} - -__inline__ __device__ half2 fp8x2_e4m3_to_half2(__nv_fp8x2_e4m3 const* in) -{ - const char2 tmp_val = reinterpret_cast(in)[0]; - half2 out = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0], - (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]); - return out; -} - -template -void invokeQuantizeMatrix(T_OUT* output, T_S const* input_qua_amax_ptr, T_IN const* input, int64_t numel, int64_t lda, - QuantizeMode quantize_mode, cudaStream_t stream); - -template -void invokeDequantizeMatrix(T_OUT* output, T_S const* input_qua_amax_ptr, T_IN const* input, int64_t numel, int64_t lda, - QuantizeMode quantize_mode, cudaStream_t stream); - -template -void invokeFakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel, cudaStream_t stream); - -template -void invokeComputeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t k, const int64_t lda, - QuantizeMode quantize_mode, cudaStream_t stream); - -template -void invokeComputeScalesAndQuantizeMatrix(T_OUT* output, T_S* quant_ptr, const T_IN* weights, const int64_t numel, - const int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream); - -} // namespace common -} // namespace tensorrt_llm -#endif // ENABLE_FP8 diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaTypeUtils.cuh b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaTypeUtils.cuh deleted file mode 100644 index a0463a3a49e..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaTypeUtils.cuh +++ /dev/null @@ -1,752 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh" -#include "tensorrt_llm/common/cudaBf16Wrapper.h" -#include "tensorrt_llm/common/cudaFp8Utils.h" -#include -#include -#include -#if ENABLE_BF16 -#include -#endif - -namespace tensorrt_llm -{ -namespace common -{ - -template -inline __device__ T ldg(T const* val) -{ - return __ldg(val); -} - -#if ENABLE_BF16 -template <> -inline __device__ __nv_bfloat162 ldg(__nv_bfloat162 const* val) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return val[0]; -#else - return __ldg(val); -#endif -} - -template <> -inline __device__ __nv_bfloat16 ldg(__nv_bfloat16 const* val) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - return val[0]; -#else - return __ldg(val); -#endif -} -#endif // ENABLE_BF16 - -// Get type2 from type or vice versa (applied to half and bfloat16) -template -struct TypeConverter -{ - using Type = half2; -}; // keep for generality - -template <> -struct TypeConverter -{ - using Type = half; -}; - -template <> -struct TypeConverter -{ - using Type = half2; -}; - -#if ENABLE_BF16 -template <> -struct TypeConverter<__nv_bfloat162> -{ - using Type = __nv_bfloat16; -}; - -template <> -struct TypeConverter<__nv_bfloat16> -{ - using Type = __nv_bfloat162; -}; -#endif // ENABLE_BF16 - -// Defined math operations (bfloat16 fallback to fp32 when it is not supported) -template -inline __device__ T hadd2(T a, T b) -{ - return __hadd2(a, b); -} - -#if ENABLE_BF16 -template <> -inline __device__ __nv_bfloat162 hadd2(__nv_bfloat162 a, __nv_bfloat162 b) -{ - return bf16hadd2(a, b); -} -#endif // ENABLE_BF16 - -template -inline __device__ T add(T a, T b) -{ - return a + b; -} - -template <> -inline __device__ half2 add(half2 a, half2 b) -{ - return __hadd2(a, b); -} - -template <> -inline __device__ half add(half a, half b) -{ - return __hadd(a, b); -} - -#if ENABLE_BF16 -template <> -inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) -{ - return bf16hadd2(a, b); -} - -template <> -inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) -{ - return bf16hadd(a, b); -} - -inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, float b) -{ - return bf16hadd(a, __float2bfloat16(b)); -} -#endif // ENABLE_BF16 - -// applies to all 4 values addition -template -inline __device__ T add(T a, T b, T c) -{ - return a + b + c; -} - -#if ENABLE_BF16 -inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) -{ - return bf16hadd(a, b, c); -} - -inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) -{ - return bf16hadd2(a, b, c); -} -#endif // ENABLE_BF16 - -// applies to all 4 values addition -template -inline __device__ T add(T a, T b, T c, T d) -{ - return (T) ((float) a + (float) b + (float) c + (float) d); -} - -#if ENABLE_BF16 -inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) -{ - return bf16hadd(a, b, c, d); -} -#endif // ENABLE_BF16 - -template -inline __device__ T hsub2(T a, T b) -{ - return __hsub2(a, b); -} - -#if ENABLE_BF16 -template <> -inline __device__ __nv_bfloat162 hsub2(__nv_bfloat162 a, __nv_bfloat162 b) -{ - return bf16hsub2(a, b); -} -#endif // ENABLE_BF16 - -template -inline __device__ T hmul2(T a, T b) -{ - return __hmul2(a, b); -} - -#if ENABLE_BF16 -template <> -inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b) -{ - return bf16hmul2(a, b); -} -#endif // ENABLE_BF16 - -template -inline __device__ T hmul2(T a, T b, T c) -{ - return a * b * c; -} - -#if ENABLE_BF16 -template <> -inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) -{ - return bf16hmul2(a, b, c); -} -#endif // ENABLE_BF16 - -template -inline __device__ T mul(T a, T b, T c) -{ - return a * b * c; -} - -#if ENABLE_BF16 -template <> -inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) -{ - return bf16hmul(a, b, c); -} - -inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) -{ - return bf16hmul2(a, b, c); -} -#endif // ENABLE_BF16 - -template -inline __device__ T fma(T a, T b, T c, T d) -{ - return a * b * c + d; -} - -#if ENABLE_BF16 -inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) -{ - return bf16hfma2(a, b, c, d); -} -#endif // ENABLE_BF16 - -template -inline __device__ T fma(T a, T b, T c) -{ - return a * b + c; -} - -#if ENABLE_BF16 -template <> -inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) -{ - return bf16hfma2(a, b, c); -} - -template <> -inline __device__ __nv_bfloat16 fma(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) -{ - return bf16hfma(a, b, c); -} -#endif // ENABLE_BF16 - -template -inline __device__ T hexp2(T a) -{ - return h2exp(a); -} - -#if ENABLE_BF16 -template <> -inline __device__ __nv_bfloat162 hexp2(__nv_bfloat162 a) -{ - return bf16exp2(a); -} -#endif // ENABLE_BF16 - -template -__device__ inline T_OUT cuda_cast(T_IN val) -{ - return val; -} - -template <> -__device__ inline float2 cuda_cast(int2 val) -{ - return make_float2(val.x, val.y); -} - -template <> -__device__ inline float2 cuda_cast(float val) -{ - return make_float2(val, val); -} - -template <> -__device__ inline float2 cuda_cast(half2 val) -{ - return __half22float2(val); -} - -template <> -__device__ inline half2 cuda_cast(float2 val) -{ - return __float22half2_rn(val); -} - -template <> -__device__ inline half2 cuda_cast(float val) -{ - return __float2half2_rn(val); -} - -template <> -__device__ inline half2 cuda_cast(half val) -{ - return __half2half2(val); -} - -template <> -__device__ inline int8_t cuda_cast(half val) -{ - union - { - int8_t int8[2]; - int16_t int16; - }; - - union - { - half fp16; - int16_t int16_in; - }; - - fp16 = val; - asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in)); - return int8[0]; -} - -template <> -__device__ inline int16_t cuda_cast(half2 val) -{ - union - { - int8_t int8[2]; - int16_t int16; - }; - - int8[0] = cuda_cast(val.x); - int8[1] = cuda_cast(val.y); - return int16; -} - -template <> -__device__ inline int8_t cuda_cast(float val) -{ - union - { - int8_t int8[2]; - int16_t int16; - }; - - asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val)); - return int8[0]; -} - -template <> -__device__ inline int16_t cuda_cast(float2 val) -{ - union - { - int8_t int8[2]; - int16_t int16; - }; - - int8[0] = cuda_cast(val.x); - int8[1] = cuda_cast(val.y); - return int16; -} - -template <> -__device__ inline half2 cuda_cast(int16_t val) -{ - union - { - int8_t int8[2]; - int16_t int16; - }; - - int16 = val; - return make_half2(int8[0], int8[1]); -} - -template <> -__device__ inline float2 cuda_cast(int16_t val) -{ - union - { - int8_t int8[2]; - int16_t int16; - }; - - int16 = val; - return make_float2(int8[0], int8[1]); -} - -#ifdef ENABLE_BF16 -template <> -__device__ inline __nv_bfloat16 cuda_cast(int32_t val) -{ - return static_cast(val); -} - -template <> -__device__ inline __nv_bfloat16 cuda_cast(int8_t val) -{ - return static_cast(val); -} - -template <> -__device__ inline int8_t cuda_cast(__nv_bfloat16 val) -{ - return static_cast(val); -} - -template <> -__device__ inline float cuda_cast(__nv_bfloat16 val) -{ - return __bfloat162float(val); -} - -template <> -__device__ inline float2 cuda_cast(__nv_bfloat162 val) -{ - return bf1622float2(val); -} - -template <> -__device__ inline half cuda_cast(__nv_bfloat16 val) -{ - return __float2half(__bfloat162float(val)); -} - -template <> -__device__ inline int16_t cuda_cast(__nv_bfloat162 val) -{ - return bf1622int16(val); -} - -template <> -__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val) -{ - return __float2bfloat16(val); -} - -template <> -__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, half>(half val) -{ - return __float2bfloat16(__half2float(val)); -} - -template <> -__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_bfloat16>(__nv_bfloat16 val) -{ - return bf162bf162(val); -} - -template <> -__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float>(float val) -{ - return __float2bfloat162_rn(val); -} - -template <> -__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float2>(float2 val) -{ - return float22bf162(val); -} - -template <> -__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(int16_t val) -{ - union - { - int8_t int8[2]; - int16_t int16; - }; - - int16 = val; - __nv_bfloat162 res; - res.x = cuda_cast<__nv_bfloat16>(int8[0]); - res.y = cuda_cast<__nv_bfloat16>(int8[1]); - return res; -} - -template <> -__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val) -{ - return float22bf162(__half22float2(val)); -} - -#endif // ENABLE BF16 - -template -__device__ inline T cuda_abs(T val) -{ - assert(false); - return {}; -} - -template <> -__device__ inline float cuda_abs(float val) -{ - return fabs(val); -} - -template <> -__device__ inline float2 cuda_abs(float2 val) -{ - return make_float2(fabs(val.x), fabs(val.y)); -} - -template <> -__device__ inline half cuda_abs(half val) -{ - return __habs(val); -} - -template <> -__device__ inline half2 cuda_abs(half2 val) -{ - return __habs2(val); -} - -#ifdef ENABLE_BF16 - -#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__) -template <> -__device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val) -{ - return __habs(val); -} - -template <> -__device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val) -{ - return __habs2(val); -} -#endif - -#endif // ENABLE_FP16 - -template -__device__ inline To cuda_sum(Ti val) -{ - return cuda_cast(val); -}; - -template -__device__ inline To cuda_sum(float2 val) -{ - return cuda_cast(val.x + val.y); -}; - -// Unary maximum: compute the max of a vector type -template -__device__ inline To cuda_max(Ti val) -{ - return cuda_cast(val); -}; - -template <> -__device__ inline float cuda_max(float2 val) -{ - return fmaxf(val.x, val.y); -} - -template <> -__device__ inline half cuda_max(half2 val) -{ - return __hmax(val.x, val.y); -} - -#ifdef ENABLE_BF16 -template <> -__device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val) -{ -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) - return __hmax(val.x, val.y); -#else - assert(0); - asm volatile("brkpt;\n" ::); - return __nv_bfloat16(0); -#endif -} -#endif - -// Binary maximum: compute the max of two values. -template -__device__ inline T cuda_max(T val1, T val2) -{ - return (val1 > val2) ? val1 : val2; -} - -template <> -__device__ inline float2 cuda_max(float2 val1, float2 val2) -{ - float2 out; - out.x = fmaxf(val1.x, val2.x); - out.y = fmaxf(val1.y, val2.y); - return out; -} - -template <> -__device__ inline half2 cuda_max(half2 val1, half2 val2) -{ - return __hmax2(val1, val2); -} - -#ifdef ENABLE_BF16 -template <> -__device__ inline __nv_bfloat162 cuda_max(__nv_bfloat162 val1, __nv_bfloat162 val2) -{ - return __hmax2(val1, val2); -} -#endif // ENABLE_BF16 - -// Binary maximum: compute the min of two values. -template -__device__ inline T cuda_min(T val1, T val2) -{ - return (val1 < val2) ? val1 : val2; -} - -template <> -__device__ inline float2 cuda_min(float2 val1, float2 val2) -{ - float2 out; - out.x = fminf(val1.x, val2.x); - out.y = fminf(val1.y, val2.y); - return out; -} - -template <> -__device__ inline half2 cuda_min(half2 val1, half2 val2) -{ - return __hmin2(val1, val2); -} - -#ifdef ENABLE_BF16 -template <> -__device__ inline __nv_bfloat162 cuda_min(__nv_bfloat162 val1, __nv_bfloat162 val2) -{ - return __hmin2(val1, val2); -} -#endif // ENABLE_BF16 - -// Helper function of clamping the val into the given range. -template -inline __device__ T cuda_clamp(T val, T minVal, T maxVal) -{ - return cuda_min(cuda_max(val, minVal), maxVal); -} - -#ifdef ENABLE_FP8 -template <> -__device__ inline float2 cuda_cast(__nv_fp8x2_e4m3 val) -{ - return bf1622float2(fp8x2_e4m3_to_bfloat2(&val)); -} - -template <> -__device__ inline half2 cuda_cast(__nv_fp8x2_e4m3 val) -{ - return fp8x2_e4m3_to_half2(&val); -} - -template <> -__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, float2>(float2 val) -{ - return __nv_fp8x2_e4m3(bf1622float2(float22bf162(val))); -} - -template <> -__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, half2>(half2 val) -{ - return __nv_fp8x2_e4m3(cuda_cast(val)); -} - -template <> -__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, __nv_bfloat162>(__nv_bfloat162 val) -{ - return __nv_fp8x2_e4m3(cuda_cast(val)); -} - -template <> -__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, half>(half val) -{ - return __nv_fp8_e4m3(val); -} - -template <> -__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, __nv_bfloat16>(__nv_bfloat16 val) -{ - return __nv_fp8_e4m3(val); -} - -template <> -__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, float>(float val) -{ - return __nv_fp8_e4m3(val); -} - -template <> -__device__ inline float cuda_cast(__nv_fp8_e4m3 val) -{ - return (float) val; -} - -template <> -__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val) -{ - return fp8x2_e4m3_to_bfloat2(&val); -} - -template <> -__device__ inline int8_t cuda_cast(__nv_fp8_e4m3 val) -{ - // no impl - return 0; -} - -template <> -__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, int8_t>(int8_t val) -{ - return cuda_cast<__nv_fp8_e4m3>(cuda_cast<__nv_bfloat16>(cuda_cast(val))); -} - -#endif // ENABLE_FP8 - -} // namespace common -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaUtils.h deleted file mode 100644 index 13ee3367e97..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaUtils.h +++ /dev/null @@ -1,641 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "tensorrt_llm/common/cudaBf16Wrapper.h" -#include "tensorrt_llm/common/cudaDriverWrapper.h" -#include "tensorrt_llm/common/cudaFp8Utils.h" -#include "tensorrt_llm/common/logger.h" -#include "tensorrt_llm/common/tllmException.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#ifndef _WIN32 // Linux -#include -#endif // not WIN32 -#include -#ifdef _WIN32 // Windows -#include -#undef ERROR // A Windows header file defines ERROR as 0, but it's used in our logger.h enum. Logging breaks without - // this undef. -#endif // WIN32 - -namespace tensorrt_llm::common -{ - -// workspace for cublas gemm : 32MB -#define CUBLAS_WORKSPACE_SIZE 33554432 - -typedef struct __align__(4) -{ - half x, y, z, w; -} - -half4; - -/* **************************** type definition ***************************** */ - -enum CublasDataType -{ - FLOAT_DATATYPE = 0, - HALF_DATATYPE = 1, - BFLOAT16_DATATYPE = 2, - INT8_DATATYPE = 3, - FP8_DATATYPE = 4 -}; - -enum TRTLLMCudaDataType -{ - FP32 = 0, - FP16 = 1, - BF16 = 2, - INT8 = 3, - FP8 = 4 -}; - -enum class OperationType -{ - FP32, - FP16, - BF16, - INT8, - FP8 -}; - -/* **************************** debug tools ********************************* */ -static char const* _cudaGetErrorEnum(cudaError_t error) -{ - return cudaGetErrorString(error); -} - -static char const* _cudaGetErrorEnum(cublasStatus_t error) -{ - switch (error) - { - case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS"; - - case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED"; - - case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED"; - - case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE"; - - case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH"; - - case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR"; - - case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; - - case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; - - case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED"; - - case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR"; - } - return ""; -} - -template -void check(T result, char const* const func, char const* const file, int const line) -{ - if (result) - { - throw TllmException( - file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA runtime error in %s: %s", func, _cudaGetErrorEnum(result))); - } -} - -template -void checkEx(T result, std::initializer_list const& validReturns, char const* const func, char const* const file, - int const line) -{ - if (std::all_of(std::begin(validReturns), std::end(validReturns), [&result](T const& t) { return t != result; })) - { - throw TllmException( - file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA runtime error in %s: %s", func, _cudaGetErrorEnum(result))); - } -} - -#define check_cuda_error(val) check((val), #val, __FILE__, __LINE__) -#define check_cuda_error_2(val, file, line) check((val), #val, file, line) - -inline std::optional isCudaLaunchBlocking() -{ - static bool firstCall = true; - static std::optional result = std::nullopt; - - if (firstCall) - { - char const* env = std::getenv("CUDA_LAUNCH_BLOCKING"); - if (env != nullptr && std::string(env) == "1") - { - result = true; - } - else if (env != nullptr && std::string(env) == "0") - { - result = false; - } - firstCall = false; - } - - return result; -} - -inline bool doCheckError() -{ - auto const cudaLaunchBlocking = isCudaLaunchBlocking(); -#ifndef NDEBUG - bool const checkError = cudaLaunchBlocking.value_or(true); -#else - bool const checkError = cudaLaunchBlocking.value_or(false); -#endif - - return checkError; -} - -inline void syncAndCheck(char const* const file, int const line) -{ - if (doCheckError()) - { - cudaDeviceSynchronize(); - check(cudaGetLastError(), "cudaGetLastError", file, line); - } -} - -#define sync_check_cuda_error() tensorrt_llm::common::syncAndCheck(__FILE__, __LINE__) - -#define PRINT_FUNC_NAME_() \ - do \ - { \ - std::cout << "[TensorRT-LLM][CALL] " << __FUNCTION__ << " " << std::endl; \ - } while (0) - -// clang-format off -template struct packed_type; -template <> struct packed_type { using type = float; }; // we don't need to pack float by default -template <> struct packed_type { using type = half2; }; - -#ifdef ENABLE_BF16 -template<> -struct packed_type<__nv_bfloat16> { - using type = __nv_bfloat162; -}; -#endif - -#ifdef ENABLE_FP8 -template<> -struct packed_type<__nv_fp8_e4m3> { - using type = __nv_fp8x2_e4m3; -}; -#endif - -template struct num_elems; -template <> struct num_elems { static constexpr int value = 1; }; -template <> struct num_elems { static constexpr int value = 2; }; -template <> struct num_elems { static constexpr int value = 4; }; -template <> struct num_elems { static constexpr int value = 1; }; -template <> struct num_elems { static constexpr int value = 2; }; -#ifdef ENABLE_BF16 -template <> struct num_elems<__nv_bfloat16> { static constexpr int value = 1; }; -template <> struct num_elems<__nv_bfloat162> { static constexpr int value = 2; }; -#endif -#ifdef ENABLE_FP8 -template <> struct num_elems<__nv_fp8_e4m3> { static constexpr int value = 1; }; -template <> struct num_elems<__nv_fp8x2_e4m3> { static constexpr int value = 2; }; -#endif - -template struct packed_as; -template struct packed_as { using type = T; }; -template<> struct packed_as { using type = half2; }; -template<> struct packed_as { using type = float2; }; -template<> struct packed_as { using type = int16_t; }; -template<> struct packed_as { using type = int2; }; -template<> struct packed_as { using type = half; }; -template<> struct packed_as { using type = float; }; -#ifdef ENABLE_BF16 -template<> struct packed_as<__nv_bfloat16, 2> { using type = __nv_bfloat162; }; -template<> struct packed_as<__nv_bfloat162, 1> { using type = __nv_bfloat16; }; -#endif -#ifdef ENABLE_FP8 -template<> struct packed_as<__nv_fp8_e4m3, 2> { using type = __nv_fp8x2_e4m3; }; -template<> struct packed_as<__nv_fp8x2_e4m3, 1> { using type = __nv_fp8_e4m3; }; -template<> struct packed_as<__nv_fp8_e5m2, 2> { using type = __nv_fp8x2_e5m2; }; -template<> struct packed_as<__nv_fp8x2_e5m2, 1> { using type = __nv_fp8_e5m2; }; -#endif - -inline __device__ float2 operator*(float2 a, float2 b) { return make_float2(a.x * b.x, a.y * b.y); } -inline __device__ float2 operator+(float2 a, float2 b) { return make_float2(a.x + b.x, a.y + b.y); } -inline __device__ float2 operator-(float2 a, float2 b) { return make_float2(a.x - b.x, a.y - b.y); } - -inline __device__ float2 operator*(float2 a, float b) { return make_float2(a.x * b, a.y * b); } -inline __device__ float2 operator+(float2 a, float b) { return make_float2(a.x + b, a.y + b); } -inline __device__ float2 operator-(float2 a, float b) { return make_float2(a.x - b, a.y - b); } - -// clang-format on - -template -struct CudaDataType -{ -}; - -template <> -struct CudaDataType -{ - static constexpr cudaDataType_t value = cudaDataType::CUDA_R_32F; -}; - -template <> -struct CudaDataType -{ - static constexpr cudaDataType_t value = cudaDataType::CUDA_R_16F; -}; - -#ifdef ENABLE_BF16 -template <> -struct CudaDataType<__nv_bfloat16> -{ - static constexpr cudaDataType_t value = cudaDataType::CUDA_R_16BF; -}; -#endif - -inline int getSMVersion() -{ - int device{-1}; - check_cuda_error(cudaGetDevice(&device)); - int sm_major = 0; - int sm_minor = 0; - check_cuda_error(cudaDeviceGetAttribute(&sm_major, cudaDevAttrComputeCapabilityMajor, device)); - check_cuda_error(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device)); - return sm_major * 10 + sm_minor; -} - -inline int getDevice() -{ - int current_dev_id = 0; - check_cuda_error(cudaGetDevice(¤t_dev_id)); - return current_dev_id; -} - -inline int getDeviceCount() -{ - int count = 0; - check_cuda_error(cudaGetDeviceCount(&count)); - return count; -} - -/// @brief Identifies the memory type of the given pointer. -template -cudaMemoryType getPtrCudaMemoryType(T* ptr) -{ - cudaPointerAttributes attributes{}; - check_cuda_error(cudaPointerGetAttributes(&attributes, ptr)); - return attributes.type; -} - -/// Get the memory info -/// \return The free and total amount of memory in bytes -inline std::tuple getDeviceMemoryInfo(bool const useUvm) -{ - if (useUvm) - { - size_t freeSysMem = 0; - size_t totalSysMem = 0; -#ifndef _WIN32 // Linux - struct sysinfo info - { - }; - - sysinfo(&info); - totalSysMem = info.totalram * info.mem_unit; - freeSysMem = info.freeram * info.mem_unit; -#else // Windows - MEMORYSTATUSEX memInfo; - memInfo.dwLength = sizeof(memInfo); - GlobalMemoryStatusEx(&memInfo); - totalSysMem = memInfo.ullTotalPhys; - freeSysMem = memInfo.ullAvailPhys; -#endif // WIN32 - - TLLM_LOG_INFO("Using UVM based system memory for KV cache, total memory %0.2f GB, available memory %0.2f GB", - ((double) totalSysMem / 1e9), ((double) freeSysMem / 1e9)); - return {freeSysMem, totalSysMem}; - } - - size_t free = 0; - size_t total = 0; - check_cuda_error(cudaMemGetInfo(&free, &total)); - TLLM_LOG_DEBUG("Using GPU memory for KV cache, total memory %0.2f GB, available memory %0.2f GB", - ((double) total / 1e9), ((double) free / 1e9)); - return {free, total}; -} - -/// @brief Gets the memory allocation granularity for the current device. -/// -/// @return size_t The size of the smallest difference in memory size supported by the current device. -inline size_t getAllocationGranularity() -{ - auto const currentDevice = getDevice(); - ::CUmemAllocationProp prop = {}; - - prop.type = ::CU_MEM_ALLOCATION_TYPE_PINNED; - prop.location.type = ::CU_MEM_LOCATION_TYPE_DEVICE; - prop.location.id = currentDevice; - prop.requestedHandleTypes = ::CU_MEM_HANDLE_TYPE_NONE; - - // Get the minimum granularity supported for allocation with cuMemCreate() - size_t granularity = 0; - TLLM_CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); - return granularity; -} - -inline int getMultiProcessorCount() -{ - int device_id = 0; - int multi_processor_count = 0; - check_cuda_error(cudaGetDevice(&device_id)); - check_cuda_error(cudaDeviceGetAttribute(&multi_processor_count, cudaDevAttrMultiProcessorCount, device_id)); - return multi_processor_count; -} - -inline int getMaxSharedMemoryPerBlockOptin() -{ - int device_id = 0; - int max_shared_memory_per_block = 0; - check_cuda_error(cudaGetDevice(&device_id)); - check_cuda_error( - cudaDeviceGetAttribute(&max_shared_memory_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device_id)); - return max_shared_memory_per_block; -} - -template -inline size_t divUp(const T1& a, const T2& n) -{ - auto const tmp_a = static_cast(a); - auto const tmp_n = static_cast(n); - return (tmp_a + tmp_n - 1) / tmp_n; -} - -inline int roundUp(int a, int n) -{ - return divUp(a, n) * n; -} - -template ::value>, - typename = std::enable_if_t::value>> -auto constexpr ceilDiv(T numerator, U denominator) -{ - return (numerator + denominator - 1) / denominator; -} - -template -void printAbsMean(T const* buf, uint64_t size, cudaStream_t stream, std::string name = "") -{ - if (buf == nullptr) - { - TLLM_LOG_WARNING("%s is an nullptr, skip!", name.c_str()); - return; - } - cudaDeviceSynchronize(); - check_cuda_error(cudaGetLastError()); - T* h_tmp = new T[size]; - cudaMemcpyAsync(h_tmp, buf, sizeof(T) * size, cudaMemcpyDeviceToHost, stream); - cudaDeviceSynchronize(); - check_cuda_error(cudaGetLastError()); - double sum = 0.0f; - uint64_t zero_count = 0; - float max_val = -1e10; - bool find_inf = false; - for (uint64_t i = 0; i < size; i++) - { - if (std::isinf((float) (h_tmp[i]))) - { - find_inf = true; - continue; - } - sum += abs((double) h_tmp[i]); - if ((float) h_tmp[i] == 0.0f) - { - zero_count++; - } - max_val = max_val > abs(float(h_tmp[i])) ? max_val : abs(float(h_tmp[i])); - } - TLLM_LOG_INFO("%20s size: %u, abs mean: %f, abs sum: %f, abs max: %f, find inf: %s", name.c_str(), size, sum / size, - sum, max_val, find_inf ? "true" : "false"); - delete[] h_tmp; - cudaDeviceSynchronize(); - check_cuda_error(cudaGetLastError()); -} - -template -void printToStream(T const* result, int const size, FILE* strm) -{ - bool const split_rows = (strm == stdout); - if (result == nullptr) - { - TLLM_LOG_WARNING("It is an nullptr, skip! \n"); - return; - } - T* tmp = reinterpret_cast(malloc(sizeof(T) * size)); - check_cuda_error(cudaMemcpy(tmp, result, sizeof(T) * size, cudaMemcpyDeviceToHost)); - for (int i = 0; i < size; ++i) - { - fprintf(strm, "%f, ", static_cast(tmp[i])); - if (split_rows && ((i + 1) % 10) == 0) - fprintf(strm, "\n"); - } - if (!split_rows || (size % 10) != 0) - { - fprintf(strm, "\n"); - } - free(tmp); -} - -template -void printToScreen(T const* result, int const size) -{ - printToStream(result, size, stdout); -} - -template -void print2dToStream(T const* result, int const r, int const c, int const stride, FILE* strm) -{ - if (result == nullptr) - { - TLLM_LOG_WARNING("It is an nullptr, skip! \n"); - return; - } - for (int ri = 0; ri < r; ++ri) - { - T const* ptr = result + ri * stride; - printToStream(ptr, c, strm); - } - fprintf(strm, "\n"); -} - -template -void print2dToScreen(T const* result, int const r, int const c, int const stride) -{ - print2dToStream(result, r, c, stride, stdout); -} - -template -void print2dToFile(std::string fname, T const* result, int const r, int const c, int const stride) -{ - FILE* fp = fopen(fname.c_str(), "wt"); - if (fp != nullptr) - { - print2dToStream(result, r, c, stride, fp); - fclose(fp); - } -} - -inline void print_float_(float x) -{ - printf("%7.3f ", x); -} - -inline void print_element_(float x) -{ - print_float_(x); -} - -inline void print_element_(half x) -{ - print_float_((float) x); -} - -#ifdef ENABLE_BF16 -inline void print_element_(__nv_bfloat16 x) -{ - print_float_((float) x); -} -#endif - -#ifdef ENABLE_FP8 -inline void print_element_(__nv_fp8_e4m3 x) -{ - print_float_((float) x); -} -#endif - -inline void print_element_(uint32_t ul) -{ - printf("%7" PRIu32, ul); -} - -inline void print_element_(uint64_t ull) -{ - printf("%7" PRIu64, ull); -} - -inline void print_element_(int32_t il) -{ - printf("%7" PRId32, il); -} - -inline void print_element_(int64_t ill) -{ - printf("%7" PRId64, ill); -} - -template -inline void printMatrix(T const* ptr, int m, int k, int stride, bool is_device_ptr) -{ - T* tmp; - if (is_device_ptr) - { - // k < stride ; stride = col-dimension. - tmp = reinterpret_cast(malloc(m * stride * sizeof(T))); - check_cuda_error(cudaMemcpy(tmp, ptr, sizeof(T) * m * stride, cudaMemcpyDeviceToHost)); - cudaDeviceSynchronize(); - } - else - { - tmp = const_cast(ptr); - } - - for (int ii = -1; ii < m; ++ii) - { - if (ii >= 0) - { - printf("%07d ", ii); - } - else - { - printf(" "); - } - - for (int jj = 0; jj < k; jj += 1) - { - if (ii >= 0) - { - print_element_(tmp[ii * stride + jj]); - } - else - { - printf("%7d ", jj); - } - } - printf("\n"); - } - if (is_device_ptr) - { - free(tmp); - } -} - -template void printMatrix(float const* ptr, int m, int k, int stride, bool is_device_ptr); -template void printMatrix(half const* ptr, int m, int k, int stride, bool is_device_ptr); -#ifdef ENABLE_BF16 -template void printMatrix(__nv_bfloat16 const* ptr, int m, int k, int stride, bool is_device_ptr); -#endif -#ifdef ENABLE_FP8 -template void printMatrix(__nv_fp8_e4m3 const* ptr, int m, int k, int stride, bool is_device_ptr); -#endif -template void printMatrix(uint32_t const* ptr, int m, int k, int stride, bool is_device_ptr); -template void printMatrix(uint64_t const* ptr, int m, int k, int stride, bool is_device_ptr); -template void printMatrix(int const* ptr, int m, int k, int stride, bool is_device_ptr); - -} // namespace tensorrt_llm::common - -/* - * Macros compliant with TensorRT coding conventions - */ -#define TLLM_CUDA_CHECK(stat) \ - do \ - { \ - tensorrt_llm::common::check((stat), #stat, __FILE__, __LINE__); \ - } while (0) - -// We use singleton memory pool and the order of destructors depends on the compiler implementation. We find that the -// cudaFree/cudaFreeHost is called after cudaruntime destruction on Windows. There will be an cudaErrorCudartUnloading -// error. However, it is safe to ignore this error because the cuda runtime is already exited, we are no more worried -// about the memory leaks. -#define TLLM_CUDA_CHECK_FREE_RESOURCE(stat) \ - do \ - { \ - tensorrt_llm::common::checkEx((stat), {cudaSuccess, cudaErrorCudartUnloading}, #stat, __FILE__, __LINE__); \ - } while (0) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/logger.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/logger.cpp deleted file mode 100644 index 334ad236906..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/logger.cpp +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/common/logger.h" -#include "tensorrt_llm/common/cudaUtils.h" -#include "tensorrt_llm/common/tllmException.h" -#include - -namespace tensorrt_llm::common -{ - -Logger::Logger() -{ - char* isFirstRankOnlyChar = std::getenv("TLLM_LOG_FIRST_RANK_ONLY"); - bool isFirstRankOnly = (isFirstRankOnlyChar != nullptr && std::string(isFirstRankOnlyChar) == "ON"); - - auto const* levelName = std::getenv("TLLM_LOG_LEVEL"); - if (levelName != nullptr) - { - auto level = [levelName = std::string(levelName)]() - { - if (levelName == "TRACE") - return TRACE; - if (levelName == "DEBUG") - return DEBUG; - if (levelName == "INFO") - return INFO; - if (levelName == "WARNING") - return WARNING; - if (levelName == "ERROR") - return ERROR; - TLLM_THROW("Invalid log level: %s", levelName.c_str()); - }(); - // If TLLM_LOG_FIRST_RANK_ONLY=ON, set LOG LEVEL of other device to ERROR - if (isFirstRankOnly) - { - auto const deviceId = getDevice(); - if (deviceId != 1) - { - level = ERROR; - } - } - setLevel(level); - } -} - -void Logger::log(std::exception const& ex, Logger::Level level) -{ - log(level, "%s: %s", TllmException::demangle(typeid(ex).name()).c_str(), ex.what()); -} - -Logger* Logger::getLogger() -{ - thread_local Logger instance; - return &instance; -} -} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/logger.h b/sgl-kernel/3rdparty/tensorrt_llm/common/logger.h deleted file mode 100644 index df84e226389..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/logger.h +++ /dev/null @@ -1,190 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include - -#include "tensorrt_llm/common/assert.h" -#include "tensorrt_llm/common/stringUtils.h" - -namespace tensorrt_llm::common -{ - -class Logger -{ - -// On Windows, the file wingdi.h is included which has -// #define ERROR 0 -// This breaks everywhere ERROR is used in the Level enum -#ifdef _WIN32 -#undef ERROR -#endif // _WIN32 - -public: - enum Level - { - TRACE = 0, - DEBUG = 10, - INFO = 20, - WARNING = 30, - ERROR = 40 - }; - - static Logger* getLogger(); - - Logger(Logger const&) = delete; - void operator=(Logger const&) = delete; - -#if defined(_MSC_VER) - template - void log(Level level, char const* format, Args const&... args); - - template - void log(Level level, int rank, char const* format, Args const&... args); -#else - template - void log(Level level, char const* format, Args const&... args) __attribute__((format(printf, 3, 0))); - - template - void log(Level level, int rank, char const* format, Args const&... args) __attribute__((format(printf, 4, 0))); -#endif - - template - void log(Level level, std::string const& format, Args const&... args) - { - return log(level, format.c_str(), args...); - } - - template - void log(Level const level, int const rank, std::string const& format, Args const&... args) - { - return log(level, rank, format.c_str(), args...); - } - - void log(std::exception const& ex, Level level = Level::ERROR); - - Level getLevel() const - { - return level_; - } - - void setLevel(Level const level) - { - level_ = level; - log(INFO, "Set logger level to %s", getLevelName(level)); - } - - bool isEnabled(Level const level) const - { - return level_ <= level; - } - -private: - static auto constexpr kPREFIX = "[TensorRT-LLM]"; - -#ifndef NDEBUG - Level const DEFAULT_LOG_LEVEL = DEBUG; -#else - Level const DEFAULT_LOG_LEVEL = INFO; -#endif - Level level_ = DEFAULT_LOG_LEVEL; - - Logger(); // NOLINT(modernize-use-equals-delete) - - static inline char const* getLevelName(Level const level) - { - switch (level) - { - case TRACE: return "TRACE"; - case DEBUG: return "DEBUG"; - case INFO: return "INFO"; - case WARNING: return "WARNING"; - case ERROR: return "ERROR"; - } - - TLLM_THROW("Unknown log level: %d", level); - } - - static inline std::string getPrefix(Level const level) - { - return fmtstr("%s[%s] ", kPREFIX, getLevelName(level)); - } - - static inline std::string getPrefix(Level const level, int const rank) - { - return fmtstr("%s[%s][%d] ", kPREFIX, getLevelName(level), rank); - } -}; - -template -void Logger::log(Logger::Level level, char const* format, Args const&... args) -{ - if (isEnabled(level)) - { - auto const fmt = getPrefix(level) + format; - auto& out = level_ < WARNING ? std::cout : std::cerr; - if constexpr (sizeof...(args) > 0) - { - out << fmtstr(fmt.c_str(), args...); - } - else - { - out << fmt; - } - out << std::endl; - } -} - -template -void Logger::log(Logger::Level const level, int const rank, char const* format, Args const&... args) -{ - if (isEnabled(level)) - { - auto const fmt = getPrefix(level, rank) + format; - auto& out = level_ < WARNING ? std::cout : std::cerr; - if constexpr (sizeof...(args) > 0) - { - out << fmtstr(fmt.c_str(), args...); - } - else - { - out << fmt; - } - out << std::endl; - } -} - -#define TLLM_LOG(level, ...) \ - do \ - { \ - auto* const logger = tensorrt_llm::common::Logger::getLogger(); \ - if (logger->isEnabled(level)) \ - { \ - logger->log(level, __VA_ARGS__); \ - } \ - } while (0) - -#define TLLM_LOG_TRACE(...) TLLM_LOG(tensorrt_llm::common::Logger::TRACE, __VA_ARGS__) -#define TLLM_LOG_DEBUG(...) TLLM_LOG(tensorrt_llm::common::Logger::DEBUG, __VA_ARGS__) -#define TLLM_LOG_INFO(...) TLLM_LOG(tensorrt_llm::common::Logger::INFO, __VA_ARGS__) -#define TLLM_LOG_WARNING(...) TLLM_LOG(tensorrt_llm::common::Logger::WARNING, __VA_ARGS__) -#define TLLM_LOG_ERROR(...) TLLM_LOG(tensorrt_llm::common::Logger::ERROR, __VA_ARGS__) -#define TLLM_LOG_EXCEPTION(ex, ...) tensorrt_llm::common::Logger::getLogger()->log(ex, ##__VA_ARGS__) -} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/quantTypeUtils.cuh b/sgl-kernel/3rdparty/tensorrt_llm/common/quantTypeUtils.cuh deleted file mode 100644 index a228d3f9fc6..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/quantTypeUtils.cuh +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh" -#include "tensorrt_llm/common/cudaFp8Utils.h" -#include -#include -#include - -namespace tensorrt_llm -{ -namespace common -{ - -template -struct QuantTypeStaticVals; - -template <> -struct QuantTypeStaticVals -{ - static constexpr float MAX_VAL = 127.f; - static constexpr float MIN_SCALING_FACTOR = 0.f; - static constexpr float MIN_SCALING_FACTOR_RCP = FLT_MAX; -}; - -#ifdef ENABLE_FP8 - -template <> -struct QuantTypeStaticVals<__nv_fp8_e4m3> -{ - static constexpr float MAX_VAL = 448.f; - // Ref: https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L720 - static constexpr float MIN_SCALING_FACTOR = 1.0f / (448.f * 512.f); - static constexpr float MIN_SCALING_FACTOR_RCP = (448.f * 512.f); -}; - -#endif // ENABLE_FP8 - -} // namespace common -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/quantization.h b/sgl-kernel/3rdparty/tensorrt_llm/common/quantization.h deleted file mode 100644 index 052d9c8c819..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/quantization.h +++ /dev/null @@ -1,358 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include - -namespace tensorrt_llm -{ -namespace common -{ - -class QuantMode -{ - // [WARNING] KEEP BELOW DEFINITION IN SYNC WITH tensorrt_llm/quantization/mode.py -public: - using BaseType = std::uint32_t; - - explicit constexpr QuantMode(BaseType value) noexcept - : mValue{value} - { - } - - QuantMode() noexcept = default; - - constexpr QuantMode(QuantMode const&) noexcept = default; - - constexpr QuantMode& operator=(QuantMode const& other) noexcept = default; - - static constexpr QuantMode none() noexcept - { - return QuantMode(BaseType(0)); - } - - static constexpr QuantMode int4Weights() noexcept - { - return QuantMode(BaseType(1u) << 0); - } - - static constexpr QuantMode int8Weights() noexcept - { - return QuantMode(BaseType(1u) << 1); - } - - static constexpr QuantMode activations() noexcept - { - return QuantMode(BaseType(1u) << 2); - } - - static constexpr QuantMode perChannelScaling() noexcept - { - return QuantMode(BaseType(1u) << 3); - } - - static constexpr QuantMode perTokenScaling() noexcept - { - return QuantMode(BaseType(1u) << 4); - } - - static constexpr QuantMode perGroupScaling() noexcept - { - return QuantMode(BaseType(1u) << 5); - } - - static constexpr QuantMode int8KvCache() noexcept - { - return QuantMode(BaseType(1u) << 6); - } - - static constexpr QuantMode fp8KvCache() noexcept - { - return QuantMode(BaseType(1u) << 7); - } - - static constexpr QuantMode fp8Qdq() noexcept - { - return QuantMode(BaseType(1u) << 8); - } - - static constexpr QuantMode fp8RowWise() noexcept - { - return QuantMode(BaseType(1u) << 3 | BaseType(1u) << 4 | BaseType(1u) << 9); - } - - static constexpr QuantMode w4a8QServe() noexcept - { - return QuantMode(BaseType(1u) << 10); - } - - constexpr BaseType value() const noexcept - { - return mValue; - } - - constexpr bool isSet(QuantMode const& mode) const noexcept - { - return (mValue & mode.value()) == mode.value(); - } - - constexpr bool hasInt4Weights() const noexcept - { - return isSet(int4Weights()); - } - - constexpr bool hasInt8Weights() const noexcept - { - return isSet(int8Weights()); - } - - constexpr bool hasActivations() const noexcept - { - return isSet(activations()); - } - - constexpr bool hasPerChannelScaling() const noexcept - { - return isSet(perChannelScaling()); - } - - constexpr bool hasPerTokenScaling() const noexcept - { - return isSet(perTokenScaling()); - } - - constexpr bool hasPerGroupScaling() const noexcept - { - return isSet(perGroupScaling()); - } - - constexpr bool hasStaticActivationScaling() const noexcept - { - return !hasPerTokenScaling(); - } - - constexpr bool hasInt8KvCache() const noexcept - { - return isSet(int8KvCache()); - } - - constexpr bool hasFp8KvCache() const noexcept - { - return isSet(fp8KvCache()); - } - - constexpr bool hasFp8Qdq() const noexcept - { - return isSet(fp8Qdq()); - } - - constexpr bool hasFp8RowWise() const noexcept - { - return isSet(fp8RowWise()); - } - - constexpr bool hasKvCacheQuant() const noexcept - { - return hasInt8KvCache() || hasFp8KvCache(); - } - - static constexpr QuantMode fromDescription(bool quantizeWeights = false, bool quantizeActivations = false, - bool perToken = false, bool perChannel = false, bool perGroup = false, bool useInt4Weights = false, - bool useInt8KvCache = false, bool useFp8KvCache = false, bool useFp8Qdq = false, bool useFp8RowWise = false, - bool useW4a8QServe = false) - { - QuantMode quantMode{}; - if (quantizeWeights) - { - if (useInt4Weights) - quantMode += int4Weights(); - else - quantMode += int8Weights(); - } - - if (quantizeActivations) - { - quantMode += activations(); - } - - if (perChannel) - { - quantMode += QuantMode::perChannelScaling(); - } - if (perToken) - { - quantMode += QuantMode::perTokenScaling(); - } - if (perGroup) - { - quantMode += QuantMode::perGroupScaling(); - } - - if (useInt8KvCache) - { - quantMode += int8KvCache(); - } - - if (useFp8KvCache) - { - quantMode += fp8KvCache(); - } - - if (useFp8Qdq) - { - quantMode += fp8Qdq(); - } - - if (useFp8RowWise) - { - quantMode += fp8RowWise(); - } - - if (useW4a8QServe) - { - quantMode += w4a8QServe(); - } - - return quantMode; - } - - static constexpr QuantMode useSmoothQuant(bool perToken = false, bool perChannel = false) - { - return fromDescription(true, true, perToken, perChannel); - } - - static constexpr QuantMode useQServe(bool perGroup) - { - return fromDescription(true, true, false, false, perGroup, true, false, false, false, false, true); - } - - static constexpr QuantMode useWeightOnly(bool useInt4Weights = false, bool perGroup = false) - { - return fromDescription(true, false, false, false, perGroup, useInt4Weights); - } - - static QuantMode const fromQuantAlgo( - std::optional quantAlgo = std::nullopt, std::optional kvCacheQuantAlgo = std::nullopt) - { - QuantMode quantMode{}; - if (quantAlgo == "W8A16") - { - quantMode = useWeightOnly(false, false); - } - else if (quantAlgo == "W4A16") - { - quantMode = useWeightOnly(true, false); - } - else if (quantAlgo == "W4A16_AWQ") - { - quantMode = useWeightOnly(true, true); - } - else if (quantAlgo == "W4A8_AWQ") - { - quantMode = useWeightOnly(true, true); - } - else if (quantAlgo == "W4A8_QSERVE_PER_GROUP") - { - quantMode = useQServe(false); - } - else if (quantAlgo == "W4A8_QSERVE_PER_CHANNEL") - { - quantMode = useQServe(true); - } - else if (quantAlgo == "W4A16_GPTQ") - { - quantMode = useWeightOnly(true, true); - } - else if (quantAlgo == "W8A8_SQ_PER_CHANNEL") - { - quantMode = useSmoothQuant(false, true); - } - else if (quantAlgo == "W8A8_SQ_PER_TENSOR_PLUGIN") - { - quantMode = useSmoothQuant(false, false); - } - else if (quantAlgo == "W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN") - { - quantMode = useSmoothQuant(true, true); - } - else if (quantAlgo == "W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN") - { - quantMode = useSmoothQuant(false, true); - } - else if (quantAlgo == "W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN") - { - quantMode = useSmoothQuant(true, false); - } - else if (quantAlgo == "FP8") - { - quantMode = fromDescription(false, false, false, false, false, false, false, false, true); - } - else if (quantAlgo == "FP8_ROWWISE") - { - quantMode = fromDescription(false, false, true, true, false, false, false, false, false, true); - } - - if (kvCacheQuantAlgo == "INT8") - { - quantMode += int8KvCache(); - } - else if (kvCacheQuantAlgo == "FP8") - { - quantMode += fp8KvCache(); - } - - return quantMode; - } - - constexpr QuantMode operator+(QuantMode const& other) const noexcept - { - return QuantMode(mValue | other.mValue); - } - - constexpr QuantMode& operator+=(QuantMode const& other) noexcept - { - return *this = *this + other; - } - - constexpr QuantMode operator-(QuantMode const& other) const noexcept - { - return QuantMode(mValue & ~other.mValue); - } - - constexpr QuantMode& operator-=(QuantMode const& other) noexcept - { - return *this = *this - other; - } - - constexpr bool operator==(QuantMode const& other) const noexcept - { - return mValue == other.mValue; - } - - constexpr bool operator!=(QuantMode const& other) const noexcept - { - return !(*this == other); - } - -private: - BaseType mValue{0}; -}; - -} // namespace common -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/reduceKernelUtils.cuh b/sgl-kernel/3rdparty/tensorrt_llm/common/reduceKernelUtils.cuh deleted file mode 100644 index c5a4fe0e24e..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/reduceKernelUtils.cuh +++ /dev/null @@ -1,399 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include -#include -#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) -#include -#else -#include -#endif -#include "tensorrt_llm/common/cudaTypeUtils.cuh" -#include -#include -#include -#include -#include - -namespace cg = cooperative_groups; - -namespace tensorrt_llm -{ -namespace common -{ - -template -struct BytesToType; - -template <> -struct BytesToType<1> -{ - using type = uint8_t; -}; - -template <> -struct BytesToType<2> -{ - using type = uint16_t; -}; - -template <> -struct BytesToType<4> -{ - using type = uint32_t; -}; - -template <> -struct BytesToType<8> -{ - using type = uint64_t; -}; - -template <> -struct BytesToType<16> -{ - using type = float4; -}; - -template -__device__ inline void copy(void const* local, void* data) -{ - using T = typename BytesToType::type; - - T const* in = static_cast(local); - T* out = static_cast(data); - *out = *in; -} - -static float constexpr HALF_FLT_MAX = 65504.F; -#define FINAL_MASK 0xffffffff - -template -__inline__ __device__ T warpReduceSum(T val) -{ -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val = add(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80 - return val; -} - -/* Calculate the sum of all elements in a block */ -template -__inline__ __device__ T blockReduceSum(T val) -{ - static __shared__ T shared[32]; - int lane = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; - - val = warpReduceSum(val); - - if (lane == 0) - shared[wid] = val; - - __syncthreads(); - - // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent - // blockDim.x is not divided by 32 - val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T) (0.0f); - val = warpReduceSum(val); - - return val; -} - -template -__inline__ __device__ T warpReduceMax(T val) -{ -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); - return val; -} - -/* Calculate the maximum of all elements in a block */ -template -__inline__ __device__ T blockReduceMax(T val) -{ - static __shared__ T shared[32]; - int lane = threadIdx.x & 0x1f; // in-warp idx - int wid = threadIdx.x >> 5; // warp idx - - val = warpReduceMax(val); // get maxx in each warp - - if (lane == 0) // record in-warp maxx by warp Idx - shared[wid] = val; - - __syncthreads(); - - // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent - // blockDim.x is not divided by 32 - val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : -1e20f; - val = warpReduceMax(val); - - return val; -} - -/* Calculate the maximum of all elements in a block */ -template -__inline__ __device__ T blockAllReduceMax(T val) -{ - static __shared__ T shared[32]; - int lane = threadIdx.x & 0x1f; // in-warp idx - int wid = threadIdx.x >> 5; // warp idx - - val = warpReduceMax(val); // get maxx in each warp - - if (lane == 0) // record in-warp maxx by warp Idx - shared[wid] = val; - - __syncthreads(); - - // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent - // blockDim.x is not divided by 32 - val = (lane < (blockDim.x / 32.f)) ? shared[lane] : -1e20f; - val = warpReduceMax(val); - - return val; -} - -template -__inline__ __device__ T warpReduceSumV2(T* val) -{ -#pragma unroll - for (int i = 0; i < NUM; i++) - { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); - } - return (T) (0.0f); -} - -template -__inline__ __device__ T blockReduceSumV2(T* val) -{ - static __shared__ T shared[NUM][33]; - int lane = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; - - warpReduceSumV2(val); - - if (lane == 0) - { -#pragma unroll - for (int i = 0; i < NUM; i++) - { - shared[i][wid] = val[i]; - } - } - - __syncthreads(); - - bool is_mask = threadIdx.x < (blockDim.x / 32.f); -#pragma unroll - for (int i = 0; i < NUM; i++) - { - val[i] = is_mask ? shared[i][lane] : (T) (0.0f); - } - warpReduceSumV2(val); - return (T) 0.0f; -} - -template -__inline__ __device__ T warpReduceMaxV2(T* val) -{ -#pragma unroll - for (int i = 0; i < NUM; i++) - { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32)); - } - return (T) (0.0f); -} - -template -__inline__ __device__ T blockReduceMaxV2(T* val) -{ - static __shared__ T shared[32][NUM]; - int lane = threadIdx.x & 0x1f; // in-warp idx - int wid = threadIdx.x >> 5; // warp idx - - warpReduceMaxV2(val); // get maxx in each warp - - if (lane == 0) // record in-warp maxx by warp Idx - { -#pragma unroll - for (int i = 0; i < NUM; i++) - { - shared[wid][i] = val[i]; - } - } - - __syncthreads(); - - // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent - // blockDim.x is not divided by 32 - bool is_mask = threadIdx.x < (blockDim.x / 32.f); -#pragma unroll - for (int i = 0; i < NUM; i++) - { - val[i] = is_mask ? shared[lane][i] : (T) -1e20f; - } - warpReduceMaxV2(val); - - return (T) 0.0f; -} - -template -__inline__ __device__ void cgBlockReduceSumElements(float* element_list, float* cgBlockReduceSumElements_shm) -{ - cg::thread_block cta = cg::this_thread_block(); - cg::thread_block_tile<32> tile = cg::tiled_partition<32>(cta); - - int const tid = cta.thread_rank(); - int const blockz = blockDim.x; - for (int i = 0; i < NUM; i++) - { -#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) - cgBlockReduceSumElements_shm[i * blockz + tid] = cg::reduce(tile, element_list[i], cg::plus()); -#else - // TODO Add implementation here - if (threadIdx.x == 0 && blockIdx.x == 0) - { - printf("[ERROR] Not support cgBlockReduceSumElements when CUDA < 11 \n"); - assert(false); - } -#endif - } - cg::sync(cta); - if (tid == 0) - { -#pragma unroll - for (int i = 0; i < NUM; i++) - { - float beta = 0.0f; - for (int j = 0; j < blockz; j += 32) - { - beta += cgBlockReduceSumElements_shm[i * blockz + j]; - } - element_list[i] = beta; - } - } -} - -template -struct TopK -{ - int p[MAX_K]; // index, being -1 at the tail if the array is not full - T u[MAX_K]; // value in descend order, being -MAX_T_VAL if the element is invalid - - __device__ __forceinline__ void insert(T const elem, int const elem_id) - { - if (elem_id < 0) - { - return; - } - // Condition of updating the array - // 1. array is not full - // 2. elem is greater than the smallest (last) element in the array - // 3. elem is equal to the smallest (last) element in the array but its elem_id is smaller - bool const need_update - = (p[MAX_K - 1] == -1 || elem > u[MAX_K - 1] || elem == u[MAX_K - 1] && elem_id < p[MAX_K - 1]); - if (!need_update) - { - return; - } - // Find suitable index for the new element - int i; - for (i = MAX_K - 2; i >= 0; --i) - { - bool const need_decrease = (p[i] == -1 || elem > u[i] || elem == u[i] && elem_id < p[i]); - if (!need_decrease) - break; - } - // Move elements to correct positions - for (int k = MAX_K - 2; k >= i; --k) - { - p[k + 1] = p[k]; - u[k + 1] = u[k]; - } - p[i] = elem_id; - u[i] = elem; - } - - __device__ __forceinline__ void init() - { - T const MAX_T_VAL = (std::is_same::value) ? HALF_FLT_MAX : FLT_MAX; - for (int i = 0; i < MAX_K; i++) - { - p[i] = -1; - u[i] = -MAX_T_VAL; - } - } -}; - -template -__device__ __forceinline__ TopK reduce_topk_op(TopK const& a, TopK const& b) -{ - TopK res = a; - for (int i = 0; i < MAX_K; ++i) - res.insert(b.u[i], b.p[i]); - return res; -} - -template -struct TopK_2 -{ - int p = -1; - T u = -((std::is_same::value) ? HALF_FLT_MAX : FLT_MAX); - - __device__ __forceinline__ void insert(T elem, int elem_id) - { - if (elem > u) - { - u = elem; - p = elem_id; - } - } - - __device__ __forceinline__ void init() - { - u = -((std::is_same::value) ? HALF_FLT_MAX : FLT_MAX); - p = -1; - } -}; - -template -__device__ __forceinline__ TopK_2 reduce_topk_op_2(TopK_2 const& a, TopK_2 const& b) -{ - return a.u > b.u ? a : b; -} - -template -__device__ __forceinline__ T clamp_inf_for_half(float const input) -{ - return input; -} - -template <> -__device__ __forceinline__ half clamp_inf_for_half(float const input) -{ - // clamp inf values to enable fp16 training - return input > 0.0f ? (half) min(input, HALF_FLT_MAX - 1000) : (half) max(input, -HALF_FLT_MAX + 1000); -} - -} // namespace common -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.cpp deleted file mode 100644 index f1c6f88b431..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.cpp +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/common/stringUtils.h" -#include "tensorrt_llm/common/assert.h" - -#include -#include -#include -#include -#include - -namespace tensorrt_llm::common -{ - -namespace -{ -std::string vformat(char const* fmt, va_list args) -{ - va_list args0; - va_copy(args0, args); - auto const size = vsnprintf(nullptr, 0, fmt, args0); - if (size <= 0) - return ""; - - std::string stringBuf(size, char{}); - auto const size2 = std::vsnprintf(&stringBuf[0], size + 1, fmt, args); - - TLLM_CHECK_WITH_INFO(size2 == size, std::string(std::strerror(errno))); - - return stringBuf; -} - -} // namespace - -std::string fmtstr(char const* format, ...) -{ - va_list args; - va_start(args, format); - std::string result = vformat(format, args); - va_end(args); - return result; -}; - -std::unordered_set str2set(std::string const& input, char delimiter) -{ - std::unordered_set values; - if (!input.empty()) - { - std::stringstream valStream(input); - std::string val; - while (std::getline(valStream, val, delimiter)) - { - if (!val.empty()) - { - values.insert(val); - } - } - } - return values; -}; - -} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.h deleted file mode 100644 index 9c5ecde98c5..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.h +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#if ENABLE_BF16 -#include -#endif // ENABLE_BF16 -#include - -#include // std::make_unique -#include // std::stringstream -#include -#include -#include - -namespace tensorrt_llm::common -{ -#if ENABLE_BF16 -static inline std::basic_ostream& operator<<(std::basic_ostream& stream, __nv_bfloat16 const& val) -{ - stream << __bfloat162float(val); - return stream; -} -#endif // ENABLE_BF16 - -static inline std::basic_ostream& operator<<(std::basic_ostream& stream, __half const& val) -{ - stream << __half2float(val); - return stream; -} - -inline std::string fmtstr(std::string const& s) -{ - return s; -} - -inline std::string fmtstr(std::string&& s) -{ - return s; -} - -#if defined(_MSC_VER) -std::string fmtstr(char const* format, ...); -#else -std::string fmtstr(char const* format, ...) __attribute__((format(printf, 1, 2))); -#endif - -// __PRETTY_FUNCTION__ is used for neat debugging printing but is not supported on Windows -// The alternative is __FUNCSIG__, which is similar but not identical -#if defined(_WIN32) -#define __PRETTY_FUNCTION__ __FUNCSIG__ -#endif - -auto constexpr kDefaultDelimiter = ", "; - -template -inline TStream& arr2outCasted(TStream& out, T* arr, size_t size, char const* delim = kDefaultDelimiter) -{ - out << "("; - if (size > 0) - { - for (size_t i = 0; i < size - 1; ++i) - { - out << static_cast(arr[i]) << delim; - } - out << static_cast(arr[size - 1]); - } - out << ")"; - return out; -} - -template -inline TStream& arr2out(TStream& out, T* arr, size_t size, char const* delim = kDefaultDelimiter) -{ - return arr2outCasted(out, arr, size, delim); -} - -template -inline std::string arr2str(T* arr, size_t size, char const* delim = kDefaultDelimiter) -{ - std::stringstream ss; - return arr2out(ss, arr, size, delim).str(); -} - -template -inline std::string vec2str(std::vector const& vec, char const* delim = kDefaultDelimiter) -{ - return arr2str(vec.data(), vec.size(), delim); -} - -inline bool strStartsWith(std::string const& str, std::string const& prefix) -{ - return str.rfind(prefix, 0) == 0; -} - -/// @brief Split a string into a set of strings using a delimiter -std::unordered_set str2set(std::string const& input, char delimiter); - -} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.cpp deleted file mode 100644 index b410613d055..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.cpp +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/common/tllmException.h" -#include "tensorrt_llm/common/stringUtils.h" - -#include -#if !defined(_MSC_VER) -#include -#include -#include -#endif -#include - -namespace tensorrt_llm::common -{ - -namespace -{ -int constexpr VOID_PTR_SZ = 2 + sizeof(void*) * 2; -} - -#if !defined(_MSC_VER) - -TllmException::TllmException(char const* file, std::size_t line, std::string const& msg) - : std::runtime_error{""} -{ - mNbFrames = backtrace(mCallstack.data(), MAX_FRAMES); - auto const trace = getTrace(); - std::runtime_error::operator=( - std::runtime_error{fmtstr("%s (%s:%zu)\n%s", msg.c_str(), file, line, trace.c_str())}); -} -#else -TllmException::TllmException(char const* file, std::size_t line, std::string const& msg) - : mNbFrames{} - , std::runtime_error{fmtstr("%s (%s:%zu)", msg.c_str(), file, line)} -{ -} -#endif - -TllmException::~TllmException() noexcept = default; - -std::string TllmException::getTrace() const -{ -#if defined(_MSC_VER) - return ""; -#else - auto const trace = backtrace_symbols(mCallstack.data(), mNbFrames); - std::ostringstream buf; - for (auto i = 1; i < mNbFrames; ++i) - { - Dl_info info; - if (dladdr(mCallstack[i], &info) && info.dli_sname) - { - auto const clearName = demangle(info.dli_sname); - buf << fmtstr("%-3d %*p %s + %zd", i, VOID_PTR_SZ, mCallstack[i], clearName.c_str(), - static_cast(mCallstack[i]) - static_cast(info.dli_saddr)); - } - else - { - buf << fmtstr("%-3d %*p %s", i, VOID_PTR_SZ, mCallstack[i], trace[i]); - } - if (i < mNbFrames - 1) - buf << std::endl; - } - - if (mNbFrames == MAX_FRAMES) - buf << std::endl << "[truncated]"; - - std::free(trace); - return buf.str(); -#endif -} - -std::string TllmException::demangle(char const* name) -{ -#if defined(_MSC_VER) - return name; -#else - std::string clearName{name}; - auto status = -1; - auto const demangled = abi::__cxa_demangle(name, nullptr, nullptr, &status); - if (status == 0) - { - clearName = demangled; - std::free(demangled); - } - return clearName; -#endif -} - -} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.h b/sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.h deleted file mode 100644 index 47e0e63d3fc..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.h +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include - -#define NEW_TLLM_EXCEPTION(...) \ - tensorrt_llm::common::TllmException(__FILE__, __LINE__, tensorrt_llm::common::fmtstr(__VA_ARGS__)) - -namespace tensorrt_llm::common -{ - -class TllmException : public std::runtime_error -{ -public: - static auto constexpr MAX_FRAMES = 128; - - explicit TllmException(char const* file, std::size_t line, std::string const& msg); - - ~TllmException() noexcept override; - - [[nodiscard]] std::string getTrace() const; - - static std::string demangle(char const* name); - -private: - std::array mCallstack{}; - int mNbFrames; -}; - -} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/workspace.h b/sgl-kernel/3rdparty/tensorrt_llm/common/workspace.h deleted file mode 100644 index 1406e821333..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/common/workspace.h +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright (c) 1993-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include -#include - -namespace tensorrt_llm::common -{ - -std::uintptr_t constexpr kCudaMemAlign = 128; - -inline int8_t* alignPtr(int8_t* ptr, uintptr_t to) -{ - uintptr_t addr = (uintptr_t) ptr; - if (addr % to) - { - addr += to - addr % to; - } - return (int8_t*) addr; -} - -constexpr size_t alignSize(size_t size, size_t to) -{ - if ((size % to) != 0U) - { - size += to - size % to; - } - return size; -} - -inline int8_t* nextWorkspacePtrCommon(int8_t* ptr, uintptr_t previousWorkspaceSize, uintptr_t const alignment) -{ - uintptr_t addr = (uintptr_t) ptr; - addr += previousWorkspaceSize; - return alignPtr((int8_t*) addr, alignment); -} - -inline int8_t* nextWorkspacePtr(int8_t* ptr, uintptr_t previousWorkspaceSize) -{ - return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, kCudaMemAlign); -} - -inline int8_t* nextWorkspacePtr( - int8_t* const base, uintptr_t& offset, uintptr_t const size, uintptr_t const alignment = kCudaMemAlign) -{ - uintptr_t curr_offset = offset; - uintptr_t next_offset = curr_offset + ((size + alignment - 1) / alignment) * alignment; - int8_t* newptr = size == 0 ? nullptr : base + curr_offset; - offset = next_offset; - return newptr; -} - -inline int8_t* nextWorkspacePtrWithAlignment( - int8_t* ptr, uintptr_t previousWorkspaceSize, uintptr_t const alignment = kCudaMemAlign) -{ - return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, alignment); -} - -inline size_t calculateTotalWorkspaceSize( - size_t const* workspaces, int count, uintptr_t const alignment = kCudaMemAlign) -{ - size_t total = 0; - for (int i = 0; i < count; i++) - { - total += workspaces[i]; - if (workspaces[i] % alignment) - { - total += alignment - (workspaces[i] % alignment); - } - } - return total; -} - -}; // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/copy_red_global.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/copy_red_global.hpp deleted file mode 100644 index 61a41031bfb..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/copy_red_global.hpp +++ /dev/null @@ -1,352 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include - -#include -#include -#include - -// Config - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDACC_VER_MAJOR__ >= 10)) -#define CUTE_ARCH_RED_F16_SM70_ENABLED -#endif - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) -#define CUTE_ARCH_RED_VEC_SM90_ENABLED -#define CUTE_ARCH_RED_BF16_SM90_ENABLED -#endif - -namespace cute -{ - -////////////////////////////////// -// Wrapper around CUDA's atomicAdd -////////////////////////////////// - -template -struct TypedAtomicAdd -{ - using SRegisters = T[1]; - using DRegisters = T[1]; - - CUTE_HOST_DEVICE static constexpr void copy(T const& src, T& dst) - { - atomicAdd(&dst, src); - } -}; - -template -struct Copy_Traits> -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; - - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout::value>>>; - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout::value>>>; - - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; -}; - -////////////////////////////////// -// F16 ADD PTX -////////////////////////////////// - -struct SM70_RED_ADD_NOFTZ_F16 -{ - using SRegisters = uint16_t[1]; - using DRegisters = uint16_t[1]; - - CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst) - { -#if defined(CUTE_ARCH_RED_F16_SM70_ENABLED) - asm volatile("red.global.add.noftz.f16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0)); -#else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED."); -#endif - } -}; - -template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; - - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; - - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; - - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; -}; - -struct SM70_RED_ADD_NOFTZ_F16x2 -{ - using SRegisters = uint32_t[1]; - using DRegisters = uint32_t[1]; - - CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst) - { -#if defined(CUTE_ARCH_RED_F16_SM70_ENABLED) - asm volatile("red.global.add.noftz.f16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0)); -#else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED."); -#endif - } -}; - -template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; - - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; - - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; - - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; -}; - -struct SM90_RED_ADD_NOFTZ_F16x2_V2 -{ - using SRegisters = uint32_t[2]; - using DRegisters = uint64_t[1]; - - CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst) - { -#if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED) - asm volatile("red.global.add.noftz.v2.f16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1)); -#else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED."); -#endif - } -}; - -template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; - - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; - - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; - - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; -}; - -struct SM90_RED_ADD_NOFTZ_F16x2_V4 -{ - using SRegisters = uint32_t[4]; - using DRegisters = uint128_t[1]; - - CUTE_HOST_DEVICE static void copy( - uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst) - { -#if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED) - asm volatile("red.global.add.noftz.v4.f16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1), - "r"(src2), "r"(src3)); -#else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED."); -#endif - } -}; - -template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; - - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; - - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; - - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; -}; - -////////////////////////////////// -// BF16 ADD PTX -////////////////////////////////// - -struct SM90_RED_ADD_NOFTZ_BF16 -{ - using SRegisters = uint16_t[1]; - using DRegisters = uint16_t[1]; - - CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst) - { -#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED) - asm volatile("red.global.add.noftz.bf16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0)); -#else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."); -#endif - } -}; - -template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; - - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; - - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; - - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; -}; - -////////////////////////////////// - -struct SM90_RED_ADD_NOFTZ_BF16x2 -{ - using SRegisters = uint32_t[1]; - using DRegisters = uint32_t[1]; - - CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst) - { -#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED) - asm volatile("red.global.add.noftz.bf16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0)); -#else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."); -#endif - } -}; - -template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; - - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; - - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; - - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; -}; - -////////////////////////////////// - -struct SM90_RED_ADD_NOFTZ_BF16x2_V2 -{ - using SRegisters = uint32_t[2]; - using DRegisters = uint64_t[1]; - - CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst) - { -#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED) - asm volatile("red.global.add.noftz.v2.bf16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1)); -#else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."); -#endif - } -}; - -template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; - - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; - - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; - - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; -}; - -////////////////////////////////// - -struct SM90_RED_ADD_NOFTZ_BF16x2_V4 -{ - using SRegisters = uint32_t[4]; - using DRegisters = uint128_t[1]; - - CUTE_HOST_DEVICE static void copy( - uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst) - { -#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED) - asm volatile("red.global.add.noftz.v4.bf16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1), - "r"(src2), "r"(src3)); -#else - CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."); -#endif - } -}; - -template <> -struct Copy_Traits -{ - // Logical thread id to thread idx (one-thread) - using ThrID = Layout<_1>; - - // Map from (src-thr,src-val) to bit - using SrcLayout = Layout>; - - // Map from (dst-thr,dst-val) to bit - using DstLayout = Layout>; - - // Reference map from (thr,val) to bit - using RefLayout = SrcLayout; -}; - -////////////////////////////////// - -} // end namespace cute diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/mma.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/mma.h deleted file mode 100644 index 2362da4f7f2..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/mma.h +++ /dev/null @@ -1,120 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates exposing architecture support for multiply-add operations -*/ - -#pragma once -#include "cutlass_extensions/weight_only_quant_op.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace arch -{ - -// Tag which triggers MMA which will trigger -struct OpMultiplyAddDequantizeInterleavedBToA; - -/* - Below we have extra tags to signal what kind of dequantization we want to do - (per col, scale only fine grained, finegrained with zero). This still lets us - the existing template infrastructure (incl. that in CUTLASS). However, we - split out the template below into OpMultiplyAddDequantizeInterleavedBToA along - with the quantization op before instantiating the GEMM pieces. - - Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount of - code we need to duplicate. - */ -struct OpMultiplyAddDequantizeInterleavedBToA_percol_scale; -struct OpMultiplyAddDequantizeInterleavedBToA_fine_scale; -struct OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias; - -// The default just forwards the original operator -template -struct TagOperator -{ - using TaggedOperator = MmaOp; -}; - -// Specializations below attach more information to the operator -template <> -struct TagOperator -{ - using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale; -}; - -template <> -struct TagOperator -{ - using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scale; -}; - -template <> -struct TagOperator -{ - using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias; -}; - -// Here we instantiate some structs to "detag" the tagged operator. It splits it back to the original -// operator + the extra information. If no extra info was tagged, the dequant op per column scaling -// as a default. -template -struct DetagOperator -{ - using Operator = TaggedMmaOp; - static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; -}; - -template <> -struct DetagOperator -{ - using Operator = OpMultiplyAddDequantizeInterleavedBToA; - static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; -}; - -template <> -struct DetagOperator -{ - using Operator = OpMultiplyAddDequantizeInterleavedBToA; - static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; -}; - -template <> -struct DetagOperator -{ - using Operator = OpMultiplyAddDequantizeInterleavedBToA; - static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; -}; - -} // namespace arch -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/compute_occupancy.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/compute_occupancy.h deleted file mode 100644 index c83a9a074da..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/compute_occupancy.h +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include - -#include "cutlass/device_kernel.h" -#include "tensorrt_llm/common/cudaUtils.h" - -namespace tensorrt_llm -{ -namespace cutlass_extensions -{ - -template -inline int compute_occupancy_for_kernel() -{ - - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - - if (smem_size > (48 << 10)) - { - cudaFuncAttributes attr; - int device = 0; - int max_smem_per_block = 0; - tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device)); - tensorrt_llm::common::check_cuda_error( - cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); - if constexpr (enable_cutlass_3x) - { - tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, cutlass::device_kernel)); - } - else - { - tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, cutlass::Kernel)); - } - if (smem_size + attr.sharedSizeBytes >= static_cast(max_smem_per_block)) - { - // This should mean that - // cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) - // wouldn't work. In that case, we return an occupancy of 0. This will cause the heuristic to ignore this - // configuration. - return 0; - } - - if constexpr (enable_cutlass_3x) - { - tensorrt_llm::common::check_cuda_error(cudaFuncSetAttribute( - cutlass::device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - else - { - tensorrt_llm::common::check_cuda_error(cudaFuncSetAttribute( - cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - } - - int max_active_blocks = -1; - if constexpr (enable_cutlass_3x) - { - tensorrt_llm::common::check_cuda_error( - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, cutlass::device_kernel, - 128 * (GemmKernel::NumLoadWarpGroups + GemmKernel::NumMmaWarpGroups), smem_size)); - } - else - { - tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, cutlass::Kernel, GemmKernel::kThreadCount, smem_size)); - } - - return max_active_blocks; -} - -} // namespace cutlass_extensions -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp deleted file mode 100644 index bba25ec23a9..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp +++ /dev/null @@ -1,550 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Functor performing elementwise operations used by epilogues. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/epilogue/collective/detail.hpp" -#include "cutlass/fast_math.h" - -#include "cute/numeric/numeric_types.hpp" -#include "cute/tensor.hpp" -#include "cutlass/trace.h" - -#include "cutlass_extensions/arch/copy_red_global.hpp" -#include "cutlass_extensions/util/gather_tensor.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace epilogue -{ -namespace collective -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -class EpilogueMoeFusedFinalize -{ -public: - using EpilogueSchedule = PtrArrayNoSmemWarpSpecialized; - using DispatchPolicy = PtrArrayNoSmemWarpSpecialized; - - using ThreadEpilogueOp = ThreadEpilogueOp_; - using ElementOutput = typename ThreadEpilogueOp::ElementOutput; - using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; - using ElementCompute = typename ThreadEpilogueOp::ElementCompute; - using ElementIntermediate = typename ThreadEpilogueOp::ElementD; - - using ElementC = typename ThreadEpilogueOp::ElementC; - using StrideC = StrideC_; - using InternalStrideC = cute::remove_pointer_t; - using ElementD = ElementD_; - using StrideD = StrideD_; - using InternalStrideD = cute::remove_pointer_t; - - static_assert(!is_same_v, "Stride C must be a pointer"); - static_assert(is_same_v, "Stride D must not be a pointer"); - - using CopyAtomR2S = Copy_Atom; - using CopyAtomS2R = Copy_Atom; - using CopyAtomR2G = Copy_Atom; - static constexpr int AlignmentD = CopyAtomR2G::NumValSrc; - - using SmemLayoutD = decltype(tile_to_shape(SmemLayoutAtomD{}, EpilogueTile{})); - - constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); - - struct SharedStorage - { - alignas(SmemAlignmentD) cute::ArrayEngine> smem_D; - }; - - struct TensorMapStorage - { - }; - - struct Arguments - { - typename ThreadEpilogueOp::Params thread{}; - ElementC const** ptr_C{}; - StrideC dC{}; - ElementD* ptr_D{}; - StrideD dD{}; - ElementBias const* ptr_bias; - StrideBias dBias{}; - ElementScale const* ptr_scale; - StrideScale dScale{}; - int64_t const* group_offset{}; - int32_t const* scatter_index{}; - cutlass::FastDivmod num_rows_in_final_output; - }; - - using Params = Arguments; - - // - // Methods - // - - template - static constexpr Params to_underlying_arguments( - ProblemShape const&, Arguments const& args, [[maybe_unused]] void* workspace) - { - return args; - } - - template - static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count = 0) - { - return 0; - } - - template - static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, - void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) - { - return cutlass::Status::kSuccess; - } - - template - CUTLASS_HOST_DEVICE static bool can_implement( - [[maybe_unused]] ProblemShape problem_shape, [[maybe_unused]] Arguments const& args) - { - bool implementable = true; - if (problem_shape.is_host_problem_shape_available()) - { - // Check alignment for all problem sizes - for (int i = 0; i < problem_shape.groups(); i++) - { - auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1); - auto [M, N, K, L] = problem_shape_MNKL; - implementable = implementable - && cutlass::detail::check_alignment(cute::make_shape(M, N, L), InternalStrideD{}); - } - } - - if (!implementable) - { - CUTLASS_TRACE_HOST( - " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for selected global " - "reduction instruction.\n"); - } - return implementable; - } - - CUTLASS_HOST_DEVICE - EpilogueMoeFusedFinalize(Params const& params_) - : params(params_) - { - } - - CUTLASS_DEVICE - bool is_source_needed() - { - // For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta. - return params.ptr_C != nullptr - && (params.thread.beta_ptr_array || params.thread.beta_ptr || params.thread.beta != 0); - } - - template - CUTLASS_HOST_DEVICE void operator()(ProblemShapeMNKL problem_shape_mnkl, BlockShapeMNK blk_shape_MNK, - BlockCoordMNKL blk_coord_mnkl, cute::Tensor const& accumulators, TiledMma tiled_mma, - ResidueMNK residue_mnk, int thread_idx, [[maybe_unused]] char* smem_buf) - { - using namespace cute; - using X = Underscore; - - static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); - static_assert(is_static::value, "ThreadBlock tile shape must be static"); - static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); - static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); - - auto synchronize = [&]() - { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; - - // Separate out problem shape for convenience - auto M = get<0>(problem_shape_mnkl); - auto N = get<1>(problem_shape_mnkl); - auto L = get<3>(problem_shape_mnkl); - - auto mma_tile_m = tile_size<0>(tiled_mma); - auto mma_tile_n = tile_size<1>(tiled_mma); - auto epi_tile_m = size<0>(EpilogueTile{}); - auto epi_tile_n = size<1>(EpilogueTile{}); - - CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M"); - CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); - - // Batches are managed by using appropriate pointers to C and D matrices - int32_t const mock_L = 1; - int32_t const mock_l_coord = 0; - - // Slice to get the tile this CTA is responsible for - auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; - - // If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups. - // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups, - // we get the correct alpha/beta values for the current batch/group using group index. - ThreadEpilogueOp epilogue_op(params.thread, l_coord); - - SharedStorage& storage = *reinterpret_cast(smem_buf); - - Tensor sD_ = make_tensor(make_smem_ptr(storage.smem_D.begin()), SmemLayoutD{}); - Tensor sD = as_position_independent_swizzle_tensor(sD_); - - // Function to scatter output rows - auto& num_rows = params.num_rows_in_final_output; - auto read_scatter_map = IndexedGather(make_gmem_ptr(params.scatter_index + params.group_offset[l_coord])); - auto get_scatter_idx = [&](auto i) - { - auto scatter = read_scatter_map(i); - int quot, rem; - num_rows(quot, rem, scatter); - return rem; - }; - - // Represent the full output tensor - ElementC const* ptr_C = epilogue_op.is_source_needed() ? params.ptr_C[l_coord] : nullptr; - auto dC = epilogue_op.is_source_needed() ? params.dC[l_coord] : InternalStrideC{}; - Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C), make_shape(M, N, mock_L), dC); // (m,n,l) - Tensor mD_mnl = make_gather_tensor( - make_gmem_ptr(params.ptr_D), make_shape(M, N, mock_L), params.dD, get_scatter_idx); // (m,n,l) - - // Use fake shape for bias, it doesn't matter - bool const is_bias_needed = params.ptr_bias != nullptr; - Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_bias), make_shape(M, N, 1), params.dBias); - Tensor mScale_mnl = make_tensor( - make_gmem_ptr(params.ptr_scale + params.group_offset[l_coord]), make_shape(M, N), params.dScale); - - Tensor gC_mnl - = local_tile(mC_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) - Tensor gD_mnl - = local_tile(mD_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) - - Tensor gC = gC_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N) - Tensor gD = gD_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N) - - Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - - Tensor gBias_mnl - = local_tile(mBias_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) - Tensor gScale_mnl - = local_tile(mScale_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) - - Tensor gBias = gBias_mnl(_, _, m_coord, n_coord, l_coord); // (BLK_M,BLK_N) - Tensor gScale = gScale_mnl(_, _, m_coord, n_coord); // (BLK_M,BLK_N) - - Tensor gBias_epi = flat_divide(gBias, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor gScale_epi = flat_divide(gScale, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - - // Get the smallest tiled copy we can use to retile the accumulators - TiledCopy tiled_copy_C_atom - = make_tiled_copy_C_atom(Copy_Atom{}, tiled_mma); - TiledCopy tiled_r2s = make_tiled_copy_S(CopyAtomR2S{}, tiled_copy_C_atom); - - auto thread_r2s = tiled_r2s.get_thread_slice(thread_idx); - Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N) - Tensor tRS_sD = thread_r2s.partition_D(sD); // ((R2S,R2S_V),R2S_M,R2S_N) - Tensor tRS_rD = make_tensor(shape(tRS_sD)); // ((R2S,R2S_V),R2S_M,R2S_N) - - // Make a tiled copy vectorized along major direction of D - auto tiled_s2r = [&]() - { - if constexpr (cutlass::gemm::detail::is_k_major()) - { - constexpr int NumThreadsMajor = epi_tile_n / AlignmentD; - constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor; - return make_tiled_copy(CopyAtomS2R{}, - Layout, Int>, Stride, _1>>{}, - Layout>>{}); - } - else if constexpr (cutlass::gemm::detail::is_mn_major()) - { - constexpr int NumThreadsMajor = epi_tile_m / AlignmentD; - constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor; - return make_tiled_copy(CopyAtomS2R{}, - Layout, Int>, Stride<_1, Int>>{}, - Layout, _1>>{}); - } - else - { - static_assert(cute::is_void_v, "Unsupported D gmem layout."); - } - }(); - - auto thread_s2r = tiled_s2r.get_thread_slice(thread_idx); - Tensor tSR_sD = thread_s2r.partition_S(sD); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_gD = thread_s2r.partition_D(gD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) - Tensor tSR_gC = thread_s2r.partition_D(gC_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) - Tensor tSR_gBias = thread_s2r.partition_D(gBias_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) - Tensor tSR_gScale = thread_s2r.partition_D(gScale_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) - - // Allocate intermediate registers for a single subtile - Tensor tSR_rD = make_tensor(take<0, 3>(shape(tSR_gD))); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_rD_final = make_tensor(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_rC = make_tensor(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_rBias = make_tensor(tSR_gBias(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N) - Tensor tSR_rScale = make_tensor(tSR_gScale(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N) - - // Make an identity coordinate tensor for predicating our output MN tile - Tensor cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); - Tensor cD_epi = flat_divide(cD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) - Tensor tSR_cD = thread_s2r.partition_D(cD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) - - // epilogue subtile loop - CUTLASS_PRAGMA_UNROLL - for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m) - { - CUTLASS_PRAGMA_UNROLL - for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n) - { - int mma_m = (epi_m * epi_tile_m) / mma_tile_m; - int mma_n = (epi_n * epi_tile_n) / mma_tile_n; - Tensor tRS_rAcc_mn = tRS_rAcc(_, mma_m, mma_n); - - int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n); - int r2s_v = epi_n_in_mma * size(tRS_rD); - CUTLASS_PRAGMA_UNROLL - for (int epi_v = 0; epi_v < size(tRS_rD); ++epi_v) - { - tRS_rD(epi_v) = tRS_rAcc_mn(r2s_v + epi_v); - } - - copy(tiled_r2s, tRS_rD, tRS_sD); - synchronize(); - - copy(tiled_s2r, tSR_sD, tSR_rD); - synchronize(); - - Tensor tSR_gC_mn = tSR_gC(_, _, _, epi_m, epi_n); - Tensor tSR_gBias_mn = tSR_gBias(_, _, _, epi_m, epi_n); - Tensor tSR_gScale_mn = tSR_gScale(_, _, _, epi_m, epi_n); - Tensor tSR_cD_mn = tSR_cD(_, _, _, epi_m, epi_n); - Tensor tSR_gD_mn = tSR_gD(_, _, _, epi_m, epi_n); - - if (epilogue_op.is_source_needed()) - { - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < size<1>(tSR_rD); ++m) - { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < size<2>(tSR_rD); ++n) - { - if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) - { - copy(tSR_gC_mn(_, m, n), tSR_rC(_, m, n)); - if (is_bias_needed) - { - copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n)); - } - copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n)); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<0>(tSR_rD); ++i) - { - auto epi_value = epilogue_op(tSR_rD(i, m, n), tSR_rC(i, m, n)); - if (is_bias_needed) - { - epi_value += static_cast(tSR_rBias(i, m, n)); - } - tSR_rD_final(i, m, n) = static_cast(tSR_rScale(i, m, n) * epi_value); - } - copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n)); - } - } - } - } - else - { - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < size<1>(tSR_rD); ++m) - { - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < size<2>(tSR_rD); ++n) - { - if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) - { - if (is_bias_needed) - { - copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n)); - } - copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n)); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size<0>(tSR_rD); ++i) - { - auto epi_value = epilogue_op(tSR_rD(i, m, n)); - if (is_bias_needed) - { - epi_value += static_cast(tSR_rBias(i, m, n)); - } - tSR_rD_final(i, m, n) = static_cast(tSR_rScale(i, m, n) * epi_value); - } - copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n)); - } - } - } - } - } - } - } - -private: - Params params; -}; - -namespace detail -{ - -template -constexpr auto get_vectorized_atomic_add_op() -{ - using namespace cute; - - auto constexpr MaxVecSize = size(MaxVec{}); - - if constexpr (is_same_v) - { - if constexpr (MaxVecSize >= 8) - { - return SM90_RED_ADD_NOFTZ_F16x2_V4{}; - } - else if constexpr (MaxVecSize >= 4) - { - return SM90_RED_ADD_NOFTZ_F16x2_V2{}; - } - else if constexpr (MaxVecSize >= 2) - { - return SM70_RED_ADD_NOFTZ_F16x2{}; - } - else - { - return SM70_RED_ADD_NOFTZ_F16{}; - } - } - else if constexpr (is_same_v) - { - if constexpr (MaxVecSize >= 8) - { - return SM90_RED_ADD_NOFTZ_BF16x2_V4{}; - } - else if constexpr (MaxVecSize >= 4) - { - return SM90_RED_ADD_NOFTZ_BF16x2_V2{}; - } - else if constexpr (MaxVecSize >= 2) - { - return SM90_RED_ADD_NOFTZ_BF16x2{}; - } - else - { - return SM90_RED_ADD_NOFTZ_BF16{}; - } - } - else - { - // non-vectorized atomic add for all other types until supported - return TypedAtomicAdd{}; - } -} - -} // namespace detail - -template -struct EpilogueMoeFusedFinalizeBuilder -{ - - // assuming cooperative kernel schedule - using EpiTileN = decltype(cute::min(size<1>(TileShape{}), _32{})); - using EpilogueTile = Shape<_128, EpiTileN>; - - // Output of linear combination is ElementCompute instead of ElementD - // since we will be doing more computate on it, no need to cast yet. - using ThreadEpilogueOp - = cutlass::epilogue::thread::LinearCombination; - - using SmemLayoutAtomD - = decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()); - using CopyAtomR2S = decltype(detail::sm90_get_smem_store_op_for_accumulator()); - using CopyAtomS2R = DefaultCopy; - using CopyAtomR2G = decltype(detail::get_vectorized_atomic_add_op()); - - template - struct Sm90TmaWarpSpecializedAdapterWithSmemStorage : detail::Sm90TmaWarpSpecializedAdapter - { - // We need to override this one using declaration because otherwise we double up on the smem - using TensorMapStorage = typename EpilogueOp::TensorMapStorage; - - using Base = detail::Sm90TmaWarpSpecializedAdapter; - - CUTLASS_HOST_DEVICE - Sm90TmaWarpSpecializedAdapterWithSmemStorage( - typename EpilogueOp::Params const& params, [[maybe_unused]] typename Base::TensorStorage& shared_tensors) - : Base(params) - { - } - - // These functions depend on the type of TensorMapStorage - template - CUTLASS_DEVICE void tensormaps_perform_update([[maybe_unused]] TensorMapStorage& shared_tensormap, - [[maybe_unused]] typename EpilogueOp::Params const& params, - [[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] int32_t next_batch) - { - } - - template - CUTLASS_DEVICE void tensormaps_cp_fence_release([[maybe_unused]] TensorMapStorage& shared_tensormap, - [[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] uint32_t lane_predicate) - { - } - }; - - using CollectiveOp = Sm90TmaWarpSpecializedAdapterWithSmemStorage< - EpilogueMoeFusedFinalize>; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace collective -} // namespace epilogue -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h deleted file mode 100644 index f3c622b88a5..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h +++ /dev/null @@ -1,105 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Functor performing linear combination with a maximum operation used by epilogues. -*/ - -#pragma once - -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/epilogue/thread/activation.h" -#include "cutlass/epilogue/thread/linear_combination_generic.h" -#include "cutlass/epilogue/thread/scale_type.h" -#include "cutlass/functional.h" -#include "cutlass/half.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/numeric_types.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace epilogue -{ -namespace thread -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -__forceinline__ __device__ float copysignf_pos(float a, float b) -{ - float r; - r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); - return r; -} - -__forceinline__ __device__ float tanh_opt(float x) -{ -#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750) - float const exp_val = -1.f * fabs(2 * x); - return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); -#else - return fast_tanh(x); -#endif -} - -///////////////////////////////////////////////////////////////////////////////////////////////// -template <> -struct GELU_taylor -{ - static bool const kIsHeavy = true; - - CUTLASS_DEVICE - float operator()(float const& z) const - { - - float k0 = float(0.7978845608028654); - float k1 = float(0.044715); - - return float(cutlass::constants::half() * z - * (cutlass::constants::one() + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z)))); - } - - using Params = LinearCombinationGenericParams; - - CUTLASS_DEVICE - float operator()(float const& scalar, Params const& params_) const - { - return this->operator()(scalar); - } -}; - -} // namespace thread -} // namespace epilogue -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h deleted file mode 100644 index d3d4d0a45ab..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h +++ /dev/null @@ -1,352 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one scaling factor per row, and one per column. - - original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h - -*/ - -#pragma once - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#include "cutlass/arch/memory.h" -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/numeric_conversion.h" -#include "tensorrt_llm/common/quantization.h" - -namespace tk = tensorrt_llm::common; - -namespace cutlass -{ -namespace epilogue -{ -namespace threadblock -{ - -template -class EpilogueVisitorPerRowPerCol -{ -public: - using ThreadblockShape = ThreadblockShape_; - static int const kThreadCount = ThreadCount; - - using ScaleTileIterator = ScaleTileIterator_; - using OutputTileIterator = OutputTileIterator_; - using ElementwiseFunctor = ElementwiseFunctor_; - - static int const kIterations = OutputTileIterator::kIterations; - static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; - - using ElementOutput = typename OutputTileIterator::Element; - using LayoutOutput = cutlass::layout::RowMajor; - using ElementAccumulator = ElementAccumulator_; - - using AlphaScaleElementType = typename ScaleTileIterator::Element; - - using ElementCompute = ElementCompute_; - using AccumulatorFragment = Array; - using ComputeFragment = Array; - using OutputVector = Array; - - static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth; - static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1); - - /// Argument structure - struct Arguments - { - - typename ElementwiseFunctor::Params elementwise; - int64_t batch_stride_alpha; - int64_t batch_stride_C; - int64_t batch_stride_D; - - // - // Methods - // - Arguments() - : batch_stride_alpha(0) - , batch_stride_C(0) - , batch_stride_D(0) - { - } - - Arguments(typename ElementwiseFunctor::Params elementwise_) - : elementwise(elementwise_) - , batch_stride_alpha(0) - , batch_stride_C(0) - , batch_stride_D(0) - { - } - - Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_, - int64_t batch_stride_C_, int64_t batch_stride_D_) - : elementwise(elementwise_) - , batch_stride_alpha(batch_stride_alpha_) - , batch_stride_C(batch_stride_C_) - , batch_stride_D(batch_stride_D_) - { - } - }; - - struct Params - { - - typename ElementwiseFunctor::Params elementwise; - int64_t batch_stride_alpha; - int64_t batch_stride_C; - int64_t batch_stride_D; - - // - // Methods - // - CUTLASS_HOST_DEVICE - Params() {} - - CUTLASS_HOST_DEVICE - Params(Arguments const& args) - : elementwise(args.elementwise) - , batch_stride_alpha(args.batch_stride_alpha) - , batch_stride_C(args.batch_stride_C) - , batch_stride_D(args.batch_stride_D) - { - } - }; - - /// Shared storage - struct SharedStorage - { - }; - -private: - Params const& params_; - SharedStorage& shared_storage_; - MatrixCoord extent_; - MatrixCoord extent_real_; - ElementwiseFunctor elementwise_; - - bool const per_token_quant_; - bool const per_channel_quant_; - - AlphaScaleElementType* ptr_alpha_row_; - AlphaScaleElementType* ptr_alpha_col_; - ScaleTileIterator iterator_alpha_col_; - OutputTileIterator iterator_C_; - OutputTileIterator iterator_D_; - - AlphaScaleElementType element_alpha_row_ = 1.0f; - AlphaScaleElementType element_alpha_col_ = 1.0f; - typename ScaleTileIterator::Fragment fragment_alpha_col_; - typename OutputTileIterator::Fragment fragment_C_; - typename OutputTileIterator::Fragment fragment_D_; - - ElementAccumulator beta_; - - int column_offset_; - - MatrixCoord thread_offset_; - -public: - CUTLASS_DEVICE - EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage, - cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx, - typename ScaleTileIterator::Params params_alpha_col, typename OutputTileIterator::Params params_C, - typename OutputTileIterator::Params params_D, tk::QuantMode quant_option, AlphaScaleElementType* ptr_alpha_row, - AlphaScaleElementType* ptr_alpha_col, typename OutputTileIterator::Element* ptr_C, - typename OutputTileIterator::Element* ptr_D, - cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), int column_offset = 0, - cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0)) - : params_(params) - , shared_storage_(shared_storage) - , extent_(problem_size) - , elementwise_(params.elementwise) - , per_token_quant_(quant_option.hasPerTokenScaling()) - , per_channel_quant_(quant_option.hasPerChannelScaling()) - , ptr_alpha_row_(ptr_alpha_row) - , ptr_alpha_col_(ptr_alpha_col) - , iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset) - , iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset) - , iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset) - , extent_real_(problem_size_real) - { - beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta); - - if (beta_ == ElementAccumulator()) - { - iterator_C_.clear_mask(); - } - - if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr)) - { - element_alpha_col_ = *ptr_alpha_col_; - } - - if (!per_token_quant_ && (ptr_alpha_row_ != nullptr)) - { - element_alpha_row_ = *ptr_alpha_row_; - } - } - - /// Helper to indicate split-K behavior - CUTLASS_DEVICE - void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme - int split_k_slices) - { ///< Total number of split-K slices - } - - /// Called to set the batch index - CUTLASS_DEVICE - void set_batch_index(int batch_idx) - { - iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha); - iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C); - iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D); - } - - /// Called at the start of the epilogue just before iterating over accumulator slices - CUTLASS_DEVICE - void begin_epilogue() - { - if (per_channel_quant_) - { - iterator_alpha_col_.load(fragment_alpha_col_); - } - } - - /// Called at the start of one step before starting accumulator exchange - CUTLASS_DEVICE - void begin_step(int step_idx) - { - fragment_D_.clear(); - fragment_C_.clear(); - - if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) - { - iterator_C_.load(fragment_C_); - ++iterator_C_; - } - } - - /// Called at the start of a row - CUTLASS_DEVICE - void begin_row(int row_idx) - { - // load alpha_row in begin_step only when per token(row) scaling is used - if (per_token_quant_) - { - int thread_offset_row - = iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row(); - - arch::global_load( - element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row()); - } - } - - /// Called after accumulators have been exchanged for each accumulator vector - CUTLASS_DEVICE - void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum) - { - - NumericArrayConverter source_converter; - - ComputeFragment result = source_converter(accum); - if (per_channel_quant_) - { - ComputeFragment alpha_col = reinterpret_cast(&fragment_alpha_col_)[column_idx]; - result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_); - } - else - { - result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_); - } - - // Convert to the output - NumericArrayConverter output_converter; - OutputVector& output = reinterpret_cast(&fragment_D_)[frag_idx]; - output = output_converter(result); - } - - /// Called at the end of a row - CUTLASS_DEVICE - void end_row(int row_idx) {} - - /// Called after all accumulator elements have been visited - CUTLASS_DEVICE - void end_step(int step_idx) - { - - iterator_D_.store(fragment_D_); - ++iterator_D_; - } - - /// Called after all steps have been completed - CUTLASS_DEVICE - void end_epilogue() {} - -private: - CUTLASS_DEVICE - ComputeFragment per_token_channel_scale_accumulator_( - ComputeFragment const& accum, ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row) - { - - ComputeFragment result; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < ComputeFragment::kElements; ++i) - { - result[i] = accum[i] * (scale_col[i] * scale_row); - } - - return result; - } - - CUTLASS_DEVICE - ComputeFragment per_token_scale_accumulator_( - ComputeFragment const& accum, AlphaScaleElementType const& scale_col, AlphaScaleElementType const& scale_row) - { - - ComputeFragment result; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < ComputeFragment::kElements; ++i) - { - result[i] = accum[i] * (scale_col * scale_row); - } - - return result; - } -}; - -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h deleted file mode 100644 index 6f26d790170..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h +++ /dev/null @@ -1,282 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. - - The epilogue rearranges the result of a matrix product through shared memory to match canonical - tensor layouts in global memory. Epilogues support conversion and reduction operations. - - original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h - -*/ - -#pragma once - -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/platform/platform.h" - -#include "cutlass/gemm/gemm.h" - -#include "cutlass/epilogue/thread/linear_combination.h" -#include "cutlass/epilogue/thread/linear_combination_clamp.h" -#include "cutlass/epilogue/thread/linear_combination_gelu.h" -#include "cutlass/epilogue/thread/linear_combination_hardswish.h" -#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" -#include "cutlass/epilogue/thread/linear_combination_relu.h" -#include "cutlass/epilogue/thread/linear_combination_relu0.h" -#include "cutlass/epilogue/thread/linear_combination_sigmoid.h" - -#include "cutlass/epilogue/thread/conversion_op.h" -#include "cutlass/epilogue/thread/reduction_op.h" - -#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" - -#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" -#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" -#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h" -#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h" -#include "cutlass/epilogue/threadblock/shared_load_iterator.h" -#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h" -#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h" -#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" -#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" -#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h" - -#include "cutlass/epilogue/threadblock/epilogue.h" -#include "cutlass/epilogue/threadblock/interleaved_epilogue.h" - -#include "cutlass/layout/permute.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace epilogue -{ -namespace threadblock -{ - -//////////////////////////////////////////////////////////////////////////////// - -namespace detail -{ - -/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts. -template -struct DefaultIteratorsTensorOp -{ - using WarpTileIterator - = cutlass::epilogue::warp::TileIteratorTensorOpMixed; - - using SharedLoadIterator - = cutlass::epilogue::threadblock::SharedLoadIteratorMixed; - - static int const kFragmentsPerIteration = 2; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace detail - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator used to load output tile from shared memory in epilogue. -/// -/// Satisfies: ReadableTileIterator -/// -template -class SharedLoadIteratorMixed -{ -public: - using ThreadMap = ThreadMap_; - using Shape = typename ThreadMap::Shape; - - using Element = int32_t; - - using Layout = layout::RowMajor; - using TensorRef = TensorRef; - using ConstTensorRef = typename TensorRef::ConstTensorRef; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using TensorCoord = MatrixCoord; - - static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; - - static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits::value / 8; - - static int const kThreads = ThreadMap::kThreads; - - /// Fragment object - using Fragment = Array; - - /// Memory access size - using AccessType = AlignedArray; - - /// Vector type used for SMEM loads - using LoadType = AlignedArray::value, ThreadMap::kElementsPerAccess), - const_min(16, kAlignment)>; - - static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements; - -private: - // - // Data members - // - - /// Byte-level pointer - LoadType const* pointers_[kLoadsPerAccess]; - - /// Stride along adjacent rows in units of LoadType - int stride_; - -public: - // - // Methods - // - - /// Constructor - CUTLASS_DEVICE - SharedLoadIteratorMixed(TensorRef ref, int thread_idx) - : stride_((ref.stride(0) / LoadType::kElements)) - { - - TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); - - // Initialize pointers - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kLoadsPerAccess; ++i) - { - pointers_[i] = reinterpret_cast(ref.data()); - - int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess; - int bank_offset = (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess; - - col_idx += (bank_offset + i) % kLoadsPerAccess; - - pointers_[i] += thread_offset.row() * stride_ + col_idx; - } - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) - { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kLoadsPerAccess; ++i) - { - pointers_[i] += pointer_offset / LoadType::kElements; - } - } - - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const& offset) - { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kLoadsPerAccess; ++i) - { - pointers_[i] - += offset.row() * Shape::kRow * stride_ + offset.column() * Shape::kColumn / LoadType::kElements; - } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const - { - - CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) - { - - CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) - { - - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) - { - - int row_ptr_offset = row * ThreadMap::Delta::kRow * stride_ - + group * ThreadMap::Delta::kGroup * stride_ + cluster * ThreadMap::Delta::kCluster * stride_ - + pointer_offset / LoadType::kElements; - - int frag_row_idx - = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - LoadType* frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) - { - - int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < kLoadsPerAccess; ++v) - { - - int vector_idx - = (column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess); - - LoadType const* memory_pointer = pointers_[v] + row_ptr_offset; - - frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx]; - } - } - } - } - } - } - - /// Loads a fragment - CUTLASS_DEVICE - void load(Fragment& frag) const - { - - load_with_pointer_offset(frag, 0); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h deleted file mode 100644 index 233d633a823..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h +++ /dev/null @@ -1,141 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/** - * @file epilogue_helpers.h - * - * This file includes types for the epilogues. The empty structs exist so we can signal to template - * code the type of epilogue we want to run, and let the underlying code specify the details such as - * element types, accumulator type and elements per vector access. - * - */ - -#pragma once - -#include "cutlass/epilogue/thread/linear_combination.h" -#include "cutlass/epilogue/thread/linear_combination_generic.h" -#include "cutlass/epilogue/thread/linear_combination_relu.h" -#include "cutlass/epilogue/thread/linear_combination_silu.h" -#include "cutlass_extensions/epilogue/thread/fused_activations.h" -#include - -namespace tensorrt_llm -{ -namespace cutlass_extensions -{ - -struct EpilogueOpBiasSilu -{ -}; - -struct EpilogueOpBiasReLU -{ -}; - -struct EpilogueOpBiasFtGelu -{ -}; - -struct EpilogueOpBias -{ -}; - -struct EpilogueOpDefaultSilu -{ -}; - -struct EpilogueOpDefaultReLU -{ -}; - -struct EpilogueOpDefaultFtGelu -{ -}; - -struct EpilogueOpDefault -{ -}; - -template -struct Epilogue -{ - static_assert(sizeof(ElementType) == 0, "Unrecognized Epilogue Tag"); -}; - -constexpr auto BiasScaleMode = cutlass::epilogue::thread::ScaleType::NoBetaScaling; - -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombinationSilu; -}; - -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombinationRelu; -}; - -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombinationGeneric; -}; - -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombination; -}; - -constexpr auto DefaultScaleMode = cutlass::epilogue::thread::ScaleType::Default; - -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombinationSilu; -}; - -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombinationRelu; -}; - -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombinationGeneric; -}; - -template -struct Epilogue -{ - using Op = cutlass::epilogue::thread::LinearCombination; -}; - -} // namespace cutlass_extensions -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl deleted file mode 100644 index 593eca06e3d..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl +++ /dev/null @@ -1,221 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/arch/mma.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/gemm.h" - -#include "cutlass/gemm/collective/builders/sm90_common.inl" - -// SM90 Collective Builders should be used only starting CUDA 12.0 -#if (__CUDACC_VER_MAJOR__ >= 12) -#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED -#endif - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::collective -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail -{ - -// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. -template -constexpr int compute_stage_count_or_override_gated(StageCountAutoCarveout stage_count) -{ - // 32 bytes to account for barriers etc. - constexpr int stage_barrier_bytes = 32; - constexpr int a_bits = static_cast(sizeof_bits::value); - constexpr int b_bits = static_cast(sizeof_bits::value); - constexpr int stage_bytes = [&]() -> int - { - if constexpr (SwapAB) - { - return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / 8 - + (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 + stage_barrier_bytes; - } - else - { - return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 - + (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / 8 + stage_barrier_bytes; - } - }(); - - return (CapacityBytes - carveout_bytes) / stage_bytes; -} - -} // namespace detail - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA_TMA_WS_SS -template class Activation, bool SwapAB> -struct CollectiveBuilderGated - || cute::is_same_v - || cute::is_same_v - || cute::is_same_v) &¬ detail:: - is_use_rmem_A()>> -{ - static_assert(is_static::value); - static_assert(is_static::value); -#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED - static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); -#endif - static_assert(detail::is_aligned(), - "Should meet TMA alignment requirement\n"); - - static constexpr bool IsArrayOfPointersGemm - = (cute::is_same_v); - static constexpr bool IsFP8Input = detail::is_input_fp8(); - static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm), - "Kernel[Array/Group]TmaWarpSpecializedCooperative is only compatible with FP8 FastAccum version right now\n"); - - // For fp32 types, map to tf32 MMA value type - using MmaElementA = cute::conditional_t, tfloat32_t, ElementA>; - using MmaElementB = cute::conditional_t, tfloat32_t, ElementB>; - - static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); - static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); - - using AtomLayoutMNK = cute::conditional_t - || IsArrayOfPointersGemm, - Layout>, Layout>>; - - using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector(), - AtomLayoutMNK{})); - - using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); - using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); - - using SmemLayoutAtomA = decltype(detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutAtomB = decltype(detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - - static constexpr int PipelineStages - = detail::compute_stage_count_or_override_gated(StageCountType{}); - using DispatchPolicy = cute::conditional_t, - /* For FP8 use a separate mainloop compared to other datatypes */ - cute::conditional_t, - MainloopSm90TmaGmmaWarpSpecialized>>; - - using SmemCopyAtomA = void; - using SmemCopyAtomB = void; - - using CollectiveOp = CollectiveMmaGated, - ElementB, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, - GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// GMMA_TMA_WS_FP8_FAST_ACCUM_SS -template class Activation, bool SwapAB> -struct CollectiveBuilderGated - || cute::is_same_v - || cute::is_same_v - || cute::is_same_v>> -{ - static_assert(is_static::value); - static_assert(is_static::value); - static_assert(detail::is_aligned(), - "Not meet TMA alignment requirement yet\n"); - static_assert( - detail::is_input_fp8(), "Only FP8 datatypes are compatible with these kernel schedules\n"); - // Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder - static_assert(!detail::is_use_rmem_A(), - "Not supported for fp8 non-TN warp specialized kernels yet\n"); -#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED - static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); -#endif - - static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); - static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); - - static constexpr bool IsArrayOfPointersGemm - = (cute::is_same_v); - using AtomLayoutMNK - = cute::conditional_t - || IsArrayOfPointersGemm, - Layout>, Layout>>; - - using TiledMma = decltype(cute::make_tiled_mma( - cute::GMMA::ss_op_selector(), - AtomLayoutMNK{})); - - using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); - using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); - - using SmemLayoutAtomA = decltype(detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - using SmemLayoutAtomB = decltype(detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); - - static constexpr int PipelineStages - = detail::compute_stage_count_or_override_gated(StageCountType{}); - using DispatchPolicy = cute::conditional_t, - MainloopSm90TmaGmmaWarpSpecialized>; - - using SmemCopyAtomA = void; - using SmemCopyAtomB = void; - - using CollectiveOp = CollectiveMmaGated, - ElementB, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, - GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_builder_gated.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_builder_gated.hpp deleted file mode 100644 index 2f2422c9914..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_builder_gated.hpp +++ /dev/null @@ -1,58 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -///////////////////////////////////////////////////////////////////////////////////////////////// -#include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass_extensions/gemm/collective/collective_mma_gated.hpp" - -namespace cutlass::gemm::collective -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template class Activation, - bool SwapAB = false, class Enable = void> -struct CollectiveBuilderGated -{ - static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters."); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#include "cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl" -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_mma_gated.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_mma_gated.hpp deleted file mode 100644 index d850f36df5f..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_mma_gated.hpp +++ /dev/null @@ -1,59 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/detail/dependent_false.hpp" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::collective -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template class Activation, bool SwapAB = false> -struct CollectiveMmaGated -{ - static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp" -#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp" -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp deleted file mode 100644 index dcba6ee6377..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp +++ /dev/null @@ -1,642 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cute/arch/cluster_sm90.hpp" -#include "cute/arch/copy_sm90.hpp" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/dispatch_policy.hpp" - -#include "cute/algorithm/functional.hpp" -#include "cute/algorithm/gemm.hpp" -#include "cute/atom/mma_atom.hpp" -#include "cute/numeric/arithmetic_tuple.hpp" -#include "cute/tensor_predicate.hpp" -#include "cutlass/pipeline/pipeline.hpp" -#include "cutlass/trace.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::collective -{ -using namespace cute; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// WarpSpecialized Mainloop -template class Activation_, bool SwapAB_> -struct CollectiveMmaGated, TileShape_, - ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, - GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_, SwapAB_> -{ - static constexpr bool isGated = true; - static constexpr bool SwapAB = SwapAB_; - - // - // Type Aliases - // - using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecialized; - using TileShape = TileShape_; - using ElementA = ElementA_; - using StrideA = StrideA_; - using ElementB = ElementB_; - using StrideB = StrideB_; - using TiledMma = TiledMma_; - using ElementAccumulator = typename TiledMma::ValTypeC; - using GmemTiledCopyA = GmemTiledCopyA_; - using GmemTiledCopyB = GmemTiledCopyB_; - using SmemLayoutAtomA = SmemLayoutAtomA_; - using SmemLayoutAtomB = SmemLayoutAtomB_; - using SmemCopyAtomA = SmemCopyAtomA_; - using SmemCopyAtomB = SmemCopyAtomB_; - using TransformA = TransformA_; - using TransformB = TransformB_; - using ArchTag = typename DispatchPolicy::ArchTag; - using Activation = Activation_; - - using ElementAux = cute::conditional_t; - using ValTypeAux = cute::conditional_t; - - using MainloopPipeline = cutlass::PipelineTmaAsync; - using PipelineState = cutlass::PipelineState; - - using PipelineParams = typename MainloopPipeline::Params; - - static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); - static_assert( - (size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert( - (size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - - static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); - static_assert( - (size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert( - (size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - - // Tile along modes in a way that maximizes the TMA box size. - using SmemLayoutA = decltype(tile_to_shape(SmemLayoutAtomA{}, - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), - conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{})); - using SmemLayoutB = decltype(tile_to_shape(SmemLayoutAtomB{}, - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), - conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{})); - using SmemLayoutAux = cute::conditional_t; - - static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); - static_assert(cute::is_base_of::value - && cute::is_base_of::value, - "MMA atom must source both A and B operand from smem_desc for this mainloop."); - static_assert( - cute::is_same_v || cute::is_same_v, - "GmemTiledCopy - invalid SM90 TMA copy atom specified."); - static_assert( - cute::is_same_v || cute::is_same_v, - "GmemTiledCopy - invalid SM90 TMA copy atom specified."); - - // TMA converts f32 input to tf32 when copying from GMEM to SMEM - // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. - static constexpr bool ConvertF32toTF32A = cute::is_same_v; - static constexpr bool ConvertF32toTF32B = cute::is_same_v; - using InternalElementA = cute::conditional_t>>; - using InternalElementB = cute::conditional_t>>; - using InternalElementAux = cute::conditional_t; - - struct SharedStorage - { - struct TensorStorage : cute::aligned_struct<128> - { - cute::array_aligned> smem_A; - cute::array_aligned> smem_B; - cute::array_aligned> smem_Aux; - } tensors; - - using PipelineStorage = typename MainloopPipeline::SharedStorage; - PipelineStorage pipeline; - }; - - using TensorStorage = typename SharedStorage::TensorStorage; - using PipelineStorage = typename SharedStorage::PipelineStorage; - - // Host side kernel arguments - struct Arguments - { - ElementA const* ptr_A; - StrideA dA; - ElementB const* ptr_B; - StrideB dB; - float scale_d0 = 1.0f; - float scale_d1 = 1.0f; - uint32_t mma_promotion_interval = 4; - }; - - // Device side kernel params - struct Params - { - // Assumption: StrideA is congruent with Problem_MK - using TMA_A = decltype(make_tma_copy(GmemTiledCopyA{}, - make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), - SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any - // Assumption: StrideB is congruent with Problem_NK - using TMA_B = decltype(make_tma_copy(GmemTiledCopyB{}, - make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), - SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any - using TMA_Aux = cute::conditional_t; - TMA_A tma_load_a; - TMA_B tma_load_b; - TMA_Aux tma_load_aux; - float scale_d0 = 1.0f; - float scale_d1 = 1.0f; - }; - - // - // Methods - // - - template - static constexpr Params to_underlying_arguments( - ProblemShape const& problem_shape, Arguments const& args, void* workspace) - { - (void) workspace; - - // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_MNKL; - - auto ptr_A = reinterpret_cast(args.ptr_A); - auto ptr_B = reinterpret_cast(args.ptr_B); - - Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA)); - Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N, K, L), args.dB)); - typename Params::TMA_A tma_load_a = make_tma_copy(GmemTiledCopyA{}, tensor_a, - SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{})); // mcast along N mode for this M load, if any - typename Params::TMA_B tma_load_b = make_tma_copy(GmemTiledCopyB{}, tensor_b, - SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{})); // mcast along M mode for this N load, if any - - if constexpr (SwapAB) - { - auto ptr_Aux = reinterpret_cast(args.ptr_A + size(make_shape(M, K, L))); - Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA)); - typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyA{}, tensor_aux, - SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{})); // mcast along N mode for this M load, if any - return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1}; - } - else - { - auto ptr_Aux = reinterpret_cast(args.ptr_B + size(make_shape(N, K, L))); - Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB)); - typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyB{}, tensor_aux, - SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{})); // mcast along M mode for this N load, if any - return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1}; - } - } - - template - static bool can_implement(ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) - { - constexpr int tma_alignment_bits = 128; - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_MNKL; - - bool implementable = true; - constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable - && cutlass::detail::check_alignment(cute::make_shape(M, K, L), StrideA{}); - constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable - && cutlass::detail::check_alignment(cute::make_shape(N, K, L), StrideB{}); - - if (!implementable) - { - CUTLASS_TRACE_HOST( - " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); - } - return implementable; - } - - static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; - static constexpr int K_PIPE_MMAS = 1; - static constexpr uint32_t TmaTransactionBytes - = (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)) / 8 - + (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)) / 8 - + (size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) * static_cast(sizeof_bits::value)) - / 8; - - /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance - CUTLASS_DEVICE - static void prefetch_tma_descriptors(Params const& mainloop_params) - { - cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); - cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); - cute::prefetch_tma_descriptor(mainloop_params.tma_load_aux.get_tma_descriptor()); - } - - /// Set up the data needed by this collective for load and mma. - /// Returns a tuple of tensors. The collective and the kernel layer have the contract - /// Returned tuple must contain at least two elements, with the first two elements being: - /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) - /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) - /// gAux_xkl - The tma tensor, A/B after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l) - /// The rest of the tensors can be specified as needed by this collective. - template - CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const - { - using X = Underscore; - // Separate out problem shape for convenience - auto [M, N, K, L] = problem_shape_MNKL; - - // TMA requires special handling of strides to deal with coord codomain mapping - // Represent the full tensors -- get these from TMA - Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M, K, L)); // (m,k,l) - Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); // (n,k,l) - - // Make tiled views, defer the slice - Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) - Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,n,k,l) - - if constexpr (SwapAB) - { - Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(M, K, L)); // (m,k,l) - Tensor gAux_xkl - = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) - return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); - } - else - { - Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(N, K, L)); // (n,k,l) - Tensor gAux_xkl - = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,n,k,l) - return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); - } - } - - /// Perform a collective-scoped matrix multiply-accumulate - /// Producer Perspective - template - CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write, - cute::tuple const& load_inputs, BlockCoord const& blk_coord, - KTileIterator k_tile_iter, int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster, - TensorStorage& shared_tensors) - { - int lane_predicate = cute::elect_one_sync(); - - if (lane_predicate) - { - Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); - - // - // Prepare the TMA loads for A and B - // - - constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); - uint2 cluster_local_block_id - = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - - Tensor gA_mkl = get<0>(load_inputs); - Tensor gB_nkl = get<1>(load_inputs); - Tensor gAux_xkl = get<2>(load_inputs); - - auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); - auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); - auto block_tma_aux = SwapAB ? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y) - : mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.x); - // Partition the inputs based on the current block coordinates. - auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; - Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k) - Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord) : gAux_xkl(_, _, n_coord, _, l_coord); - - // Applies the mapping from block_tma_a - Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) - Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) - - Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) - Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) - - Tensor tAuxgAux = block_tma_aux.partition_S(gAux); - Tensor tAuxsAux = block_tma_aux.partition_D(sAux); - - uint16_t mcast_mask_a = 0; - uint16_t mcast_mask_b = 0; - uint16_t mcast_mask_aux = 0; - - // Issue TmaLoads - // Maps the tile -> block, value - if constexpr (cute::is_same_v) - { - auto block_layout = Layout{}; // (m,n) -> block_id - for (int n = 0; n < size<1>(block_layout); ++n) - { - mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, Int<0>{})); - } - } - - if constexpr (cute::is_same_v) - { - auto block_layout = Layout{}; // (m,n) -> block_id - for (int m = 0; m < size<0>(block_layout); ++m) - { - mcast_mask_b |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, Int<0>{})); - } - } - - if constexpr (SwapAB) - { - mcast_mask_aux = mcast_mask_a; - } - else - { - mcast_mask_aux = mcast_mask_b; - } - - // Mainloop - CUTLASS_PRAGMA_NO_UNROLL - for (; k_tile_count > 0; --k_tile_count) - { - // LOCK smem_pipe_write for _writing_ - pipeline.producer_acquire(smem_pipe_write); - - // - // Copy gmem to smem for *k_tile_iter - // - - using BarrierType = typename MainloopPipeline::ProducerBarrierType; - BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); - - int write_stage = smem_pipe_write.index(); - copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_, _, _, *k_tile_iter), - tAsA(_, _, _, write_stage)); - copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_, _, _, *k_tile_iter), - tBsB(_, _, _, write_stage)); - copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux), tAuxgAux(_, _, _, *k_tile_iter), - tAuxsAux(_, _, _, write_stage)); - ++k_tile_iter; - - // Advance smem_pipe_write - ++smem_pipe_write; - } - } - } - - /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster - CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) - { - int lane_predicate = cute::elect_one_sync(); - - // Issue the epilogue waits - if (lane_predicate) - { - /* This helps avoid early exit of blocks in Cluster - * Waits for all stages to either be released (all - * Consumer UNLOCKs), or if the stage was never used - * then would just be acquired since the phase was - * still inverted from make_producer_start_state - */ - pipeline.producer_tail(smem_pipe_write); - } - } - - /// Perform a collective-scoped matrix multiply-accumulate - /// Consumer Perspective - template - CUTLASS_DEVICE void mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, FrgTensorC& accum0, - FrgTensorC& accum1, int k_tile_count, int thread_idx, TensorStorage& shared_tensors, - Params const& mainloop_params) - { - static_assert(is_rmem::value, "C tensor must be rmem resident."); - static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); - static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); - static_assert(cute::rank(SmemLayoutAux{}) == 3, "Smem layout must be rank 3."); - static_assert(cute::is_void_v, - "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); - static_assert(cute::is_void_v, - "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); - - Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); - - // - // Define C accumulators and A/B partitioning - // - - TiledMma tiled_mma; - auto thread_mma = tiled_mma.get_thread_slice(thread_idx); - - Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) - - // Allocate "fragments/descriptors" - Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) - - auto tCsAux = [&]() -> auto - { - if constexpr (SwapAB) - { - return thread_mma.partition_A(sAux); - } - else - { - return thread_mma.partition_B(sAux); - } - }(); - auto tCrAux = [&]() -> auto - { - if constexpr (SwapAB) - { - return thread_mma.make_fragment_A(tCsAux); - } - else - { - return thread_mma.make_fragment_B(tCsAux); - } - }(); - - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE - if constexpr (SwapAB) - { - CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE - } - else - { - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE - } - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sAux)); // PIPE - - // - // PIPELINED MAIN LOOP - // - static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), "ERROR : Incorrect number of MMAs in flight"); - - // We release buffers to producer warps(dma load) with some mmas in flight - PipelineState smem_pipe_release = smem_pipe_read; - - // Prologue GMMAs - int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); - - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - - warpgroup_fence_operand(accum0); - warpgroup_fence_operand(accum1); - CUTLASS_PRAGMA_UNROLL - for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) - { - // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) - auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); - pipeline.consumer_wait(smem_pipe_read, barrier_token); - - int read_stage = smem_pipe_read.index(); - warpgroup_arrive(); - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) - { - // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum0); - if constexpr (SwapAB) - { - cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum1); - } - else - { - cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accum1); - } - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - - warpgroup_commit_batch(); - - ++smem_pipe_read; - } - - warpgroup_fence_operand(accum0); - warpgroup_fence_operand(accum1); - // Mainloop GMMAs - k_tile_count -= prologue_mma_count; - - CUTLASS_PRAGMA_NO_UNROLL - for (; k_tile_count > 0; --k_tile_count) - { - // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) - auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); - pipeline.consumer_wait(smem_pipe_read, barrier_token); - - // - // Compute on k_tile - // - - int read_stage = smem_pipe_read.index(); - warpgroup_fence_operand(accum0); - warpgroup_fence_operand(accum1); - warpgroup_arrive(); - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) - { - // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum0); - if constexpr (SwapAB) - { - cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum1); - } - else - { - cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accum1); - } - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - warpgroup_commit_batch(); - - /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed - warpgroup_wait(); - warpgroup_fence_operand(accum0); - warpgroup_fence_operand(accum1); - - // UNLOCK smem_pipe_release, done _computing_ on it - pipeline.consumer_release(smem_pipe_release); - - // Advance smem_pipe_read and smem_pipe_release - ++smem_pipe_read; - ++smem_pipe_release; - } - - warpgroup_fence_operand(accum0); - warpgroup_fence_operand(accum1); - } - - /// Perform a Consumer Epilogue to release all buffers - CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) - { - // Prologue GMMAs - int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); - k_tile_count -= prologue_mma_count; - - smem_pipe_release.advance(k_tile_count); - - // Wait on all GMMAs to complete - warpgroup_wait<0>(); - - for (int count = 0; count < prologue_mma_count; ++count) - { - pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it - ++smem_pipe_release; - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp deleted file mode 100644 index 72c1adf293f..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp +++ /dev/null @@ -1,665 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -#pragma once - -#include "cute/arch/cluster_sm90.hpp" -#include "cute/arch/copy_sm90.hpp" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/dispatch_policy.hpp" - -#include "cute/algorithm/functional.hpp" -#include "cute/algorithm/gemm.hpp" -#include "cute/atom/mma_atom.hpp" -#include "cute/numeric/arithmetic_tuple.hpp" -#include "cute/tensor_predicate.hpp" -#include "cutlass/epilogue/thread/activation.h" -#include "cutlass/gemm/collective/fp8_accumulation.hpp" -#include "cutlass/trace.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::collective -{ -using namespace cute; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -// WarpSpecialized Mainloop -template class Activation_, bool SwapAB_> -struct CollectiveMmaGated, TileShape_, - ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, - GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_, SwapAB_> -{ - static constexpr bool isGated = true; - static constexpr bool SwapAB = SwapAB_; - - // - // Type Aliases - // - using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedFP8; - using TileShape = TileShape_; - using ElementA = ElementA_; - using StrideA = StrideA_; - using ElementB = ElementB_; - using StrideB = StrideB_; - using TiledMma = TiledMma_; - using ElementAccumulator = typename TiledMma::ValTypeC; - using GmemTiledCopyA = GmemTiledCopyA_; - using GmemTiledCopyB = GmemTiledCopyB_; - using SmemLayoutAtomA = SmemLayoutAtomA_; - using SmemLayoutAtomB = SmemLayoutAtomB_; - using SmemCopyAtomA = SmemCopyAtomA_; - using SmemCopyAtomB = SmemCopyAtomB_; - using TransformA = TransformA_; - using TransformB = TransformB_; - using ArchTag = typename DispatchPolicy::ArchTag; - using Activation = Activation_; - - using ElementAux = cute::conditional_t; - using ValTypeAux = cute::conditional_t; - - using MainloopPipeline = cutlass::PipelineTmaAsync; - using PipelineState = cutlass::PipelineState; - - using PipelineParams = typename MainloopPipeline::Params; - - static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); - static_assert( - (size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert( - (size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - - static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); - static_assert( - (size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert( - (size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - - // Tile along modes in a way that maximizes the TMA box size. - using SmemLayoutA = decltype(tile_to_shape(SmemLayoutAtomA{}, - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), - conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{})); - using SmemLayoutB = decltype(tile_to_shape(SmemLayoutAtomB{}, - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), - conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{})); - using SmemLayoutAux = cute::conditional_t; - - static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); - static_assert(cute::is_base_of::value - && cute::is_base_of::value, - "MMA atom must source both A and B operand from smem_desc for this mainloop."); - static_assert( - cute::is_same_v || cute::is_same_v, - "GmemTiledCopy - invalid SM90 TMA copy atom specified."); - static_assert( - cute::is_same_v || cute::is_same_v, - "GmemTiledCopy - invalid SM90 TMA copy atom specified."); - - struct SharedStorage - { - struct TensorStorage : cute::aligned_struct<128> - { - cute::array_aligned> smem_A; - cute::array_aligned> smem_B; - cute::array_aligned> smem_Aux; - } tensors; - - using PipelineStorage = typename MainloopPipeline::SharedStorage; - PipelineStorage pipeline; - }; - - using TensorStorage = typename SharedStorage::TensorStorage; - using PipelineStorage = typename SharedStorage::PipelineStorage; - - // Host side kernel arguments - struct Arguments - { - ElementA const* ptr_A; - StrideA dA; - ElementB const* ptr_B; - StrideB dB; - float scale_d0 = 1.0f; - float scale_d1 = 1.0f; - uint32_t mma_promotion_interval = 4; - }; - - // Device side kernel params - struct Params - { - // Assumption: StrideA is congruent with Problem_MK - using TMA_A = decltype(make_tma_copy(GmemTiledCopyA{}, - make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), - SmemLayoutA{}(_, _, 0), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any - // Assumption: StrideB is congruent with Problem_NK - using TMA_B = decltype(make_tma_copy(GmemTiledCopyB{}, - make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), - SmemLayoutB{}(_, _, 0), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any - using TMA_Aux = cute::conditional_t; - TMA_A tma_load_a; - TMA_B tma_load_b; - TMA_Aux tma_load_aux; - float scale_d0 = 1.0f; - float scale_d1 = 1.0f; - uint32_t mma_promotion_interval = 4; - }; - - // - // Methods - // - - template - static constexpr Params to_underlying_arguments( - ProblemShape const& problem_shape, Arguments const& args, void* workspace) - { - (void) workspace; - - // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_MNKL; - - auto ptr_A = reinterpret_cast(args.ptr_A); - auto ptr_B = reinterpret_cast(args.ptr_B); - - Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA)); - Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N, K, L), args.dB)); - typename Params::TMA_A tma_load_a = make_tma_copy(GmemTiledCopyA{}, tensor_a, - SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{})); // mcast along N mode for this M load, if any - typename Params::TMA_B tma_load_b = make_tma_copy(GmemTiledCopyB{}, tensor_b, - SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{})); // mcast along M mode for this N load, if any - if constexpr (SwapAB) - { - auto ptr_Aux = reinterpret_cast(args.ptr_A + size(make_shape(M, K, L))); - Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA)); - typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyA{}, tensor_aux, - SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), - size<1>(ClusterShape{})); // mcast along N mode for this M load, if any - return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1, args.mma_promotion_interval}; - } - else - { - auto ptr_Aux = reinterpret_cast(args.ptr_B + size(make_shape(N, K, L))); - Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB)); - typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyB{}, tensor_aux, - SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), - size<0>(ClusterShape{})); // mcast along M mode for this N load, if any - return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1, args.mma_promotion_interval}; - } - } - - template - static bool can_implement(ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) - { - constexpr int tma_alignment_bits = 128; - auto problem_shape_MNKL = append<4>(problem_shape, 1); - auto [M, N, K, L] = problem_shape_MNKL; - - bool implementable = true; - constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable - && cutlass::detail::check_alignment(cute::make_shape(M, K, L), StrideA{}); - constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; - implementable = implementable - && cutlass::detail::check_alignment(cute::make_shape(N, K, L), StrideB{}); - /* MMA promotion interval should be a multiple of 4, since each mainloop iteration would issue 4 MMA - * instructions. */ - implementable = implementable && (args.mma_promotion_interval % 4 == 0); - - if (!implementable) - { - CUTLASS_TRACE_HOST( - " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); - } - return implementable; - } - - static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; - static constexpr int K_PIPE_MMAS = 1; - static constexpr uint32_t TmaTransactionBytes - = (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)) / 8 - + (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)) / 8 - + (size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) * static_cast(sizeof_bits::value)) - / 8; - - /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance - CUTLASS_DEVICE - static void prefetch_tma_descriptors(Params const& mainloop_params) - { - cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); - cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); - cute::prefetch_tma_descriptor(mainloop_params.tma_load_aux.get_tma_descriptor()); - } - - /// Set up the data needed by this collective for load and mma. - /// Returns a tuple of tensors. The collective and the kernel layer have the contract - /// Returned tuple must contain at least two elements, with the first two elements being: - /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) - /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) - /// gAux_xkl - The tma tensor, A/B after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l) - template - CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const - { - using X = Underscore; - // Separate out problem shape for convenience - auto [M, N, K, L] = problem_shape_MNKL; - - // TMA requires special handling of strides to deal with coord codomain mapping - // Represent the full tensors -- get these from TMA - Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M, K, L)); // (m,k,l) - Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); // (n,k,l) - - // Make tiled views, defer the slice - Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) - Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,n,k,l) - - if constexpr (SwapAB) - { - Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(M, K, L)); // (m,k,l) - Tensor gAux_xkl - = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) - return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); - } - else - { - Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(N, K, L)); // (n,k,l) - Tensor gAux_xkl - = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,n,k,l) - return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); - } - } - - /// Perform a collective-scoped matrix multiply-accumulate - /// Producer Perspective - template - CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write, - cute::tuple const& load_inputs, BlockCoord const& blk_coord, - KTileIterator k_tile_iter, int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster, - TensorStorage& shared_tensors) - { - int lane_predicate = cute::elect_one_sync(); - - if (lane_predicate) - { - Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); - - // - // Prepare the TMA loads for A and B - // - - constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); - uint2 cluster_local_block_id - = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - - Tensor gA_mkl = get<0>(load_inputs); - Tensor gB_nkl = get<1>(load_inputs); - Tensor gAux_xkl = get<2>(load_inputs); - - auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); - auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); - auto block_tma_aux = SwapAB ? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y) - : mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.x); - - // Partition the inputs based on the current block coordinates. - auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; - Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k) - Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord) : gAux_xkl(_, _, n_coord, _, l_coord); - - // Applies the mapping from block_tma_a - Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) - Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) - - Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) - Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) - - Tensor tAuxgAux = block_tma_aux.partition_S(gAux); - Tensor tAuxsAux = block_tma_aux.partition_D(sAux); - - uint16_t mcast_mask_a = 0; - uint16_t mcast_mask_b = 0; - uint16_t mcast_mask_aux = 0; - - // Issue TmaLoads - // Maps the tile -> block, value - if constexpr (cute::is_same_v) - { - auto block_layout = Layout{}; // (m,n) -> block_id - for (int n = 0; n < size<1>(block_layout); ++n) - { - mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, Int<0>{})); - } - } - - if constexpr (cute::is_same_v) - { - auto block_layout = Layout{}; // (m,n) -> block_id - for (int m = 0; m < size<0>(block_layout); ++m) - { - mcast_mask_b |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, Int<0>{})); - } - } - - if constexpr (SwapAB) - { - mcast_mask_aux = mcast_mask_a; - } - else - { - mcast_mask_aux = mcast_mask_b; - } - - // Mainloop - CUTLASS_PRAGMA_NO_UNROLL - for (; k_tile_count > 0; --k_tile_count) - { - // LOCK smem_pipe_write for _writing_ - pipeline.producer_acquire(smem_pipe_write); - - // - // Copy gmem to smem for *k_tile_iter - // - - using BarrierType = typename MainloopPipeline::ProducerBarrierType; - BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); - - int write_stage = smem_pipe_write.index(); - copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_, _, _, *k_tile_iter), - tAsA(_, _, _, write_stage)); - copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_, _, _, *k_tile_iter), - tBsB(_, _, _, write_stage)); - copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux), tAuxgAux(_, _, _, *k_tile_iter), - tAuxsAux(_, _, _, write_stage)); - ++k_tile_iter; - - // Advance smem_pipe_write - ++smem_pipe_write; - } - } - } - - /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster - CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) - { - int lane_predicate = cute::elect_one_sync(); - - // Issue the epilogue waits - if (lane_predicate) - { - /* This helps avoid early exit of blocks in Cluster - * Waits for all stages to either be released (all - * Consumer UNLOCKs), or if the stage was never used - * then would just be acquired since the phase was - * still inverted from make_producer_start_state - */ - pipeline.producer_tail(smem_pipe_write); - } - } - - /// Perform a collective-scoped matrix multiply-accumulate - /// Consumer Perspective - template - CUTLASS_DEVICE void mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, FrgTensorC& accum0, - FrgTensorC& accum1, int k_tile_count, int thread_idx, TensorStorage& shared_tensors, - Params const& mainloop_params) - { - - static_assert(is_rmem::value, "C tensor must be rmem resident."); - static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); - static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); - static_assert(cute::is_void_v, - "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); - static_assert(cute::is_void_v, - "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); - - Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); - - // - // Define C accumulators and A/B partitioning - // - - TiledMma tiled_mma; - auto thread_mma = tiled_mma.get_thread_slice(thread_idx); - - Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) - - // Allocate "fragments/descriptors" - Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) - - auto tCsAux = [&]() -> auto - { - if constexpr (SwapAB) - { - return thread_mma.partition_A(sAux); - } - else - { - return thread_mma.partition_B(sAux); - } - }(); - auto tCrAux = [&]() -> auto - { - if constexpr (SwapAB) - { - return thread_mma.make_fragment_A(tCsAux); - } - else - { - return thread_mma.make_fragment_B(tCsAux); - } - }(); - - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE - if constexpr (SwapAB) - { - CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE - } - else - { - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE - } - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sAux)); // PIPE - - // - // PIPELINED MAIN LOOP - // - static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), "ERROR : Incorrect number of MMAs in flight"); - - // We release buffers to producer warps(dma load) with some mmas in flight - PipelineState smem_pipe_release = smem_pipe_read; - - // Prologue GMMAs - int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); - - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - - GmmaFP8Accumulation accumulation0(accum0, mainloop_params.mma_promotion_interval, size<2>(tCrA)); - GmmaFP8Accumulation accumulation1(accum1, mainloop_params.mma_promotion_interval, size<2>(tCrA)); - warpgroup_fence_operand(accumulation0()); - warpgroup_fence_operand(accumulation1()); - CUTLASS_PRAGMA_UNROLL - for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) - { - // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) - auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); - pipeline.consumer_wait(smem_pipe_read, barrier_token); - - if (accumulation0.prepare_if_needed()) - { - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - } - - int read_stage = smem_pipe_read.index(); - warpgroup_arrive(); - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) - { - // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm( - tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation0()); - if constexpr (SwapAB) - { - cute::gemm( - tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation1()); - } - else - { - cute::gemm( - tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accumulation1()); - } - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - warpgroup_commit_batch(); - - accumulation0.promote_if_needed(); - accumulation1.promote_if_needed(); - - ++smem_pipe_read; - } - - warpgroup_fence_operand(accumulation0()); - warpgroup_fence_operand(accumulation1()); - // Mainloop GMMAs - k_tile_count -= prologue_mma_count; - - CUTLASS_PRAGMA_NO_UNROLL - for (; k_tile_count > 0; --k_tile_count) - { - // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) - auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); - pipeline.consumer_wait(smem_pipe_read, barrier_token); - - // - // Compute on k_tile - // - - int read_stage = smem_pipe_read.index(); - - if (accumulation0.prepare_if_needed()) - { - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - } - - warpgroup_fence_operand(accumulation0()); - warpgroup_fence_operand(accumulation1()); - warpgroup_arrive(); - // Unroll the K mode manually to set scale D to 1 - CUTLASS_PRAGMA_UNROLL - for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) - { - // (V,M,K) x (V,N,K) => (V,M,N) - cute::gemm( - tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation0()); - if constexpr (SwapAB) - { - cute::gemm( - tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation1()); - } - else - { - cute::gemm( - tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accumulation1()); - } - tiled_mma.accumulate_ = GMMA::ScaleOut::One; - } - warpgroup_commit_batch(); - - /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed - warpgroup_wait(); - warpgroup_fence_operand(accumulation0()); - warpgroup_fence_operand(accumulation1()); - - accumulation0.promote_if_needed(); - accumulation1.promote_if_needed(); - - pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it - - // Advance smem_pipe_read and smem_pipe_release - ++smem_pipe_read; - ++smem_pipe_release; - } - - accumulation0.promote_residue_if_needed(); - accumulation1.promote_residue_if_needed(); - - warpgroup_fence_operand(accumulation0()); - warpgroup_fence_operand(accumulation1()); - } - - /// Perform a Consumer Epilogue to release all buffers - CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) - { - // Prologue GMMAs - int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); - k_tile_count -= prologue_mma_count; - - smem_pipe_release.advance(k_tile_count); - - // Wait on all GMMAs to complete - warpgroup_wait<0>(); - - for (int count = 0; count < prologue_mma_count; ++count) - { - pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it - ++smem_pipe_release; - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h deleted file mode 100644 index 2edd5a228b4..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h +++ /dev/null @@ -1,438 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! - \file - \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and - batched array variants. -*/ - -#pragma once - -// #include - -#include "cutlass/arch/arch.h" -#include "cutlass/cutlass.h" -#include "cutlass/device_kernel.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/kernel/gemm_universal.h" -#include "cutlass/gemm/threadblock/threadblock_swizzle.h" - -#include "cutlass/gemm/device/default_gemm_configuration.h" -#include "cutlass/gemm/kernel/default_gemm_universal.h" - -#include "cutlass/trace.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace device -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/* - This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088) - It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs - and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs. - - Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support - that feature at the moment. - */ - -template -class GemmUniversalBaseCompat -{ -public: - using GemmKernel = GemmKernel_; - using ThreadblockShape = typename GemmKernel::Mma::Shape; - - using ElementA = typename GemmKernel::ElementA; - using LayoutA = typename GemmKernel::LayoutA; - using TensorRefA = TensorRef; - static ComplexTransform const kTransformA = GemmKernel::kTransformA; - - using ElementB = typename GemmKernel::ElementB; - using LayoutB = typename GemmKernel::LayoutB; - using TensorRefB = TensorRef; - static ComplexTransform const kTransformB = GemmKernel::kTransformB; - - using ElementC = typename GemmKernel::ElementC; - using LayoutC = typename GemmKernel::LayoutC; - using TensorRefC = TensorRef; - using TensorRefD = TensorRef; - - using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC; - - using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; - using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; - using Operator = typename GemmKernel::Operator; - - /// Argument structure - using Arguments = typename GemmKernel::Arguments; - -protected: - /// Kernel parameters object - typename GemmKernel::Params params_; - -protected: - /// Private helper to obtain the grid dimensions with fix-up for split-K - static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args) - { - - // Determine grid shape - ThreadblockSwizzle threadblock_swizzle; - - grid_tiled_shape = threadblock_swizzle.get_tiled_shape( - args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); - - gemm_k_size = args.problem_size.k(); - - if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) - { - - int const kAlignK - = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); - - gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); - - if (gemm_k_size) - { - grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); - } - } - } - -public: - /// Constructs the GEMM. - GemmUniversalBaseCompat() {} - - /// Determines whether the GEMM can execute the given problem. - static Status can_implement(Arguments const& args) - { - - // Determine grid shape - cutlass::gemm::GemmCoord grid_tiled_shape; - int gemm_k_size = 0; - - get_grid_shape_(grid_tiled_shape, gemm_k_size, args); - - ThreadblockSwizzle threadblock_swizzle; - dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape); - - uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1); - - if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) - { - - return Status::kErrorInvalidProblem; - } - - return GemmKernel::can_implement(args); - } - - /// Gets the workspace size - static size_t get_workspace_size(Arguments const& args) - { - - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()"); - - size_t workspace_bytes = 0; - - // Determine grid shape - cutlass::gemm::GemmCoord grid_tiled_shape; - int gemm_k_size = 0; - - get_grid_shape_(grid_tiled_shape, gemm_k_size, args); - - if (args.mode == GemmUniversalMode::kGemmSplitKParallel) - { - - // Split-K parallel always requires a temporary workspace - workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k()); - } - else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) - { - - // Serial split-K only requires a temporary workspace if the number of partitions along the - // GEMM K dimension is greater than one. - workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); - } - - CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); - - workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape); - - return workspace_bytes; - } - - /// Computes the grid shape - static dim3 get_grid_shape(Arguments const& args) - { - - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()"); - - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord grid_tiled_shape; - int gemm_k_size = 0; - - get_grid_shape_(grid_tiled_shape, gemm_k_size, args); - dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape); - - CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n" - << " result = {" << result << "}"); - - return result; - } - - /// Computes the maximum number of active blocks per multiprocessor - static int maximum_active_blocks(int smem_capacity = -1) - { - - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()"); - - int max_active_blocks = -1; - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - - CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); - - if (smem_size <= (48 << 10)) - { - - cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, Kernel, GemmKernel::kThreadCount, smem_size); - - if (result == cudaSuccess) - { - CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); - return max_active_blocks; - } - } - else - { - - // Query assuming zero shared memory then compute occupancy limit based on SMEM - cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, Kernel, GemmKernel::kThreadCount, 0); - - if (result != cudaSuccess) - { - - CUTLASS_TRACE_HOST( - " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); - - return -1; - } - - if (smem_capacity < 0) - { - int device_idx = 0; - result = cudaGetDevice(&device_idx); - - if (result != cudaSuccess) - { - return -1; - } - - cudaDeviceProp properties; - result = cudaGetDeviceProperties(&properties, device_idx); - - if (result != cudaSuccess) - { - return -1; - } - - smem_capacity = static_cast(properties.sharedMemPerMultiprocessor); - } - - int occupancy = std::min(max_active_blocks, smem_capacity / smem_size); - - CUTLASS_TRACE_HOST(" occupancy: " << occupancy); - - return occupancy; - } - - CUTLASS_TRACE_HOST(" returning internal error"); - - return -1; - } - - /// Initializes GEMM state from arguments. - Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) - { - - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace " - << workspace << ", stream: " << (stream ? "non-null" : "null")); - - size_t workspace_bytes = get_workspace_size(args); - - CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); - - if (workspace_bytes) - { - - if (!workspace) - { - CUTLASS_TRACE_HOST(" error: device workspace must not be null"); - - return Status::kErrorWorkspaceNull; - } - - if (args.mode == GemmUniversalMode::kGemm) - { - CUTLASS_TRACE_HOST(" clearing device workspace"); - cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream); - - if (result != cudaSuccess) - { - CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); - - return Status::kErrorInternal; - } - } - } - - // Get CUDA grid shape - cutlass::gemm::GemmCoord grid_tiled_shape; - int gemm_k_size = 0; - - get_grid_shape_(grid_tiled_shape, gemm_k_size, args); - - // Initialize the Params structure - params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast(workspace)); - - // Specify shared memory capacity for kernel. - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - - if (smem_size >= (48 << 10)) - { - cudaError_t result - = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - - if (result != cudaSuccess) - { - return Status::kErrorInternal; - } - } - - return Status::kSuccess; - } - - /// Lightweight update given a subset of arguments - Status update(Arguments const& args, void* workspace = nullptr) - { - - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace); - - size_t workspace_bytes = get_workspace_size(args); - - if (workspace_bytes && !workspace) - { - return Status::kErrorWorkspaceNull; - } - - params_.update(args, workspace); - - return Status::kSuccess; - } - - /// Runs the kernel using initialized state. - Status run(cudaStream_t stream = nullptr) - { - CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()"); - - // - // Configure grid and block dimensions - // - - ThreadblockSwizzle threadblock_swizzle; - - dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); - dim3 block(GemmKernel::kThreadCount, 1, 1); - - int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); - - // - // Launch kernel - // - - CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes"); - - // Launch - cutlass::Kernel<<>>(params_); - - // - // Query for errors - // - cudaError_t result = cudaGetLastError(); - - if (result != cudaSuccess) - { - CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); - return Status::kErrorInternal; - } - - return Status::kSuccess; - } - - /// Runs the kernel using initialized state. - Status operator()(cudaStream_t stream = nullptr) - { - return run(stream); - } - - /// Runs the kernel using initialized state. - Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) - { - - Status status = initialize(args, workspace, stream); - - if (status == Status::kSuccess) - { - status = run(stream); - } - - return status; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/splitk_gemm_grouped.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/splitk_gemm_grouped.h deleted file mode 100644 index bfd3666b9c1..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/splitk_gemm_grouped.h +++ /dev/null @@ -1,542 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! - \file - \brief Based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h -*/ - -#pragma once - -#include -#include -#include - -#include "cutlass/arch/arch.h" -#include "cutlass/cutlass.h" -#include "cutlass/device_kernel.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/kernel/gemm_universal.h" -#include "cutlass/gemm/threadblock/threadblock_swizzle.h" - -#include "cutlass/gemm/device/default_gemm_configuration.h" -#include "cutlass/gemm/kernel/default_gemm_universal.h" - -#include "cutlass/trace.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace device -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -__global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, GemmCoord const* problem_sizes, int splitk, - int64_t* splitk_buffer_offsets) -{ - // in_tensor: [problem_idx, k_partition, hidden_size] - // Note that different requests of in_tensor might have different hidden_size (=m*n) - // so, we need to use splitk_buffer_offsets. - // out_tensor: problem_idx * [hidden_size] - - int const problem_idx = blockIdx.y; - GemmCoord problem = problem_sizes[problem_idx]; - int const hidden_size = problem.m() * problem.n(); - const T_IN* in_tensor_ = in_tensor + splitk_buffer_offsets[problem_idx] * splitk; - T_OUT* out_tensor_ = out_tensor[problem_idx]; - - for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < hidden_size; i += blockDim.x * gridDim.x) - { - float sum = 0.0f; - for (int k_idx = 0; k_idx < splitk; k_idx++) - { - sum += (float) in_tensor_[k_idx * hidden_size + i]; - } - out_tensor_[i] = (T_OUT) (sum); - } -} - -/// GEMM Grouped -template -class BaseSplitkGrouped -{ -public: - using BaseKernel = BaseKernel_; - - using ElementA = typename BaseKernel::ElementA; - using LayoutA = typename BaseKernel::LayoutA; - using TensorRefA = TensorRef; - static ComplexTransform const kTransformA = BaseKernel::kTransformA; - static int const kAlignmentA = BaseKernel::kAlignmentA; - - using ElementB = typename BaseKernel::ElementB; - using LayoutB = typename BaseKernel::LayoutB; - using TensorRefB = TensorRef; - static ComplexTransform const kTransformB = BaseKernel::kTransformB; - static int const kAlignmentB = BaseKernel::kAlignmentB; - - using ElementC = typename BaseKernel::ElementC; - using LayoutC = typename BaseKernel::LayoutC; - using TensorRefC = TensorRef; - using TensorRefD = TensorRef; - static int const kAlignmentC = BaseKernel::kAlignmentC; - - using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC; - - using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp; - using ThreadblockSwizzle = typename threadblock::GemmSplitKHorizontalThreadblockSwizzle; - - using Operator = typename BaseKernel::Operator; - using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator; - - using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; - using MathOperator = typename WarpMmaOperator::MathOperator; - using OperatorClass = typename WarpMmaOperator::OperatorClass; - using ArchTag = typename WarpMmaOperator::ArchTag; - using ThreadblockShape = typename BaseKernel::Mma::Shape; - using WarpShape = typename BaseKernel::WarpShape; - using InstructionShape = typename BaseKernel::InstructionShape; - static int const kStages = BaseKernel::Mma::kStages; - - /// Argument structure - using Arguments = typename BaseKernel::Arguments; - - using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo; - -protected: - /// Kernel parameters object - typename BaseKernel::Params gemm_params_; - -private: - /// Get the number of tiles across all problems in a group - static int32_t group_tile_count(cutlass::gemm::GemmCoord const* problem_sizes_ptr, int problem_count) - { - int32_t tiles = 0; - for (int32_t i = 0; i < problem_count; ++i) - { - cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i]; - BaseKernel::ProblemVisitor::possibly_transpose_problem(problem); - tiles += problem_tile_count(problem); - } - return tiles; - } - - /// Copy from `data` to `workspace` - Status copy_to_workspace(void* workspace, void* data, size_t bytes) - { - cudaError_t cuda_error = cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice); - if (cuda_error != cudaSuccess) - { - // Call cudaGetLastError() to clear the error bit - cuda_error = cudaGetLastError(); - CUTLASS_TRACE_HOST(" cudaMemcpy() returned error " << cudaGetErrorString(cuda_error)); - return Status::kErrorInternal; - } - - return Status::kSuccess; - } - - /// Precomputes scheduling information for the grouped GEMM - Status precompute(Arguments const& args, int32_t tile_count, void* workspace) - { - size_t workspace_bytes = get_workspace_size(args); - std::vector host_workspace(workspace_bytes); - BaseKernel::ProblemVisitor::host_precompute( - args.host_problem_sizes, args.problem_count, args.threadblock_count, (void*) host_workspace.data()); - return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes); - } - - /// Reorder `data` according to `indices` - template - static void reorder_array(T* data, std::vector const& indices) - { - // For now, simply create a copy of the data and then copy over to the original. - std::vector copy(indices.size()); - for (size_t i = 0; i < indices.size(); ++i) - { - copy.at(i) = data[indices[i]]; - } - - memcpy(data, copy.data(), indices.size() * sizeof(T)); - } - -public: - /// Constructs the GEMM. - BaseSplitkGrouped() {} - - /// Determines whether the GEMM can execute the given problem. - static Status can_implement(Arguments const& args) - { - - return BaseKernel::can_implement(args); - } - - /// Get the number of tiles in a problem - static int32_t problem_tile_count(cutlass::gemm::GemmCoord const& problem) - { - auto grid = BaseKernel::ProblemVisitor::grid_shape(problem); - return BaseKernel::ProblemVisitor::tile_count(grid); - } - - /// Get the number of tiles across all problems in a group - static int32_t group_tile_count(Arguments const& args) - { - if (args.host_problem_sizes == nullptr) - { - CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes"); - return -1; - } - - return group_tile_count(args.host_problem_sizes, args.problem_count); - } - - /// Gets the workspace size - static size_t get_workspace_size(Arguments const& args) - { - size_t total_mn = 0; - for (int i = 0; i < args.problem_count; i++) - { - total_mn += args.host_problem_sizes[i].m() * args.host_problem_sizes[i].n(); - } - size_t workSpaceSize = total_mn * sizeof(ElementAccumulator) * args.split_k_slices; - - if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) - { - workSpaceSize += BaseKernel::ProblemVisitor::get_workspace_size( - args.host_problem_sizes, args.problem_count, args.threadblock_count); - } - return workSpaceSize; - } - - /// Computes the grid shape - static dim3 get_grid_shape(Arguments const& args) - { - - return dim3(args.threadblock_count, 1, 1); - } - - /// Computes the maximum number of active blocks per multiprocessor - static int maximum_active_blocks(int smem_capacity = -1) - { - - CUTLASS_TRACE_HOST("BaseSplitkGrouped::maximum_active_blocks()"); - - int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); - - CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); - - cudaError_t result; - if (smem_size > (48 << 10)) - { - result = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - - if (result != cudaSuccess) - { - // Call cudaGetLastError() to clear the error bit - result = cudaGetLastError(); - CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result)); - return -1; - } - } - - int max_active_blocks = -1; - result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, Kernel, BaseKernel::kThreadCount, smem_size); - - if (result != cudaSuccess) - { - // Call cudaGetLastError() to clear the error bit - result = cudaGetLastError(); - CUTLASS_TRACE_HOST( - " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); - return -1; - } - - CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); - return max_active_blocks; - } - - /// Sorts each pointer passed in according to the indices that sort - /// `problem_sizes_ptr` in descending order of problem-K dimension. - static void sort_problems(int problem_count, cutlass::gemm::GemmCoord* problem_sizes_ptr, int64_t* lda_host_ptr, - int64_t* ldb_host_ptr, int64_t* ldc_host_ptr, int64_t* ldd_host_ptr, int64_t* offset_A_ptr, - int64_t* offset_B_ptr, int64_t* offset_C_ptr, int64_t* offset_D_ptr) - { - std::vector indices(problem_count); - std::iota(indices.begin(), indices.end(), 0); - std::stable_sort(indices.begin(), indices.end(), - [&problem_sizes_ptr](size_t i, size_t j) { return problem_sizes_ptr[i].k() > problem_sizes_ptr[j].k(); }); - - reorder_array(problem_sizes_ptr, indices); - reorder_array(lda_host_ptr, indices); - reorder_array(ldb_host_ptr, indices); - reorder_array(ldc_host_ptr, indices); - reorder_array(ldd_host_ptr, indices); - reorder_array(offset_A_ptr, indices); - reorder_array(offset_B_ptr, indices); - reorder_array(offset_C_ptr, indices); - reorder_array(offset_D_ptr, indices); - } - - /// Computes the number of threadblocks to launch for the grouped kernel - static int sufficient( - cutlass::gemm::GemmCoord const* problem_sizes_ptr = nullptr, int problem_count = 0, int available_sm_count = -1) - { - // Determine the number of blocks that would be launched to fill up a single - // wave on the GPU with each SM having maximum occupancy. - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); - if (result != cudaSuccess) - { - // Call cudaGetLastError() to clear the error bit - result = cudaGetLastError(); - CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(result)); - return 0; - } - - int multiprocessor_count; - result = cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device_idx); - if (result != cudaSuccess) - { - CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(result)); - return 0; - } - - bool override_sm_count = (available_sm_count < 0 || available_sm_count > multiprocessor_count); - if (override_sm_count) - { - available_sm_count = multiprocessor_count; - } - - int max_active_blocks = maximum_active_blocks(); - if (max_active_blocks <= 0) - { - return 0; - } - - int occupancy_based_block_count = available_sm_count * max_active_blocks; - - if (problem_sizes_ptr == nullptr || problem_count == 0) - { - return occupancy_based_block_count; - } - - int total_tiles = group_tile_count(problem_sizes_ptr, problem_count); - - // If the group contains a single problem, launching the exact number of - // threadblocks needed to cover the problem minimizes the work performed - // per threadblock in finding the next tile to compute. We return total_tiles - // unless the user has provided the SM count. - if (problem_count == 1 && override_sm_count) - { - return total_tiles; - } - - // Choose between the full wave of threadblocks and the tile count. If there - // are fewer tiles in the group than threadblocks in the full wave, only - // some threadblocks will be assigned tiles. Those threadblocks - // which are not assigned tiles still need to perform the work of iterating through - // problem sizes to determine that they have no work to do. This competes for cycles - // with those threadblocks that are assigned tiles to compute. - return std::min(total_tiles, occupancy_based_block_count); - } - - /// Initializes GEMM state from arguments. - Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) - { - - CUTLASS_TRACE_HOST("BaseSplitkGrouped::initialize() - workspace " - << workspace << ", stream: " << (stream ? "non-null" : "null")); - - // Workspace - size_t workspace_bytes = get_workspace_size(args); - - if (workspace_bytes && !workspace) - { - return Status::kErrorWorkspaceNull; - } - - if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) - { - int32_t tile_count = group_tile_count(args); - Status status = precompute(args, tile_count, workspace); - if (status != Status::kSuccess) - { - return status; - } - - gemm_params_ = typename BaseKernel::Params(args, workspace, tile_count); - } - else - { - gemm_params_ = typename BaseKernel::Params(args, workspace); - } - - // Specify shared memory capacity for kernel. - int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); - - if (smem_size >= (48 << 10)) - { - cudaError_t result - = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - - if (result != cudaSuccess) - { - return Status::kErrorInternal; - } - } - - return Status::kSuccess; - } - - /// Lightweight update given a subset of arguments - Status update(Arguments const& args, void* workspace = nullptr) - { - - size_t workspace_bytes = get_workspace_size(args); - - if (workspace_bytes && !workspace) - { - return Status::kErrorWorkspaceNull; - } - - if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) - { - int32_t tile_count = group_tile_count(args); - Status status = precompute(args, tile_count, workspace); - if (status != Status::kSuccess) - { - return status; - } - - gemm_params_.update(args, workspace, tile_count); - } - else - { - gemm_params_.update(args, workspace); - } - - return Status::kSuccess; - } - - /// Runs the kernel using initialized state. - Status run(cudaStream_t stream = nullptr) - { - if (!gemm_params_.problem_visitor.problem_count) - { - return Status::kSuccess; - } - - // - // Launch kernel - // - - // Launch splitk grouped gemm - { - dim3 grid(gemm_params_.threadblock_count, 1, gemm_params_.split_k_slices); - dim3 block(BaseKernel::kThreadCount, 1, 1); - - int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); - cutlass::Kernel<<>>(gemm_params_); - - cudaError_t result = cudaGetLastError(); - - if (result != cudaSuccess) - { - CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); - return Status::kErrorInternal; - } - } - - // Launch splitkReduction - { - dim3 grid(32, gemm_params_.problem_visitor.problem_count); - dim3 block(256); - splitkReduction<<>>(gemm_params_.ptr_D, gemm_params_.ptr_D_split, - gemm_params_.problem_visitor.problem_sizes, gemm_params_.split_k_slices, - gemm_params_.splitk_buffer_offsets); - - cudaError_t result = cudaGetLastError(); - - if (result != cudaSuccess) - { - CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); - return Status::kErrorInternal; - } - } - - return Status::kSuccess; - } - - /// Runs the kernel using initialized state. - Status operator()(cudaStream_t stream = nullptr) - { - return run(stream); - } - - /// Initializes and runs the kernel. - Status operator()(Arguments const& args, void* workspace, cudaStream_t stream = nullptr) - { - - Status status = initialize(args, workspace, stream); - - if (status == Status::kSuccess) - { - status = run(stream); - } - - return status; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// GEMM Grouped -template -class SplitkGemmGrouped : public BaseSplitkGrouped -{ -public: - using GemmKernel = GemmKernel_; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace device -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h deleted file mode 100644 index 100a1161a88..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h +++ /dev/null @@ -1,162 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "cutlass/arch/arch.h" -#include "cutlass/arch/mma.h" -#include "cutlass/bfloat16.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/half.h" -#include "cutlass/layout/matrix.h" - -#include "cutlass_extensions/arch/mma.h" -#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" - -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ - -template -struct MixedGemmArchTraits -{ - static_assert(dependent_false, "Unrecognised parameterization"); -}; - -template -struct MixedGemmArchTraits -{ - static constexpr int Stages = 2; - using OperatorClass = cutlass::arch::OpClassSimt; - using AccType = float; - using LayoutB = cutlass::layout::ColumnMajor; - - static constexpr int ElementsPerAccessA = 1; - static constexpr int ElementsPerAccessB = 1; - static constexpr int ElementsPerAccessC = 1; - static constexpr int ThreadblockK = 8; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; - - using Operator = cutlass::arch::OpMultiplyAdd; -}; - -// ======================= Turing Traits ============================== -// Note that turing does not have native bfloat support so weights and activations will be casted to fp16 -// and compute will happen in fp16 then will be converted for bf16 output. -template -struct MixedGemmArchTraits::value - || cutlass::platform::is_same::value>::type> -{ -private: - using LayoutDetails = LayoutDetailsB; - -public: - static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; - - using OperatorClass = cutlass::arch::OpClassTensorOp; - using AccType = float; - using LayoutB = typename LayoutDetails::Layout; - - static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; - static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; - static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; - - using Operator = typename LayoutDetails::Operator; -}; - -// ======================= Ampere Traits ============================== -template -struct MixedGemmArchTraits::value - || cutlass::platform::is_same::value>::type> -{ -private: - using LayoutDetails = LayoutDetailsB; - -public: - static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; - - using OperatorClass = cutlass::arch::OpClassTensorOp; - using AccType = float; - using LayoutB = typename LayoutDetails::Layout; - - static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; - static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; - static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; - - using Operator = typename LayoutDetails::Operator; -}; - -// ======================= Ada Traits ============================== -template -struct MixedGemmArchTraits::value - || cutlass::platform::is_same::value>::type> -{ -private: - using LayoutDetails = LayoutDetailsB; - -public: - static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; - - using OperatorClass = cutlass::arch::OpClassTensorOp; - using AccType = float; - using LayoutB = typename LayoutDetails::Layout; - - static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; - static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; - static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits::value>; - - using Operator = typename LayoutDetails::Operator; -}; - -// FP8 A/B = fp8, C/D = fp32 -template -struct MixedGemmArchTraits::value - || cutlass::platform::is_same::value>::type> -{ -private: - using LayoutDetails = LayoutDetailsB; - -public: - static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; - - using OperatorClass = cutlass::arch::OpClassTensorOp; - using AccType = float; - // be careful, TypeC should align with HopperGroupedGemmInput::OutputTypeAdaptor_t - using TypeC = __nv_bfloat16; - using LayoutB = typename LayoutDetails::Layout; - - static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; - static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; - static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits::value>; - - using Operator = typename LayoutDetails::Operator; -}; - -} // namespace kernel -} // namespace gemm -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_int8_traits.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_int8_traits.h deleted file mode 100644 index 3fd722994e2..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_int8_traits.h +++ /dev/null @@ -1,57 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "cutlass/arch/arch.h" -#include "cutlass/arch/mma.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/layout/matrix.h" - -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ - -template -struct Int8GemmArchTraits -{ - using OperatorClass = cutlass::arch::OpClassSimt; - using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; -}; - -// ======================= Turing Traits ============================== -template <> -struct Int8GemmArchTraits -{ - using OperatorClass = cutlass::arch::OpClassTensorOp; - using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; -}; - -// ======================= Ampere Traits ============================== -template <> -struct Int8GemmArchTraits -{ - using OperatorClass = cutlass::arch::OpClassTensorOp; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; -}; - -} // namespace kernel -} // namespace gemm -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h deleted file mode 100644 index 1dbd0b1765f..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h +++ /dev/null @@ -1,207 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief - Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with - the appropriate threadblock-scoped epilogue. - - Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are - accommodated by exchanging A and B operands and assuming transposed layouts. Partial - specializations here choose 'device::GemmTransposed' to implement this functionality. - -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/complex.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/device/default_gemm_configuration.h" -#include "cutlass/gemm/kernel/default_gemm.h" -#include "cutlass/gemm/kernel/default_gemm_complex.h" -#include "cutlass/gemm/kernel/gemm_transpose_operands.h" - -#include "cutlass/layout/permute.h" - -#include "splitk_gemm_grouped.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - /// Element type for A matrix operand - typename ElementA_, - /// Layout type for A matrix operand - typename LayoutA_, - /// Complex elementwise transformation on A operand - ComplexTransform TransformA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB_, - /// Layout type for B matrix operand - typename LayoutB_, - /// Complex elementwise transformation on B operand - ComplexTransform TransformB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for C and D matrix operands - typename ElementC_, - /// Layout type for C and D matrix operands - typename LayoutC_, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Operator class tag - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Warp-level tile size (concept: GemmShape) - typename InstructionShape, - /// Epilogue output operator - typename EpilogueOutputOp, - /// Threadblock-level swizzling operator - typename ThreadblockSwizzle, - /// Number of stages used in the pipelined mainloop - int Stages, - /// Whether the schedule of problems to visit has been precomputed - GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly, - /// Operation performed by GEMM - typename Operator = typename device::DefaultGemmConfiguration::Operator, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, - /// Permute result D - typename PermuteDLayout = layout::NoPermute, - /// - typename Enable = void> -struct DefaultSplitkGemmGrouped; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Real-valued GEMM kernels -// - -template < - /// Element type for A matrix operand - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for C and D matrix operands - typename ElementC, - /// Layout type for C and D matrix operands - typename LayoutC, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Operator class tag - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Warp-level tile size (concept: GemmShape) - typename InstructionShape, - /// Epilogue output operator - typename EpilogueOutputOp, - /// Threadblock-level swizzling operator - typename ThreadblockSwizzle, - /// Number of stages used in the pipelined mainloop - int Stages, - /// Whether the schedule of problems to visit has been precomputed - GroupScheduleMode GroupScheduleMode_, - /// Operation performed by GEMM - typename Operator, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear, - /// Permute result D - typename PermuteDLayout> -struct DefaultSplitkGemmGrouped::value>::type> -{ - - // If true, we must construct a 'transposed-and-exchanged' Mma operator. - static bool const kInternalTranspose = platform::is_same::value; - - using MapArguments = kernel::detail::MapArguments; - - // Define the default GEMM kernel - using DefaultGemmKernel = typename kernel::DefaultGemm::GemmKernel; - - /// Define the kernel in terms of the default kernel - using GemmKernel = kernel::SplitkGemmGrouped; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h deleted file mode 100644 index 0baec58ea9a..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h +++ /dev/null @@ -1,566 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/arch/arch.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/semaphore.h" - -#include - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace detail -{ -template -inline constexpr bool dependent_false_v = false; -} - -template -struct GemmFpAIntB -{ - - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - static bool const kSplitKSerial = SplitKSerial; - - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Element; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename Mma::LayoutC; - using ElementScale = ElementC; - - static ComplexTransform const kTransformA = Mma::kTransformA; - static ComplexTransform const kTransformB = Mma::kTransformA; - - // Type definitions about the mainloop. - using Operator = typename Mma::Operator; - using OperatorClass = typename Mma::Operator::OperatorClass; - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; - using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; - - static int const kStages = Mma::kStages; - static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; - - /// Parameters structure - struct Arguments - { - GemmUniversalMode mode = GemmUniversalMode::kGemm; - - cutlass::gemm::GemmCoord problem_size; - int group_size; - typename Mma::IteratorA::TensorRef ref_A; - typename Mma::IteratorB::TensorRef ref_B; - typename Mma::IteratorScale::TensorRef ref_scale; - typename Mma::IteratorScale::TensorRef ref_zero; - typename Epilogue::OutputTileIterator::TensorRef ref_C; - typename Epilogue::OutputTileIterator::TensorRef ref_D; - - // Control serial split-k - int batch_count; - - typename EpilogueOutputOp::Params output_op; - - // For gather+scatter operations - int const* gather_A_indices; - int const* gather_B_indices; - int const* scatter_D_indices; - - // Included so we can use Gemm Universal - int batch_stride_D = 0; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Arguments() {} - - CUTLASS_HOST_DEVICE - Arguments(cutlass::gemm::GemmCoord const& problem_size, int const group_size, - typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, - typename Mma::IteratorScale::TensorRef ref_scale, typename Mma::IteratorScale::TensorRef ref_zero, - typename Epilogue::OutputTileIterator::TensorRef ref_C, - typename Epilogue::OutputTileIterator::TensorRef ref_D, int serial_split_k_factor, - typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(), - int const* gather_A_indices = nullptr, int const* gather_B_indices = nullptr, - int const* scatter_D_indices = nullptr) - : problem_size(problem_size) - , group_size(group_size) - , ref_A(ref_A) - , ref_B(ref_B) - , ref_scale(ref_scale) - , ref_zero(ref_zero) - , ref_C(ref_C) - , ref_D(ref_D) - , batch_count(serial_split_k_factor) - , output_op(output_op) - , gather_A_indices(gather_A_indices) - , gather_B_indices(gather_B_indices) - , scatter_D_indices(scatter_D_indices) - { - } - }; - - /// Parameters structure - struct Params - { - cutlass::gemm::GemmCoord problem_size; - int group_size; - cutlass::gemm::GemmCoord grid_tiled_shape; - int swizzle_log_tile; - typename Mma::IteratorA::Params params_A; - typename Mma::IteratorA::TensorRef ref_A; - typename Mma::IteratorB::Params params_B; - typename Mma::IteratorB::TensorRef ref_B; - typename Mma::IteratorScale::Params params_scale; - typename Mma::IteratorScale::TensorRef ref_scale; - typename Mma::IteratorScale::TensorRef ref_zero; - typename Epilogue::OutputTileIterator::Params params_C; - typename Epilogue::OutputTileIterator::TensorRef ref_C; - typename Epilogue::OutputTileIterator::Params params_D; - typename Epilogue::OutputTileIterator::TensorRef ref_D; - typename EpilogueOutputOp::Params output_op; - int* semaphore; - int gemm_k_size; - // For gather+scatter operations - int const* gather_A_indices; - int const* gather_B_indices; - int const* scatter_D_indices; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params() - : swizzle_log_tile(0) - , semaphore(0) - , gemm_k_size(0) - { - } - - CUTLASS_HOST_DEVICE - Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, int const gemm_k_size, - void* workspace = nullptr) - : problem_size(args.problem_size) - , group_size(args.group_size) - , grid_tiled_shape(grid_tiled_shape) - , swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)) - , params_A(args.ref_A.layout()) - , ref_A(args.ref_A) - , params_B(args.ref_B.layout()) - , ref_B(args.ref_B) - , params_scale(args.ref_scale.layout()) - , ref_scale(args.ref_scale) - , ref_zero(args.ref_zero) - , params_C(args.ref_C.layout()) - , ref_C(args.ref_C) - , params_D(args.ref_D.layout()) - , ref_D(args.ref_D) - , output_op(args.output_op) - , semaphore(static_cast(workspace)) - , gemm_k_size(gemm_k_size) - , gather_A_indices(args.gather_A_indices) - , gather_B_indices(args.gather_B_indices) - , scatter_D_indices(args.scatter_D_indices) - { - } - }; - - /// Shared memory storage structure - union SharedStorage - { - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - }; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - GemmFpAIntB() {} - - /// Determines whether kernel satisfies alignment - static Status can_implement(Arguments const& args) - { - static int const kAlignmentA - = (platform::is_same>::value) ? 32 - : (platform::is_same>::value) - ? 64 - : Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB - = (platform::is_same>::value) ? 32 - : (platform::is_same>::value) - ? 64 - : Mma::IteratorB::AccessType::kElements; - - static int const kAlignmentScale = Mma::IteratorScale::AccessType::kElements; - - static int const kAlignmentC = (platform::is_same>::value) - ? 32 - : (platform::is_same>::value) - ? 64 - : Epilogue::OutputTileIterator::kElementsPerAccess; - - if (!TensorRef_aligned(args.ref_A, kAlignmentA)) - { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(args.ref_B, kAlignmentB)) - { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(args.ref_scale, kAlignmentScale)) - { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(args.ref_zero, kAlignmentScale)) - { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(args.ref_C, kAlignmentC)) - { - return Status::kErrorMisalignedOperand; - } - - if (!TensorRef_aligned(args.ref_D, kAlignmentC)) - { - return Status::kErrorMisalignedOperand; - } - - if (!args.ref_scale.good()) - { - return Status::kErrorNotSupported; - } - - if constexpr (hasZero(Mma::QuantOp)) - { - if (!args.ref_zero.good()) - { - return Status::kErrorNotSupported; - } - } - else - { - if (args.ref_zero.good()) - { - return Status::kErrorNotSupported; - } - } - - if constexpr (isFinegrained(Mma::QuantOp)) - { - if (args.group_size != 64 && args.group_size != 128) - { - return Status::kErrorNotSupported; - } - } - - return Status::kSuccess; - } - - static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) - { - - return 0; - } - - // Initializes the fine grained scale+bias iterator. Needed since the fine grained iterator - // has a different constructor signature than a regular cutlass iterator - template = true> - CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, - typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero, - typename IteratorScale::TensorCoord extent, int thread_id, - typename IteratorScale::TensorCoord const& threadblock_offset, int group_size) - { - - return IteratorScale(params, pointer_scale, pointer_zero, extent, thread_id, threadblock_offset, group_size); - } - - template = true> - CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, - typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero, - typename IteratorScale::TensorCoord extent, int thread_id, - typename IteratorScale::TensorCoord const& threadblock_offset, int group_size) - { - - return IteratorScale(params, pointer_scale, extent, thread_id, threadblock_offset); - } - - CUTLASS_DEVICE - void run_kernel_(Params const& params, SharedStorage& shared_storage) - { - using LayoutB = typename Mma::IteratorB::Layout; - static_assert(platform::is_same::value && kInterleave == 1 - || platform::is_same::value && kInterleave >= 1, - "B must be row major/col major OR col major interleaved."); - - // Compute threadblock location - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // Early exit if CTA is out of range - if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() - || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) - { - - return; - } - - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_tile_offset.m() * Mma::Shape::kM, - threadblock_tile_offset.k() * params.gemm_k_size, - }; - - cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave, - threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave}; - - typename MatrixCoord::Index fg_row_offset = threadblock_tile_offset.k() * params.gemm_k_size / 64; - typename MatrixCoord::Index scale_row_offset = isFinegrained(Mma::QuantOp) ? fg_row_offset : 0; - cutlass::MatrixCoord tb_offset_scale{scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN}; - - // Problem size is a function of threadblock index in the K dimension - int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size); - - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A(params.params_A, params.ref_A.data(), - {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A, params.gather_A_indices); - - typename Mma::IteratorB iterator_B(params.params_B, params.ref_B.data(), - {problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, thread_idx, tb_offset_B, - params.gather_B_indices); - - typename MatrixCoord::Index scale_row_extent = isFinegrained(Mma::QuantOp) ? problem_size_k / 64 : 1; - typename Mma::IteratorScale iterator_scale = initialize_scale( - params.params_scale, params.ref_scale.data(), params.ref_zero.data(), - {scale_row_extent, params.problem_size.n()}, thread_idx, tb_offset_scale, params.group_size); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - int lane_idx = threadIdx.x % 32; - - // - // Main loop - // - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - if (!kSplitKSerial || gemm_k_iterations > 0) - { - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators); - } - - // - // Epilogue - // - - EpilogueOutputOp output_op(params.output_op); - - // - // Masked tile iterators constructed from members - // - - threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // assume identity swizzle - MatrixCoord threadblock_offset( - threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN); - - int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); - - // Construct the semaphore. - Semaphore semaphore(params.semaphore + block_idx, thread_idx); - - // If performing a reduction via split-K, fetch the initial synchronization - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) - { - - // Fetch the synchronization lock initially but do not block. - semaphore.fetch(); - - // Indicate which position in a serial reduction the output operator is currently updating - output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); - } - - // Tile iterator loading from source tensor. - typename Epilogue::OutputTileIterator iterator_C(params.params_C, params.ref_C.data(), params.problem_size.mn(), - thread_idx, threadblock_offset, params.scatter_D_indices); - - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D(params.params_D, params.ref_D.data(), params.problem_size.mn(), - thread_idx, threadblock_offset, params.scatter_D_indices); - - Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); - - // Wait on the semaphore - this latency may have been covered by iterator construction - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) - { - - // For subsequent threadblocks, the source matrix is held in the 'D' tensor. - if (threadblock_tile_offset.k()) - { - iterator_C = iterator_D; - } - - semaphore.wait(threadblock_tile_offset.k()); - } - - // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, iterator_D, accumulators, iterator_C); - - // - // Release the semaphore - // - - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) - { - - int lock = 0; - if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) - { - - // The final threadblock resets the semaphore for subsequent grids. - lock = 0; - } - else - { - // Otherwise, the semaphore is incremented - lock = threadblock_tile_offset.k() + 1; - } - - semaphore.release(lock); - } - } - - template - CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) - { - if constexpr (platform::is_same::value) - { - run_kernel_(params, shared_storage); - } - else - { - CUTLASS_NOT_IMPLEMENTED(); - } - } - - /* - To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond - to the ArchTag of the cutlass kernel operator. - */ - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const& params, SharedStorage& shared_storage) - { -#if defined(__CUDA_ARCH__) -#if (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) - run_kernel(params, shared_storage); -#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890) - run_kernel(params, shared_storage); -#elif (__CUDA_ARCH__ == 890) - run_kernel(params, shared_storage); -#elif (__CUDA_ARCH__ >= 900) - CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels. -#else - static_assert( - false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); -#endif -#else - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh deleted file mode 100644 index 1bd0a3f11a8..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh +++ /dev/null @@ -1,218 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include -#include -#include -#include -#include - -namespace fused_moe -{ -template -struct Fused_Moe_Kernel_sm80 -{ - static constexpr int kMaxTileM = MaxTileM_; - static constexpr int kTileN = isGateActivation(activation_type_) ? TileN_ / 2 : TileN_; - static constexpr int kTileK = TileK_; - static constexpr int kStages = Stages_; - static constexpr Activation_Type activation_type = activation_type_; - - using ElementInput = ElementInput_; - using ElementWeight = ElementWeight_; - using ElementOutput = ElementOutput_; - using BaseKernelTraits = Fused_Moe_Kernel_traits_sm80; - using Routine_Arguments = Routine_Arguments; - using Routine_Params = Routine_Params; - using ProblemVisitor - = cutlass::gemm::kernel::MoeProblemVisitor, false>, - cutlass::gemm::GemmShape, cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, - BaseKernelTraits::kThreadCount, BaseKernelTraits::kThreadCount>; - - struct Arguments - { - Routine_Arguments routine_args; - int problem_count{}; - int threadblock_count{}; - }; - - struct Params - { - Routine_Params routine_params; - int threadblock_count{}; - typename ProblemVisitor::Params problem_visitor_param; - }; - - using BaseKernelTraits_m16 = Fused_Moe_Kernel_traits_sm80; - static constexpr bool use_m16 = TileK_ >= 64; // use tileshape m = 16 when original tileshape k >= 64 - - static constexpr int kSmemSize = use_m16 - ? (BaseKernelTraits::kSmemSize > BaseKernelTraits_m16::kSmemSize ? BaseKernelTraits::kSmemSize - : BaseKernelTraits_m16::kSmemSize) - : BaseKernelTraits::kSmemSize; - static constexpr int kThreadCount = BaseKernelTraits::kThreadCount; - - static constexpr bool can_implement(int const avaliable_smem_size) - { - return BaseKernelTraits::can_implement(avaliable_smem_size); - } - - static Params to_underlying_arguments(Arguments const& args) - { - return { - {args.routine_args.ptr_input, args.routine_args.ptr_fc1, args.routine_args.ptr_bias, - args.routine_args.ptr_output, args.routine_args.total_tokens_including_expert, args.routine_args.gemm_n, - args.routine_args.gemm_k, args.routine_args.num_expert, args.routine_args.bias_is_broadcast}, - args.threadblock_count, - {args.routine_args.total_tokens_including_expert, args.routine_args.gemm_n, args.routine_args.gemm_k, - args.problem_count, nullptr, 0}}; - } - - CUTE_DEVICE - void run_device(Params const& params) - { -#define ROUTINE_PATH(kTileM_size) \ - { \ - constexpr int kTileM = use_m16 ? (kTileM_size) : ((kTileM_size) == 16 ? 32 : (kTileM_size)); \ - using RoutineTraits = Fused_Moe_Kernel_routine_sm80; \ - RoutineTraits routine{}; \ - int const block_m_idx = (block_m_idx_temp) *kMaxTileM / kTileM; \ - routine.run_routine(params.routine_params, problem_index, block_m_idx, block_n_idx, gemm_m); \ - } - typename ProblemVisitor::SharedStorage dummy_storage{}; - ProblemVisitor problem_visitor(params.problem_visitor_param, dummy_storage, blockIdx.x); - while (problem_visitor.next_tile()) - { - auto problem_size = problem_visitor.problem_size(); - auto grid_size = problem_visitor.grid_shape(problem_size); - auto problem_index = problem_visitor.problem_index(); - int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); - int const gemm_m = problem_size.m(); - const int32_t block_m_idx_temp = cta_idx / grid_size.n(); - const int32_t block_n_idx = cta_idx % grid_size.n(); - - int const residue_m = gemm_m - kMaxTileM * block_m_idx_temp; - if (residue_m > kMaxTileM / 2) - { - using RoutineTraits = Fused_Moe_Kernel_routine_sm80; - RoutineTraits routine{}; - routine.run_routine(params.routine_params, problem_index, block_m_idx_temp, block_n_idx, gemm_m); - } - else - { - - if constexpr (kMaxTileM >= 128) - { - if (residue_m > 32) - { - ROUTINE_PATH(64); - } - else if (residue_m > 16) - { - ROUTINE_PATH(32); - } - else - { - // TODO: use cuda core gemm here - ROUTINE_PATH(16); - } - } - else if (kMaxTileM == 64) - { - if (residue_m > 16) - { - ROUTINE_PATH(32); - } - else - { - // TODO: use cuda core gemm here - ROUTINE_PATH(16); - } - } - else if (kMaxTileM == 32) - { - // TODO: use cuda core gemm here - ROUTINE_PATH(16); - } - else - { - // TODO: use cuda core gemm here - ROUTINE_PATH(16); - } - } - problem_visitor.advance(gridDim.x); - } -#undef ROUTINE_PATH - } -}; - -template -__global__ void run_global(__grid_constant__ typename GemmType::Params const params) -{ - GemmType gemm; - gemm.run_device(params); -} - -/// Computes the maximum number of active blocks per multiprocessor -template -static int fused_gemm_maximum_active_blocks(int smem_capacity = -1) -{ - - CUTLASS_TRACE_HOST("BaseGrouped::maximum_active_blocks()"); - - constexpr int smem_size = GemmType::kSmemSize; - - CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); - - cudaError_t result; - if (smem_size > (48 << 10)) - { - result = cudaFuncSetAttribute(run_global, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - - if (result != cudaSuccess) - { - // Call cudaGetLastError() to clear the error bit - result = cudaGetLastError(); - CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result)); - return -1; - } - } - - int max_active_blocks = -1; - result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, run_global, GemmType::kThreadCount, smem_size); - - if (result != cudaSuccess) - { - // Call cudaGetLastError() to clear the error bit - result = cudaGetLastError(); - CUTLASS_TRACE_HOST( - " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); - return -1; - } - - CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); - return max_active_blocks; -} -} // namespace fused_moe diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh deleted file mode 100644 index 4c46a541efd..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh +++ /dev/null @@ -1,799 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include - -namespace fused_moe -{ - -template -struct Fused_Moe_Kernel_routine_sm80; - -template -struct Fused_Moe_Kernel_routine_sm80> -{ - using KT = Fused_Moe_Kernel_traits_sm80; - using Params = Routine_Params; - - CUTE_DEVICE auto gmem_tensor_init(int const problem_index, int const gemm_m, Params const& params) - { - using X = cute::Underscore; - - int const M = gemm_m; - int const N1 = params.gemm_n; - int const K1 = params.gemm_k; - bool const bias_is_broadcast = params.bias_is_broadcast; - - int const row_jump = ((problem_index == 0) ? 0 : params.total_tokens_including_expert[problem_index - 1]); - typename KT::ElementInput const* ptr_input_ = params.ptr_input + row_jump * K1; - typename KT::ElementWeight const* ptr_fc1_gate_ - = params.ptr_fc1 + (2 * problem_index + 1) * N1 * K1; // TODO: we only focus on gated activation.. - typename KT::ElementWeight const* ptr_fc1_ - = params.ptr_fc1 + 2 * problem_index * N1 * K1; // TODO: we only focus on gated activation.. - typename KT::ElementInput const* ptr_bias_ = (params.ptr_bias == nullptr) - ? nullptr - : (bias_is_broadcast ? params.ptr_bias + 2 * problem_index * N1 : params.ptr_bias + 2 * row_jump * N1); - typename KT::ElementInput const* ptr_bias_gate_ = (params.ptr_bias == nullptr) - ? nullptr - : (bias_is_broadcast ? params.ptr_bias + (2 * problem_index + 1) * N1 - : params.ptr_bias + (2 * row_jump + 1) * N1); - typename KT::ElementOutput* ptr_output_ = params.ptr_output + row_jump * N1; - - cute::Tensor mInput_mk - = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_input_)), - cute::make_shape(M, K1), cute::make_stride(K1, cute::_1{})); - - cute::Tensor mfc1_gate_nk - = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_fc1_gate_)), - cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{})); - - cute::Tensor mfc1_nk - = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_fc1_)), - cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{})); - - cute::Tensor mBias_mn = cute::make_tensor( - cute::make_gmem_ptr(static_cast(ptr_bias_)), cute::make_shape(M, N1), - cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1 * 2, - cute::_1{})); // trick: bias shape is [1, N], but we use [M, N]. - - cute::Tensor mBias_gate_mn = cute::make_tensor( - cute::make_gmem_ptr(static_cast(ptr_bias_gate_)), cute::make_shape(M, N1), - cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1 * 2, - cute::_1{})); // trick: bias shape is [1, N], but we use [M, N]. - - cute::Tensor mOutput_mn - = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_output_)), - cute::make_shape(M, N1), cute::make_stride(N1, cute::_1{})); - - cute::Tensor gInput_mk = cute::local_tile(mInput_mk, typename KT::TileShape{}, - cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_K, m, k) - cute::Tensor gfc1_gate_nk = cute::local_tile(mfc1_gate_nk, typename KT::TileShape{}, - cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_N, BLK_K, n, k) - cute::Tensor gfc1_nk = cute::local_tile(mfc1_nk, typename KT::TileShape{}, - cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_N, BLK_K, n, k) - - cute::Tensor gBias_mn = cute::local_tile(mBias_mn, typename KT::TileShape{}, - cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) - - cute::Tensor gBias_gate_mn = cute::local_tile(mBias_gate_mn, typename KT::TileShape{}, - cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) - - cute::Tensor gOutput_mn = cute::local_tile(mOutput_mn, typename KT::TileShape{}, - cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) - - return cute::make_tuple(gInput_mk, gfc1_gate_nk, gfc1_nk, gBias_mn, gBias_gate_mn, gOutput_mn); - } - - // be careful, m_idx will change when use another tile shape.. - CUTE_DEVICE void run_routine( - Params const& params, int const problem_index, int const block_m_idx, int const block_n_idx, int const gemm_m) - { - extern __shared__ char smem_[]; - typename KT::SharedStorage& shared_storage = *reinterpret_cast(smem_); - int const thread_idx = threadIdx.x; - bool const bias_is_broadcast = params.bias_is_broadcast; - // gmem tensor partition .. - auto [gInput_mk, gfc1_gate_nk, gfc1_nk, gBias_mn, gBias_gate_mn, gOutput_mn] - = gmem_tensor_init(problem_index, gemm_m, params); - int const residue_m = gemm_m - block_m_idx * cute::size<0>(gInput_mk); - auto const n_tile_count = cute::size<2>(gfc1_gate_nk); - - // smem tensor .. - cute::Tensor sInput = cute::make_tensor( - cute::make_smem_ptr(shared_storage.smem_input.data()), typename KT::SmemLayoutA{}); // (BLK_M, BLK_K, Stage) - cute::Tensor sfc1_weight = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_weight.data()), - typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage) - cute::Tensor sfc1_gate_weight - = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_gate_weight.data()), - typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage) - cute::Tensor sO = cute::make_tensor( - cute::make_smem_ptr(shared_storage.smem_o.data()), typename KT::SmemLayoutO{}); // (BLK_M, BLK_N) - - // (1) first step, get the fc1_res and fc1_gate - - // (1.1) get partition for gmem -> smem - cute::Tensor gInput = gInput_mk(cute::_, cute::_, block_m_idx, cute::_); // (BLK_M, BLK_K, k) - cute::Tensor gfc1 = gfc1_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k) - cute::Tensor gfc1g = gfc1_gate_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k) - - typename KT::GmemTiledCopyA gmem_tiled_copy_A; - typename KT::GmemTiledCopyB gmem_tiled_copy_B; - auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx); - auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx); - - cute::Tensor tInputgInput = gmem_thr_copy_A.partition_S(gInput); // (ACPY,ACPY_M,ACPY_K,k) - cute::Tensor tInputsInput = gmem_thr_copy_A.partition_D(sInput); // (ACPY,ACPY_M,ACPY_K,Stage) - cute::Tensor tfc1gfc1 = gmem_thr_copy_B.partition_S(gfc1); // (BCPY,BCPY_N,BCPY_K,k) - cute::Tensor tfc1sfc1 = gmem_thr_copy_B.partition_D(sfc1_weight); // (BCPY,BCPY_N,BCPY_K,Stage) - cute::Tensor tfc1ggfc1g = gmem_thr_copy_B.partition_S(gfc1g); // (BCPY,BCPY_N,BCPY_K,k) - cute::Tensor tfc1gsfc1g = gmem_thr_copy_B.partition_D(sfc1_gate_weight); // (BCPY,BCPY_N,BCPY_K,Stage) - - // Allocate predicate tensors for input and fc weight (actually we only need input predicate tensor) - cute::Tensor tInputpInput - = cute::make_tensor(cute::make_shape(cute::size<1>(tInputsInput), cute::size<2>(tInputsInput)), - cute::Stride{}); - // Construct identity layout for sInput - cute::Tensor cInput = make_identity_tensor( - make_shape(cute::size<0>(sInput), cute::size<1>(sInput))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - - // Repeat the partitioning with identity layouts - cute::Tensor tInputcInput = gmem_thr_copy_A.partition_S(cInput); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - - // Set predicates for m bounds - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < cute::size<0>(tInputpInput); ++m) - { - tInputpInput(m, 0) = cute::get<0>(tInputcInput(0, m, 0)) < residue_m; // blk_m coord < residue_m - } - - // (1.2) prefetch gmem -> smem - cute::clear(tInputsInput); // we don't need to clear tfc1sfc1.. - auto k_tile_iter = cute::make_coord_iterator(cute::size<2>(gInput)); // emm, iter start from 0 - int k_tile_count = cute::size<2>(gInput); - CUTLASS_PRAGMA_UNROLL - for (int k_pipe = 0; k_pipe < KT::Stages - 1; ++k_pipe) - { - if (k_tile_count <= 0) - { - cute::clear(tInputpInput); - } - // cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), - // tInputsInput(cute::_, cute::_, cute::_, k_pipe)); - // use copy_if - cute::copy_if(gmem_tiled_copy_A, tInputpInput, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), - tInputsInput(cute::_, cute::_, cute::_, k_pipe)); - cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter), - tfc1sfc1(cute::_, cute::_, cute::_, k_pipe)); - cute::copy(gmem_tiled_copy_B, tfc1ggfc1g(cute::_, cute::_, cute::_, *k_tile_iter), - tfc1gsfc1g(cute::_, cute::_, cute::_, k_pipe)); - cute::cp_async_fence(); - k_tile_count--; - if (k_tile_count > 0) - { - ++k_tile_iter; - } - } - - // (1.3) get partition for rf - typename KT::TiledMma tiled_mma; - auto thr_mma = tiled_mma.get_thread_slice(thread_idx); - cute::Tensor tOrInput = thr_mma.partition_fragment_A(sInput(cute::_, cute::_, 0)); // (MMA,MMA_M,MMA_K) - cute::Tensor tOrfc1 = thr_mma.partition_fragment_B(sfc1_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K) - cute::Tensor tOrfc1g = thr_mma.partition_fragment_B(sfc1_gate_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K) - - cute::Tensor accum - = cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N) - cute::Tensor accum_gate - = cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N) - cute::clear(accum); - cute::clear(accum_gate); - // checkout the shape - CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum)); // MMA_M - CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum_gate)); // MMA_M - CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum)); // MMA_N - CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum_gate)); // MMA_N - CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1g) == cute::size<2>(accum)); // MMA_N - CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1g) == cute::size<2>(accum_gate)); // MMA_N - CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1)); // MMA_K - CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1g)); // MMA_K - CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_A) == cute::size(tiled_mma)); - CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_B) == cute::size(tiled_mma)); - - // (1.4)retiling the smem and rf for copy.. - auto smem_tiled_copy_A = cute::make_tiled_copy_A(typename KT::SmemCopyAtomA{}, tiled_mma); - auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); - cute::Tensor tOsInput = smem_thr_copy_A.partition_S(sInput); // (CPY,CPY_M,CPY_K,Stage) - cute::Tensor tOrInput_copy_view = smem_thr_copy_A.retile_D(tOrInput); // (CPY,CPY_M,CPY_K) - CUTE_STATIC_ASSERT_V(cute::size<1>(tOsInput) == cute::size<1>(tOrInput_copy_view)); // CPY_M - CUTE_STATIC_ASSERT_V(cute::size<2>(tOsInput) == cute::size<2>(tOrInput_copy_view)); // CPY_K - - auto smem_tiled_copy_B = cute::make_tiled_copy_B(typename KT::SmemCopyAtomB{}, tiled_mma); - auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); - cute::Tensor tOsfc1 = smem_thr_copy_B.partition_S(sfc1_weight); // (CPY,CPY_N,CPY_K,Stage) - cute::Tensor tOrfc1_copy_view = smem_thr_copy_B.retile_D(tOrfc1); // (CPY,CPY_N,CPY_K) - cute::Tensor tOsfc1g = smem_thr_copy_B.partition_S(sfc1_gate_weight); // (CPY,CPY_N,CPY_K,Stage) - cute::Tensor tOrfc1g_copy_view = smem_thr_copy_B.retile_D(tOrfc1g); // (CPY,CPY_N,CPY_K) - CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1) == cute::size<1>(tOrfc1_copy_view)); // CPY_N - CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1) == cute::size<2>(tOrfc1_copy_view)); // CPY_K - CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1g) == cute::size<1>(tOrfc1g_copy_view)); // CPY_N - CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1g) == cute::size<2>(tOrfc1g_copy_view)); // CPY_K - - // (1.5) mainloop - // Current pipe index in smem to read from - int smem_pipe_read = 0; - // Current pipe index in smem to write to - int smem_pipe_write = KT::Stages - 1; - - cute::Tensor tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); - cute::Tensor tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); - cute::Tensor tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read); - - constexpr int K_BLOCK_MAX = cute::size<2>(tOrInput); - // prefetch register pipeline - if constexpr (K_BLOCK_MAX > 1) - { - cute::cp_async_wait(); - __syncthreads(); - - // Prefetch the first rmem from the first k-tile - cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, cute::Int<0>{}), - tOrInput_copy_view(cute::_, cute::_, cute::Int<0>{})); - cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, cute::Int<0>{}), - tOrfc1_copy_view(cute::_, cute::_, cute::Int<0>{})); - cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, cute::Int<0>{}), - tOrfc1g_copy_view(cute::_, cute::_, cute::Int<0>{})); - } - // k loop for mainloop - CUTLASS_PRAGMA_NO_UNROLL - for (; k_tile_count > 0; --k_tile_count) - { - cute::for_each(cute::make_int_sequence{}, - [&](auto k_block) - { - if (k_block == K_BLOCK_MAX - 1) - { - tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); - tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); - tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read); - cute::cp_async_wait(); - __syncthreads(); - } - // Load A, B shmem->regs for k_block+1 - auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; - cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), - tOrInput_copy_view(cute::_, cute::_, k_block_next)); - cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), - tOrfc1_copy_view(cute::_, cute::_, k_block_next)); - cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next), - tOrfc1g_copy_view(cute::_, cute::_, k_block_next)); - // Copy gmem to smem before computing gemm on each k-pipe - if (k_block == 0) - { - // cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), - // tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write)); - cute::copy_if(gmem_tiled_copy_A, tInputpInput, - tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), - tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write)); - cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter), - tfc1sfc1(cute::_, cute::_, cute::_, smem_pipe_write)); - cute::copy(gmem_tiled_copy_B, tfc1ggfc1g(cute::_, cute::_, cute::_, *k_tile_iter), - tfc1gsfc1g(cute::_, cute::_, cute::_, smem_pipe_write)); - cute::cp_async_fence(); - if (k_tile_count - 1 > 0) - { - ++k_tile_iter; - } - - // Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe) - smem_pipe_write = smem_pipe_read; - ++smem_pipe_read; - smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read; - } - // Thread-level register gemm for k_block - cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), - accum); - cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block), - tOrfc1g(cute::_, cute::_, k_block), accum_gate); - }); - } - - // load tail - cute::for_each(cute::make_int_sequence{}, - [&](auto WaitIndex) - { - k_tile_count--; - using WaitIndex_t = decltype(WaitIndex); - cute::for_each(cute::make_int_sequence{}, - [&](auto k_block) - { - if (k_block == K_BLOCK_MAX - 1) - { - tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); - tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); - tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read); - cute::cp_async_wait(); - __syncthreads(); - } - // Load A, B shmem->regs for k_block+1 - auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; - cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), - tOrInput_copy_view(cute::_, cute::_, k_block_next)); - cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), - tOrfc1_copy_view(cute::_, cute::_, k_block_next)); - cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next), - tOrfc1g_copy_view(cute::_, cute::_, k_block_next)); - if (k_block == 0) - { - // only update smem_pipe_read - ++smem_pipe_read; - smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read; - } - // Thread-level register gemm for k_block - cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), - tOrfc1(cute::_, cute::_, k_block), accum); - cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block), - tOrfc1g(cute::_, cute::_, k_block), accum_gate); - }); - }); - // mma tail - cute::for_each(cute::make_int_sequence{}, - [&](auto k_block) - { - // Load A, B shmem->regs for k_block+1 - auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; - cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), - tOrInput_copy_view(cute::_, cute::_, k_block_next)); - cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), - tOrfc1_copy_view(cute::_, cute::_, k_block_next)); - cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next), - tOrfc1g_copy_view(cute::_, cute::_, k_block_next)); - // Thread-level register gemm for k_block - cute::gemm( - tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), accum); - cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block), - tOrfc1g(cute::_, cute::_, k_block), accum_gate); - }); - // if (cute::thread0()) { - // cute::print(accum_gate(0, 0, 0)); - // printf("\n"); - // } - // (2) add bias if it has.. - if (params.ptr_bias != nullptr) - { - cute::Tensor gBias = gBias_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx); - cute::Tensor gBias_gate = gBias_gate_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx); - cute::Tensor tOgBias = thr_mma.partition_C(gBias); - cute::Tensor tOgBiasg = thr_mma.partition_C(gBias_gate); - for (int i = 0; i < cute::size(accum); i++) - { - accum(i) += tOgBias(i); - accum_gate(i) += tOgBiasg(i); - } - } - - // (3) calculate swiglu - using ActivationFn = typename KT::ActivationFn; - ActivationFn fn{}; - CUTLASS_PRAGMA_UNROLL - for (int temp_iter = 0; temp_iter < cute::size(accum); temp_iter++) - { - accum(temp_iter) = fn(accum_gate(temp_iter)) * accum(temp_iter); - } - - // (4) push all the result to smem - // (4.1) convert result from ElementAccum to ElementInput - cute::Tensor temp_accum = util_convert_type(accum); - // if (cute::thread0()) { - // cute::print(temp_accum(0, 0, 0)); - // printf("\n"); - // } - // (4.2) retile rf and smem for copy back.. - auto smem_tiled_copy_O = cute::make_tiled_copy_C(typename KT::SmemCopyAtomO{}, tiled_mma); - auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); - // cute::clear(sO); - cute::Tensor taccumrO = smem_thr_copy_O.retile_S(temp_accum); - cute::Tensor taccumsO = smem_thr_copy_O.partition_D(sO); - - // (4.3) copy rf result to smem (TODO: maybe use forloop for better performance..) - cute::copy(smem_tiled_copy_O, taccumrO, taccumsO); - __syncthreads(); - - // (4.4) sO -> rO -> gO - - typename KT::GmemTiledCopyO gmem_tiled_copy_O; - auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); - // auto gmem_thr_copy_Bias = gmem_tiled_copy_O.get_thread_slice(thread_idx % KT::kGmemTrheadsPerRow); // - // remember, for all the threads in the same col, they have the same idx for bias.. - cute::Tensor gO = gOutput_mn(cute::_, cute::_, block_m_idx, block_n_idx); - // cute::Tensor gBias = gBias_mn(cute::_, cute::_, 0, block_n_idx); // bias only have one row.. - auto tOsO = gmem_thr_copy_O.partition_S(sO); - auto tOgO = gmem_thr_copy_O.partition_D(gO); - // auto tOgBias = gmem_thr_copy_O.partition_D(gBias); - cute::Tensor cOutput = cute::make_identity_tensor( - cute::make_shape(cute::size<0>(typename KT::TileShape{}), cute::size<1>(typename KT::TileShape{}))); - cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cOutput); - cute::Tensor tOrO = cute::make_tensor(cute::shape(tOgO)); - cute::copy(gmem_tiled_copy_O, tOsO, tOrO); - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < cute::size<1>(tOgO); ++m) - { - if (cute::get<0>(tOcO(0, m, 0)) < residue_m) - { - cute::copy(gmem_tiled_copy_O, tOrO(cute::_, m, cute::_), tOgO(cute::_, m, cute::_)); - } - } - } -}; - -template -struct Fused_Moe_Kernel_routine_sm80> -{ - - using KT = Fused_Moe_Kernel_traits_sm80; - using Params = Routine_Params; - - CUTE_DEVICE auto gmem_tensor_init(int const problem_index, int const gemm_m, Params const& params) - { - using X = cute::Underscore; - - int const M = gemm_m; - int const N1 = params.gemm_n; - int const K1 = params.gemm_k; - bool const bias_is_broadcast = params.bias_is_broadcast; - - int const row_jump = ((problem_index == 0) ? 0 : params.total_tokens_including_expert[problem_index - 1]); - typename KT::ElementInput const* ptr_input_ = params.ptr_input + row_jump * K1; - typename KT::ElementWeight const* ptr_fc1_ = params.ptr_fc1 + problem_index * N1 * K1; - typename KT::ElementInput const* ptr_bias_ = (params.ptr_bias == nullptr) - ? nullptr - : (bias_is_broadcast ? params.ptr_bias + problem_index * N1 : params.ptr_bias + row_jump * N1); - typename KT::ElementOutput* ptr_output_ = params.ptr_output + row_jump * N1; - - cute::Tensor mInput_mk - = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_input_)), - cute::make_shape(M, K1), cute::make_stride(K1, cute::_1{})); - - cute::Tensor mfc1_nk - = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_fc1_)), - cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{})); - - cute::Tensor mBias_mn = cute::make_tensor( - cute::make_gmem_ptr(static_cast(ptr_bias_)), cute::make_shape(M, N1), - cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1, - cute::_1{})); // trick: bias shape is [1, N], but we use [M, N]. - - cute::Tensor mOutput_mn - = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_output_)), - cute::make_shape(M, N1), cute::make_stride(N1, cute::_1{})); - - cute::Tensor gInput_mk = cute::local_tile(mInput_mk, typename KT::TileShape{}, - cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_K, m, k) - cute::Tensor gfc1_nk = cute::local_tile(mfc1_nk, typename KT::TileShape{}, - cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_N, BLK_K, n, k) - - cute::Tensor gBias_mn = cute::local_tile(mBias_mn, typename KT::TileShape{}, - cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) - - cute::Tensor gOutput_mn = cute::local_tile(mOutput_mn, typename KT::TileShape{}, - cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) - - return cute::make_tuple(gInput_mk, gfc1_nk, gBias_mn, gOutput_mn); - } - - // be careful, m_idx will change when use another tile shape.. - CUTE_DEVICE void run_routine( - Params const& params, int const problem_index, int const block_m_idx, int const block_n_idx, int const gemm_m) - { - extern __shared__ char smem_[]; - typename KT::SharedStorage& shared_storage = *reinterpret_cast(smem_); - int const thread_idx = threadIdx.x; - bool const bias_is_broadcast = params.bias_is_broadcast; - // gmem tensor partition .. - auto [gInput_mk, gfc1_nk, gBias_mn, gOutput_mn] = gmem_tensor_init(problem_index, gemm_m, params); - int const residue_m = gemm_m - block_m_idx * cute::size<0>(gInput_mk); - auto const n_tile_count = cute::size<2>(gfc1_nk); - - // smem tensor .. - cute::Tensor sInput = cute::make_tensor( - cute::make_smem_ptr(shared_storage.smem_input.data()), typename KT::SmemLayoutA{}); // (BLK_M, BLK_K, Stage) - cute::Tensor sfc1_weight = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_weight.data()), - typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage) - cute::Tensor sO = cute::make_tensor( - cute::make_smem_ptr(shared_storage.smem_o.data()), typename KT::SmemLayoutO{}); // (BLK_M, BLK_N) - - // (1) first step, get the fc1_res and fc1_gate - - // (1.1) get partition for gmem -> smem - cute::Tensor gInput = gInput_mk(cute::_, cute::_, block_m_idx, cute::_); // (BLK_M, BLK_K, k) - cute::Tensor gfc1 = gfc1_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k) - - typename KT::GmemTiledCopyA gmem_tiled_copy_A; - typename KT::GmemTiledCopyB gmem_tiled_copy_B; - auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx); - auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx); - - cute::Tensor tInputgInput = gmem_thr_copy_A.partition_S(gInput); // (ACPY,ACPY_M,ACPY_K,k) - cute::Tensor tInputsInput = gmem_thr_copy_A.partition_S(sInput); // (ACPY,ACPY_M,ACPY_K,Stage) - cute::Tensor tfc1gfc1 = gmem_thr_copy_B.partition_S(gfc1); // (BCPY,BCPY_N,BCPY_K,k) - cute::Tensor tfc1sfc1 = gmem_thr_copy_B.partition_D(sfc1_weight); // (BCPY,BCPY_N,BCPY_K,Stage) - - // Allocate predicate tensors for input and fc weight (actually we only need input predicate tensor) - cute::Tensor tInputpInput - = cute::make_tensor(cute::make_shape(cute::size<1>(tInputsInput), cute::size<2>(tInputsInput)), - cute::Stride{}); - // Construct identity layout for sInput - cute::Tensor cInput = make_identity_tensor( - make_shape(cute::size<0>(sInput), cute::size<1>(sInput))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - - // Repeat the partitioning with identity layouts - cute::Tensor tInputcInput = gmem_thr_copy_A.partition_S(cInput); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - - // Set predicates for m bounds - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < cute::size<0>(tInputpInput); ++m) - { - tInputpInput(m, 0) = cute::get<0>(tInputcInput(0, m, 0)) < residue_m; // blk_m coord < residue_m - } - - // (1.2) prefetch gmem -> smem - cute::clear(tInputsInput); // we don't need to clear tfc1sfc1.. - auto k_tile_iter = cute::make_coord_iterator(cute::size<2>(gInput)); // emm, iter start from 0 - int k_tile_count = cute::size<2>(gInput); - CUTLASS_PRAGMA_UNROLL - for (int k_pipe = 0; k_pipe < KT::Stages - 1; ++k_pipe) - { - if (k_tile_count <= 0) - { - cute::clear(tInputpInput); - } - // cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), - // tInputsInput(cute::_, cute::_, cute::_, k_pipe)); - // use copy_if - cute::copy_if(gmem_tiled_copy_A, tInputpInput, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), - tInputsInput(cute::_, cute::_, cute::_, k_pipe)); - cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter), - tfc1sfc1(cute::_, cute::_, cute::_, k_pipe)); - cute::cp_async_fence(); - k_tile_count--; - if (k_tile_count > 0) - { - ++k_tile_iter; - } - } - - // (1.3) get partition for rf - typename KT::TiledMma tiled_mma; - auto thr_mma = tiled_mma.get_thread_slice(thread_idx); - cute::Tensor tOrInput = thr_mma.partition_fragment_A(sInput(cute::_, cute::_, 0)); // (MMA,MMA_M,MMA_K) - cute::Tensor tOrfc1 = thr_mma.partition_fragment_B(sfc1_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K) - - cute::Tensor accum - = cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N) - cute::clear(accum); - // checkout the shape - CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum)); // MMA_M - CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum)); // MMA_N - CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1)); // MMA_K - CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_A) == cute::size(tiled_mma)); - CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_B) == cute::size(tiled_mma)); - - // (1.4)retiling the smem and rf for copy.. - auto smem_tiled_copy_A = cute::make_tiled_copy_A(typename KT::SmemCopyAtomA{}, tiled_mma); - auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); - cute::Tensor tOsInput = smem_thr_copy_A.partition_S(sInput); // (CPY,CPY_M,CPY_K,Stage) - cute::Tensor tOrInput_copy_view = smem_thr_copy_A.retile_D(tOrInput); // (CPY,CPY_M,CPY_K) - CUTE_STATIC_ASSERT_V(cute::size<1>(tOsInput) == cute::size<1>(tOrInput_copy_view)); // CPY_M - CUTE_STATIC_ASSERT_V(cute::size<2>(tOsInput) == cute::size<2>(tOrInput_copy_view)); // CPY_K - - auto smem_tiled_copy_B = cute::make_tiled_copy_B(typename KT::SmemCopyAtomB{}, tiled_mma); - auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); - cute::Tensor tOsfc1 = smem_thr_copy_B.partition_S(sfc1_weight); // (CPY,CPY_N,CPY_K,Stage) - cute::Tensor tOrfc1_copy_view = smem_thr_copy_B.retile_D(tOrfc1); // (CPY,CPY_N,CPY_K) - CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1) == cute::size<1>(tOrfc1_copy_view)); // CPY_N - CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1) == cute::size<2>(tOrfc1_copy_view)); // CPY_K - - // (1.5) mainloop - // Current pipe index in smem to read from - int smem_pipe_read = 0; - // Current pipe index in smem to write to - int smem_pipe_write = KT::Stages - 1; - - cute::Tensor tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); - cute::Tensor tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); - - constexpr int K_BLOCK_MAX = cute::size<2>(tOrInput); - // prefetch register pipeline - if constexpr (K_BLOCK_MAX > 1) - { - cute::cp_async_wait(); - __syncthreads(); - - // Prefetch the first rmem from the first k-tile - cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, cute::Int<0>{}), - tOrInput_copy_view(cute::_, cute::_, cute::Int<0>{})); - cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, cute::Int<0>{}), - tOrfc1_copy_view(cute::_, cute::_, cute::Int<0>{})); - } - // k loop for mainloop - CUTLASS_PRAGMA_NO_UNROLL - for (; k_tile_count > 0; --k_tile_count) - { - cute::for_each(cute::make_int_sequence{}, - [&](auto k_block) - { - if (k_block == K_BLOCK_MAX - 1) - { - tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); - tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); - cute::cp_async_wait(); - __syncthreads(); - } - // Load A, B shmem->regs for k_block+1 - auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; - cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), - tOrInput_copy_view(cute::_, cute::_, k_block_next)); - cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), - tOrfc1_copy_view(cute::_, cute::_, k_block_next)); - // Copy gmem to smem before computing gemm on each k-pipe - if (k_block == 0) - { - // cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), - // tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write)); - cute::copy_if(gmem_tiled_copy_A, tInputpInput, - tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), - tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write)); - cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter), - tfc1sfc1(cute::_, cute::_, cute::_, smem_pipe_write)); - cute::cp_async_fence(); - if (k_tile_count - 1 > 0) - { - ++k_tile_iter; - } - - // Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe) - smem_pipe_write = smem_pipe_read; - ++smem_pipe_read; - smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read; - } - // Thread-level register gemm for k_block - cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), - accum); - }); - } - // load tail - cute::for_each(cute::make_int_sequence{}, - [&](auto WaitIndex) - { - k_tile_count--; - using WaitIndex_t = decltype(WaitIndex); - cute::for_each(cute::make_int_sequence{}, - [&](auto k_block) - { - if (k_block == K_BLOCK_MAX - 1) - { - tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); - tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); - cute::cp_async_wait(); - __syncthreads(); - } - // Load A, B shmem->regs for k_block+1 - auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; - cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), - tOrInput_copy_view(cute::_, cute::_, k_block_next)); - cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), - tOrfc1_copy_view(cute::_, cute::_, k_block_next)); - if (k_block == 0) - { - // only update smem_pipe_read - ++smem_pipe_read; - smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read; - } - // Thread-level register gemm for k_block - cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), - tOrfc1(cute::_, cute::_, k_block), accum); - }); - }); - // mma tail - cute::for_each(cute::make_int_sequence{}, - [&](auto k_block) - { - // Load A, B shmem->regs for k_block+1 - auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; - cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), - tOrInput_copy_view(cute::_, cute::_, k_block_next)); - cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), - tOrfc1_copy_view(cute::_, cute::_, k_block_next)); - // Thread-level register gemm for k_block - cute::gemm( - tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), accum); - }); - // if (cute::thread0()) { - // cute::print(accum_gate(0, 0, 0)); - // printf("\n"); - // } - // (2) add bias if it has.. - if (params.ptr_bias != nullptr) - { - cute::Tensor gBias = gBias_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx); - cute::Tensor tOgBias = thr_mma.partition_C(gBias); - for (int i = 0; i < cute::size(accum); i++) - { - accum(i) += tOgBias(i); - } - } - // (3) calculate swiglu - using ActivationFn = typename KT::ActivationFn; - ActivationFn fn{}; - CUTLASS_PRAGMA_UNROLL - for (int temp_iter = 0; temp_iter < cute::size(accum); temp_iter++) - { - accum(temp_iter) = fn(accum(temp_iter)); - } - - // (4) push all the result to smem - // (4.1) convert result from ElementAccum to ElementInput - cute::Tensor temp_accum = util_convert_type(accum); - // if (cute::thread0()) { - // cute::print(temp_accum(0, 0, 0)); - // printf("\n"); - // } - // (4.2) retile rf and smem for copy back.. - auto smem_tiled_copy_O = cute::make_tiled_copy_C(typename KT::SmemCopyAtomO{}, tiled_mma); - auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); - // cute::clear(sO); - cute::Tensor taccumrO = smem_thr_copy_O.retile_S(temp_accum); - cute::Tensor taccumsO = smem_thr_copy_O.partition_D(sO); - - // (4.3) copy rf result to smem (TODO: maybe use forloop for better performance..) - cute::copy(smem_tiled_copy_O, taccumrO, taccumsO); - __syncthreads(); - - // (4.4) sO -> rO -> gO - - typename KT::GmemTiledCopyO gmem_tiled_copy_O; - auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); - // auto gmem_thr_copy_Bias = gmem_tiled_copy_O.get_thread_slice(thread_idx % KT::kGmemTrheadsPerRow); // - cute::Tensor gO = gOutput_mn(cute::_, cute::_, block_m_idx, block_n_idx); - auto tOsO = gmem_thr_copy_O.partition_S(sO); - auto tOgO = gmem_thr_copy_O.partition_D(gO); - cute::Tensor cOutput = cute::make_identity_tensor( - cute::make_shape(cute::size<0>(typename KT::TileShape{}), cute::size<1>(typename KT::TileShape{}))); - cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cOutput); - cute::Tensor tOrO = cute::make_tensor(cute::shape(tOgO)); - cute::copy(gmem_tiled_copy_O, tOsO, tOrO); - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < cute::size<1>(tOgO); ++m) - { - if (cute::get<0>(tOcO(0, m, 0)) < residue_m) - { - cute::copy(gmem_tiled_copy_O, tOrO(cute::_, m, cute::_), tOgO(cute::_, m, cute::_)); - } - } - } -}; - -} // namespace fused_moe diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh deleted file mode 100644 index b4c90085dbb..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh +++ /dev/null @@ -1,215 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include - -namespace fused_moe -{ -template -struct Routine_Arguments -{ - ElementInput* ptr_input{}; - ElementWeight* ptr_fc1{}; - ElementInput* ptr_bias{}; - ElementOutput* ptr_output{}; - int64_t const* total_tokens_including_expert{}; - int gemm_n{}; - int gemm_k{}; - int num_expert{}; - bool bias_is_broadcast{}; -}; - -template -struct Routine_Params -{ - ElementInput* ptr_input{}; - ElementWeight* ptr_fc1{}; - ElementInput* ptr_bias{}; - ElementOutput* ptr_output{}; - int64_t const* total_tokens_including_expert{}; - int gemm_n{}; - int gemm_k{}; - int num_expert{}; - bool bias_is_broadcast{}; -}; - -enum class Activation_Type -{ - Gelu = 0, - Relu, - Silu, - Swiglu, - Geglu, - Identity, - InvalidType -}; - -constexpr bool isGateActivation(Activation_Type const& activation_type) -{ - return activation_type == Activation_Type::Swiglu || activation_type == Activation_Type::Geglu; -} - -template -constexpr Activation_Type EpilogueRouting(bool /*is_gate*/) -{ - return Activation_Type::InvalidType; -} - -template <> -constexpr Activation_Type EpilogueRouting(bool /*is_gate*/) -{ - return Activation_Type::Identity; -} - -template <> -constexpr Activation_Type EpilogueRouting(bool /*is_gate*/) -{ - return Activation_Type::Relu; -} - -template <> -constexpr Activation_Type EpilogueRouting(bool is_gate) -{ - return is_gate ? Activation_Type::Swiglu : Activation_Type::Silu; -} - -template <> -constexpr Activation_Type EpilogueRouting(bool is_gate) -{ - return is_gate ? Activation_Type::Geglu : Activation_Type::Gelu; -} - -/* fusing all three kernels has many limitations. This is the simpler version. Just fuse first two kernels..*/ -template -struct Fused_Moe_Kernel_traits_sm80 -{ - using ElementInput = ElementInput_; - using ElementWeight = ElementWeight_; - using ElementAccum = float; - using ElementOutput = ElementOutput_; - - using index_t = uint32_t; - static_assert(TileM_ % 16 == 0); - static_assert(TileN_ % 32 == 0); - static_assert(TileK_ % 32 == 0); - static constexpr int Stages = Stages_; - static constexpr int kTileM = TileM_; - static constexpr int kTileN = TileN_; - static constexpr int kTileK = (kTileM > 16) ? (TileK_) : (TileK_ >= 64 ? TileK_ : 64); - - // tile shape - using TileShape = cute::Shape, cute::Int, cute::Int>; - static constexpr int kWarpsCount = 4; - static constexpr int kThreadCount = kWarpsCount * 32; - - // MMA atom arch and layout - using MMA_Atom_Arch = std::conditional_t, - cute::MMA_Atom, cute::MMA_Atom>; - // using ValLayoutMNK = cute::Layout>; - using ThreadLayoutMNK - = std::conditional_t, cute::_1>>, - cute::Layout, cute::_1>>>; - using ValLayoutMNK = std::conditional_t, - cute::Tile>; - using TiledMma = cute::TiledMMA; // 32x32x16 or 16x64x16 MMA for LDSM if kWarp = 4 - static constexpr int kAlignment = 8; - static constexpr int kBlcokKSmem = (kTileM == 16) ? 64 : 32; - // A memory copy operand - using DefaultOperandA - = DefaultGemm_TensorOpSm80_OperandA; - using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; - using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; - using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; - - // B memory copy operand - using DefaultOperandB - = DefaultGemm_TensorOpSm80_OperandB; - using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; - using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; - using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; - - // Output memory copy operand - using SmemLayoutAtomO = SmemLayoutAtomA; - using SmemCopyAtomO = cute::Copy_Atom; - static constexpr int kGmemElementPerLoad = sizeof(cute::uint128_t) / sizeof(ElementOutput); - static constexpr int kGmemTrheadsPerRow = kBlcokKSmem / kGmemElementPerLoad; - using GmemLayoutAtomO - = cute::Layout, cute::Int>, - cute::Stride, cute::_1>>; - using GmemTiledCopyO = decltype(cute::make_tiled_copy(cute::Copy_Atom{}, - GmemLayoutAtomO{}, cute::Layout>{})); - - static_assert(cute::rank(SmemLayoutAtomA{}) == 2); - static_assert(cute::size<0>(TileShape{}) % cute::size<0>(SmemLayoutAtomA{}) == 0); // M - static_assert(cute::size<2>(TileShape{}) % cute::size<1>(SmemLayoutAtomA{}) == 0); // K - static_assert(cute::rank(SmemLayoutAtomB{}) == 2); - static_assert(cute::size<1>(TileShape{}) % cute::size<0>(SmemLayoutAtomB{}) == 0); // N - static_assert(cute::size<2>(TileShape{}) % cute::size<1>(SmemLayoutAtomB{}) == 0); // K - - using SmemLayoutA = decltype(cute::tile_to_shape(SmemLayoutAtomA{}, - cute::make_shape( - cute::shape<0>(TileShape{}), cute::shape<2>(TileShape{}), cute::Int{}))); // BLK_M, BLK_K, Stages - using SmemLayoutB = decltype(cute::tile_to_shape(SmemLayoutAtomB{}, - cute::make_shape( - cute::shape<1>(TileShape{}), cute::shape<2>(TileShape{}), cute::Int{}))); // BLK_N, BLK_K, Stages - using SmemLayoutO = decltype(cute::tile_to_shape( - SmemLayoutAtomO{}, cute::make_shape(cute::shape<0>(TileShape{}), cute::shape<1>(TileShape{})))); // BLK_M, BLK_N - - // we need at least 2 stages.. - static_assert(Stages >= 2); - - struct SharedStorageNormal : cute::aligned_struct<128> - { - cute::array_aligned> smem_input; - cute::array_aligned> smem_fc1_weight; - cute::array_aligned> smem_o; - }; - - struct SharedStorageGate : cute::aligned_struct<128> - { - cute::array_aligned> smem_input; - cute::array_aligned> smem_fc1_gate_weight; - cute::array_aligned> smem_fc1_weight; - cute::array_aligned> smem_o; - }; - - using SharedStorage = std::conditional_t; - - using ActivationFn = std::conditional_t, - std::conditional_t, - std::conditional_t, cutlass::epilogue::thread::Identity>>>; - - static constexpr int kSmemSize = static_cast(sizeof(SharedStorage)); - - static constexpr bool can_implement(int const avaliable_smem_size) - { - return avaliable_smem_size > kSmemSize; - } - - // #endif -}; -} // namespace fused_moe diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h deleted file mode 100644 index 80a4d856085..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h +++ /dev/null @@ -1,73 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*! \file - \brief Scheduler for grouped GEMM -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" -#include "cutlass/matrix_coord.h" - -#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h" -#include "cutlass_extensions/gemm/kernel/moe_problem_visitor.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ - -/// Visitor class to abstract away the algorithm for iterating over tiles -template -struct GemmMoeProblemVisitor - : public MoeProblemVisitor, ThreadblockShape, - GroupScheduleMode_, PrefetchTileCount, ThreadCount> -{ - - static bool const kTransposed = Transposed; - - using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper; - using Base - = MoeProblemVisitor; - using Params = typename Base::Params; - using SharedStorage = typename Base::SharedStorage; - - // - // Methods - // - CUTLASS_DEVICE - GemmMoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx) - : Base(params_, shared_storage_, block_idx) - { - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp deleted file mode 100644 index 3a084ee04fb..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp +++ /dev/null @@ -1,70 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/gemm/kernel/tile_scheduler.hpp" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::kernel -{ - -//////////////////////////////////////////////////////////////////////////////// - -/* - * Stateless universal device GEMM kernel type that treats GEMM as - * a composition of a collective mainloop and a collective epilogue. - * - * Supports both the 2.x and 3.x APIs based on whether the first type is - * a cute::tuple<> or not. - * 2.x API implementation: cutlass/gemm/kernel/gemm_universal.h - * 3.x API implementation: cutlass/gemm/kernel/gemm_*.hpp - * - * In the following declaration, the name preceding the 'Or' refers to - * 3.x API type argument order, and the name succeeding the 'Or' refers to - * 2.x API type argument order. Template arguments without two names - * belong to the 3.x API only. - **/ -template -class GemmUniversalGated; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::kernel - -//////////////////////////////////////////////////////////////////////////////// - -#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp" -#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp" -//////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h deleted file mode 100644 index 0650ca8ded4..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h +++ /dev/null @@ -1,585 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief GEMM kernel to support the epilogue visitor model - for customized softmax partial reduction epilogue fusion. - - This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once - its usage has been stabilized. For now, it is included in this example to demonstrate - some basic output fusion options. - - original file: 3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h -*/ - -#pragma once - -#include "cutlass/complex.h" -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/semaphore.h" -#include "cutlass/trace.h" - -#include "cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h" - -namespace tk = tensorrt_llm::common; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct GemmWithEpilogueVisitor -{ -public: - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueVisitor = typename Epilogue::Visitor; - using ThreadblockSwizzle = ThreadblockSwizzle_; - - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using TensorRefA = TensorRef; - - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using TensorRefB = TensorRef; - - using ElementCompute = typename EpilogueVisitor::ElementCompute; - using LayoutAlphaCol = cutlass::layout::RowMajor; - using LayoutAlphaRow = cutlass::layout::ColumnMajor; - using TensorRefAlphaCol = TensorRef; - using TensorRefAlphaRow = TensorRef; - - using ElementC = typename EpilogueVisitor::ElementOutput; - using LayoutC = typename Epilogue::Layout; - using TensorRefC = TensorRef; - - static ComplexTransform const kTransformA = Mma::kTransformA; - static ComplexTransform const kTransformB = Mma::kTransformB; - using Operator = typename Mma::Operator; - - using OperatorClass = typename Mma::Operator::OperatorClass; - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; - using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; - using EpilogueOutputOp = - typename Epilogue::Visitor::ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain - - static int const kStages = Mma::kStages; - static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; - static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - /// Split-K preserves splits that are 128b aligned - static int const kSplitKAlignment - = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); - - // - // Structures - // - - /// Argument structure - struct Arguments - { - - // - // Data members - // - - GemmUniversalMode mode; - GemmCoord problem_size; - int batch_count; - - TensorRefA ref_A; - TensorRefB ref_B; - tk::QuantMode quant_option; - TensorRefAlphaCol ref_alpha_col; - TensorRefAlphaRow ref_alpha_row; - TensorRefC ref_C; - TensorRefC ref_D; - - int64_t batch_stride_A; - int64_t batch_stride_B; - int64_t batch_stride_D; - - typename EpilogueVisitor::Arguments epilogue_visitor; - - // - // Methods - // - - Arguments() - : mode(GemmUniversalMode::kGemm) - , batch_count(1) - { - } - - /// constructs an arguments structure - Arguments(GemmUniversalMode mode_, GemmCoord problem_size_, int batch_count_, TensorRefA ref_A_, - TensorRefB ref_B_, tk::QuantMode quant_option_, TensorRefAlphaCol ref_alpha_col_, - TensorRefAlphaRow ref_alpha_row_, TensorRefC ref_C_, TensorRefC ref_D_, int64_t batch_stride_A_, - int64_t batch_stride_B_, typename EpilogueVisitor::Arguments epilogue_visitor_) - : mode(mode_) - , problem_size(problem_size_) - , batch_count(batch_count_) - , ref_A(ref_A_) - , ref_B(ref_B_) - , quant_option(quant_option_) - , ref_alpha_col(ref_alpha_col_) - , ref_alpha_row(ref_alpha_row_) - , ref_C(ref_C_) - , ref_D(ref_D_) - , batch_stride_A(batch_stride_A_) - , batch_stride_B(batch_stride_B_) - , batch_stride_D(0) - , epilogue_visitor(epilogue_visitor_) - { - } - }; - - // - // Structure for precomputing values in host memory and passing to kernels - // - - /// Parameters structure - struct Params - { - - cutlass::gemm::GemmCoord problem_size; - cutlass::gemm::GemmCoord grid_tiled_shape; - int swizzle_log_tile; - - typename Mma::IteratorA::Params params_A; - typename Mma::IteratorB::Params params_B; - typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col; - typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row; - typename EpilogueVisitor::OutputTileIterator::Params params_C; - typename EpilogueVisitor::OutputTileIterator::Params params_D; - - GemmUniversalMode mode; - int batch_count; - int gemm_k_size; - - void* ptr_A; - void* ptr_B; - tk::QuantMode quant_option; - typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_col; - typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_row; - ElementC* ptr_C; - ElementC* ptr_D; - - int64_t batch_stride_A; - int64_t batch_stride_B; - - typename EpilogueVisitor::Params epilogue_visitor; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params() - : swizzle_log_tile(0) - , params_A(0) - , params_B(0) - , params_alpha_col(0) - , params_C(0) - , params_D(0) - , batch_count(0) - , gemm_k_size(0) - , mode(cutlass::gemm::GemmUniversalMode::kGemm) - , ptr_A(nullptr) - , ptr_B(nullptr) - , ptr_alpha_col(nullptr) - , ptr_alpha_row(nullptr) - , ptr_C(nullptr) - , ptr_D(nullptr) - , batch_stride_A(0) - , batch_stride_B(0) - { - } - - Params( - Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape_, int gemm_k_size_, int* workspace_) - : problem_size(args.problem_size) - , swizzle_log_tile(0) - , params_A(args.ref_A.layout()) - , params_B(args.ref_B.layout()) - , params_alpha_col(args.ref_alpha_col.layout()) - , params_alpha_row(args.ref_alpha_col.layout()) - , params_C(args.ref_C.layout()) - , params_D(args.ref_D.layout()) - , mode(args.mode) - , batch_count(args.batch_count) - , gemm_k_size(args.problem_size.k()) - , ptr_A(args.ref_A.data()) - , ptr_B(args.ref_B.data()) - , quant_option(args.quant_option) - , ptr_alpha_col(args.ref_alpha_col.data()) - , ptr_alpha_row(args.ref_alpha_row.data()) - , ptr_C(args.ref_C.data()) - , ptr_D(args.ref_D.data()) - , batch_stride_A(args.batch_stride_A) - , batch_stride_B(args.batch_stride_B) - , epilogue_visitor(args.epilogue_visitor) - { - - ThreadblockSwizzle threadblock_swizzle; - - grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.problem_size, - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); - - if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) - { - - int const kAlignK - = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); - - gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); - - if (gemm_k_size) - { - grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); - } - } - - swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); - } - }; - - /// Shared memory storage structure - union SharedStorage - { - - typename Mma::SharedStorage main_loop; - - struct - { - typename Epilogue::SharedStorage epilogue; - typename EpilogueVisitor::SharedStorage visitor; - } epilogue; - }; - -public: - // - // Methods - // - - CUTLASS_DEVICE - GemmWithEpilogueVisitor() {} - - /// Determines whether kernel satisfies alignment - static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) - { - - CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()"); - - static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; - static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; - static int const kAlignmentC = EpilogueVisitor::OutputTileIterator::kElementsPerAccess; - - bool isAMisaligned = false; - bool isBMisaligned = false; - bool isCMisaligned = false; - - if (platform::is_same::value) - { - isAMisaligned = problem_size.k() % kAlignmentA; - } - else if (platform::is_same::value) - { - isAMisaligned = problem_size.m() % kAlignmentA; - } - else if (platform::is_same>::value - || platform::is_same>::value) - { - isAMisaligned = problem_size.k() % kAlignmentA; - } - - if (platform::is_same::value) - { - isBMisaligned = problem_size.n() % kAlignmentB; - } - else if (platform::is_same::value) - { - isBMisaligned = problem_size.k() % kAlignmentB; - } - else if (platform::is_same>::value - || platform::is_same>::value) - { - isBMisaligned = problem_size.k() % kAlignmentB; - } - - if (platform::is_same::value) - { - isCMisaligned = problem_size.n() % kAlignmentC; - } - else if (platform::is_same::value) - { - isCMisaligned = problem_size.m() % kAlignmentC; - } - else if (platform::is_same>::value - || platform::is_same>::value) - { - isCMisaligned = problem_size.n() % kAlignmentC; - } - - if (isAMisaligned) - { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); - return Status::kErrorMisalignedOperand; - } - - if (isBMisaligned) - { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); - return Status::kErrorMisalignedOperand; - } - - if (isCMisaligned) - { - CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); - return Status::kErrorMisalignedOperand; - } - - CUTLASS_TRACE_HOST(" returning kSuccess"); - - return Status::kSuccess; - } - - static Status can_implement(Arguments const& args) - { - return can_implement(args.problem_size); - } - - static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) - { - - return 0; - } - -#define SPLIT_K_ENABLED 1 - - /// Executes one GEMM - CUTLASS_DEVICE - void run_kernel_(Params const& params, SharedStorage& shared_storage) - { - - // Compute threadblock location - ThreadblockSwizzle threadblock_swizzle; - - cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // Early exit if CTA is out of range - if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() - || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) - { - - return; - } - - int offset_k = 0; - int problem_size_k = params.problem_size.k(); - - ElementA* ptr_A = static_cast(params.ptr_A); - ElementB* ptr_B = static_cast(params.ptr_B); - -#if SPLIT_K_ENABLED - // - // Fetch pointers based on mode. - // - if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel) - { - - if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) - { - - problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; - } - - offset_k = threadblock_tile_offset.k() * params.gemm_k_size; - } - else if (params.mode == GemmUniversalMode::kBatched) - { - ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; - ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; - } - else if (params.mode == GemmUniversalMode::kArray) - { - ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; - ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; - } -#endif - - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_tile_offset.m() * Mma::Shape::kM, - offset_k, - }; - - cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN}; - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A); - - typename Mma::IteratorB iterator_B( - params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - - int lane_idx = threadIdx.x % 32; - - // - // Main loop - // - - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); - - // - // Masked tile iterators constructed from members - // - - threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - // assume identity swizzle - MatrixCoord threadblock_offset( - threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN); - - int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); - - // - // Construct the epilogue visitor - // - - EpilogueVisitor epilogue_visitor(params.epilogue_visitor, shared_storage.epilogue.visitor, - params.problem_size.mn(), thread_idx, warp_idx, lane_idx, params.params_alpha_col, params.params_C, - params.params_D, params.quant_option, params.ptr_alpha_row, params.ptr_alpha_col, params.ptr_C, - params.ptr_D, threadblock_offset, blockIdx.y * params.problem_size.m()); - - if (params.mode == GemmUniversalMode::kGemm) - { - // Indicate which position in a serial reduction the output operator is currently updating - epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); - } - else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) - { - epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); - } - - // Construct the epilogue - Epilogue epilogue(shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx); - - // Execute the epilogue operator to update the destination tensor. - epilogue(epilogue_visitor, accumulators); - } - - template - CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) - { - if constexpr (platform::is_same::value) - { - run_kernel_(params, shared_storage); - } - else - { - CUTLASS_NOT_IMPLEMENTED(); - } - } - - /* - To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond - to the ArchTag of the cutlass kernel operator. - */ - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const& params, SharedStorage& shared_storage) - { -#if defined(__CUDA_ARCH__) -#if (__CUDA_ARCH__ >= 720) && (__CUDA_ARCH__ < 750) - run_kernel(params, shared_storage); -#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) - run_kernel(params, shared_storage); -#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) - run_kernel(params, shared_storage); -#elif (__CUDA_ARCH__ >= 900) - // TODO - replace with CUTLASS_NOT_IMPLEMENTED() and upgrade to 3.x kernels. - run_kernel(params, shared_storage); -#else - static_assert( - false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); -#endif -#else - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h deleted file mode 100644 index 6dc6ffc1a9f..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h +++ /dev/null @@ -1,143 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -/* - This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is - quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices - to be consumed by CUTLASS. - - Note that for int4, ThreadBlockK MUST be 64. - - */ - -#pragma once - -#include "cutlass/layout/matrix.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/arch/arch.h" -#include "cutlass/arch/mma.h" -#include "cutlass/platform/platform.h" - -#include "cutlass_extensions/arch/mma.h" -#include "cutlass_extensions/tile_interleaved_layout.h" - -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ - -template -struct LayoutDetailsB -{ -}; - -// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks. -// TODO - Switch this to column major for weights since gemms should be more performant. -template -struct LayoutDetailsB= 75>::type> -{ - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - using Layout = layout::ColumnMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; -}; - -template -struct LayoutDetailsB= 75>::type> -{ - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - using Layout = layout::ColumnMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; -}; - -template -struct LayoutDetailsB -{ - static constexpr int ThreadblockK = 64; - -private: - static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; - static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; - -public: - using Layout = layout::ColumnMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; - // for fast accumulation - // using Operator = cutlass::arch::OpMultiplyAddFastAccum; -}; - -// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA, -// which signals that we want to dequantize after loading from smem. -template - struct LayoutDetailsB < TypeA, - uint8_t, Arch, - typename platform::enable_if= 75 && Arch::kMinComputeCapability<90>::type> -{ - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - -private: - static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; - static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; - -public: - using Layout = layout::ColumnMajorTileInterleave; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; -}; - -template - struct LayoutDetailsB < TypeA, - uint4b_t, Arch, - typename platform::enable_if= 75 && Arch::kMinComputeCapability<90>::type> -{ - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - -private: - static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; - static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; - -public: - using Layout = layout::ColumnMajorTileInterleave; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; -}; - -template -struct LayoutDetailsB= 90>::type> -{ - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - using Layout = layout::ColumnMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; -}; - -template -struct LayoutDetailsB= 90>::type> -{ - static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; - using Layout = layout::ColumnMajor; - static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; - using Operator = cutlass::arch::OpMultiplyAdd; -}; - -} // namespace kernel -} // namespace gemm -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cute_util.cuh b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cute_util.cuh deleted file mode 100644 index aac2cb35799..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cute_util.cuh +++ /dev/null @@ -1,185 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include -#include -#include -#include -#include - -template -struct DefaultGemm_TensorOpSm80_OperandA; - -template -struct DefaultGemm_TensorOpSm80_OperandB; - -template <> -struct DefaultGemm_TensorOpSm80_OperandA -{ - // Smem - using SmemLayoutAtom = decltype(cute::composition( - cute::Swizzle<3, 3, 3>{}, cute::Layout, cute::Stride>{})); - using SmemCopyAtom = cute::Copy_Atom; - - // Gmem - using GmemTiledCopy = decltype(cute::make_tiled_copy( - cute::Copy_Atom, cute::half_t>{}, - cute::Layout, cute::Stride>{}, - cute::Layout>{})); -}; - -template <> -struct DefaultGemm_TensorOpSm80_OperandA -{ - // Smem - using SmemLayoutAtom = decltype(cute::composition( - cute::Swizzle<3, 3, 3>{}, cute::Layout, cute::Stride>{})); - using SmemCopyAtom = cute::Copy_Atom; - - // Gmem - using GmemTiledCopy = decltype(cute::make_tiled_copy( - cute::Copy_Atom, cute::bfloat16_t>{}, - cute::Layout, cute::Stride>{}, - cute::Layout>{})); -}; - -/// Operand A - Column-major (M-major) -template -struct DefaultGemm_TensorOpSm80_OperandA -{ - // Smem - using SmemLayoutAtom = decltype(cute::composition( - cute::Swizzle<3, 3, 3>{}, cute::Layout, cute::Stride>{})); - using SmemCopyAtom = cute::Copy_Atom; - - // Gmem - using GmemTiledCopy = decltype(cute::make_tiled_copy( - cute::Copy_Atom, cute::half_t>{}, - cute::Layout, cute::Stride>{}, - cute::Layout>{})); -}; - -template -struct DefaultGemm_TensorOpSm80_OperandA -{ - // Smem - using SmemLayoutAtom = decltype(cute::composition( - cute::Swizzle<3, 3, 3>{}, cute::Layout, cute::Stride>{})); - using SmemCopyAtom = cute::Copy_Atom; - - // Gmem - using GmemTiledCopy = decltype(cute::make_tiled_copy( - cute::Copy_Atom, cute::bfloat16_t>{}, - cute::Layout, cute::Stride>{}, - cute::Layout>{})); -}; - -// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the DefaultOperands - -// Operand B - Column-Major (K-major) -template -struct DefaultGemm_TensorOpSm80_OperandB - : DefaultGemm_TensorOpSm80_OperandA -{ -}; - -template -struct DefaultGemm_TensorOpSm80_OperandB - : DefaultGemm_TensorOpSm80_OperandA -{ -}; - -// Operand B - Row-Major (N-major) -template -struct DefaultGemm_TensorOpSm80_OperandB - : DefaultGemm_TensorOpSm80_OperandA -{ -}; - -template -struct DefaultGemm_TensorOpSm80_OperandB - : DefaultGemm_TensorOpSm80_OperandA -{ -}; - -// -// F16: 128-by-128-by-32 (small k-block) -// - -/// Operand A - Row-major (K-Major) -template <> -struct DefaultGemm_TensorOpSm80_OperandA -{ - // Smem - using SmemLayoutAtom = decltype(cute::composition( - cute::Swizzle<2, 3, 3>{}, cute::Layout, cute::Stride>{})); - using SmemCopyAtom = cute::Copy_Atom; - - // Gmem - using GmemTiledCopy = decltype(cute::make_tiled_copy( - cute::Copy_Atom, cute::half_t>{}, - cute::Layout, cute::Stride>{}, - cute::Layout>{})); -}; - -template <> -struct DefaultGemm_TensorOpSm80_OperandA -{ - // Smem - using SmemLayoutAtom = decltype(cute::composition( - cute::Swizzle<2, 3, 3>{}, cute::Layout, cute::Stride>{})); - using SmemCopyAtom = cute::Copy_Atom; - - // Gmem - using GmemTiledCopy = decltype(cute::make_tiled_copy( - cute::Copy_Atom, cute::bfloat16_t>{}, - cute::Layout, cute::Stride>{}, - cute::Layout>{})); -}; - -template -CUTE_DEVICE auto util_convert_type(cute::Tensor const& tensor) -{ - using From_type = typename Engine::value_type; - constexpr int numel = decltype(cute::size(tensor))::value; - cutlass::NumericArrayConverter convert_op; - // HACK: this requires tensor to be "contiguous" - auto frag = convert_op(*reinterpret_cast const*>(tensor.data())); - return cute::make_tensor(cute::make_rmem_ptr(&frag), tensor.layout()); -} - -template -CUTE_DEVICE void util_copy( - TiledCopy const& tiled_copy, cute::Tensor const& S, cute::Tensor& D) -{ - CUTE_STATIC_ASSERT_V(cute::rank(S) == cute::Int<3>{}); - CUTE_STATIC_ASSERT_V(cute::rank(D) == cute::Int<3>{}); - CUTE_STATIC_ASSERT_V(cute::size<0>(S) == cute::size<0>(D)); - CUTE_STATIC_ASSERT_V(cute::size<1>(S) == cute::size<1>(D)); - CUTE_STATIC_ASSERT_V(cute::size<2>(S) == cute::size<2>(D)); - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < cute::size<1>(S); ++m) - { - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < cute::size<2>(S); ++k) - { - cute::copy(tiled_copy, S(cute::_, m, k), D(cute::_, m, k)); - } - } -} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h deleted file mode 100644 index b708f7c28b5..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h +++ /dev/null @@ -1,553 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*! \file - \brief -*/ - -#pragma once - -#include "cutlass/complex.h" -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/semaphore.h" - -#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/trace.h" - -#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h" -#include "cutlass_extensions/tile_interleaved_layout.h" - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// -// This section exists to that we can use the same kernel code for regular gemm and dequantizing gemms. -// It will dispatch to the dequantizing gemm if the Mma type has an Iterator for scales in global. -template -using void_t = void; - -template -struct use_dq_gemm : platform::false_type -{ -}; - -template -struct use_dq_gemm> : platform::true_type -{ -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MoeFCGemm -{ -public: - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; - static bool const kTransposed = false; - - // Optional transpose - using MapArguments = kernel::detail::MapArguments; - - // Public-facing type definitions related to operand element type, layout, and complex conjugate - // operation. Must interact with the 'kTransposed' notion. - static_assert(!kTransposed, "Transpose problem not supported"); - using ElementA = typename MapArguments::ElementA; - using LayoutA = typename MapArguments::LayoutA; - using ElementB = typename MapArguments::ElementB; - using LayoutB = typename MapArguments::LayoutB; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename MapArguments::LayoutC; - using ElementScale = ElementC; - - static ComplexTransform const kTransformA = MapArguments::kTransformA; - static ComplexTransform const kTransformB = MapArguments::kTransformB; - - // Type definitions about the mainloop. - using Operator = typename Mma::Operator; - using OperatorClass = typename Mma::Operator::OperatorClass; - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; - using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; - - static int const kStages = Mma::kStages; - static int const kAlignmentA = MapArguments::kAlignmentA; - static int const kAlignmentB = MapArguments::kAlignmentB; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - using ProblemVisitor - = GemmMoeProblemVisitor; - - // - // Structures - // - - /// Argument structure - struct Arguments - { - - // - // Data members - // - - int problem_count; - int threadblock_count; - int group_size; - - typename EpilogueOutputOp::Params output_op; - - ElementA* ptr_A; - ElementB* ptr_B; - ElementScale* weight_scales; - ElementC* ptr_C; - ElementC* ptr_D; - bool C_is_broadcast; - - int64_t const* total_tokens_including_expert; - int64_t gemm_n; - int64_t gemm_k; - - // Only used by device-level operator - GemmCoord* host_problem_sizes; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - Arguments() - : problem_count(0) - , threadblock_count(0) - , ptr_A(nullptr) - , ptr_B(nullptr) - , weight_scales(nullptr) - , ptr_C(nullptr) - , ptr_D(nullptr) - , total_tokens_including_expert(nullptr) - , gemm_n(0) - , gemm_k(0) - , host_problem_sizes(nullptr) - , C_is_broadcast{true} - { - } - - /// Ctor - CUTLASS_HOST_DEVICE - Arguments(int problem_count, int threadblock_count, int group_size, typename EpilogueOutputOp::Params output_op, - ElementA const* ptr_A, ElementB const* ptr_B, ElementScale const* weight_scales, ElementC const* ptr_C, - bool C_is_broadcast, ElementC* ptr_D, int64_t const* total_tokens_including_expert, int64_t gemm_n, - int64_t gemm_k, GemmCoord* host_problem_sizes = nullptr) - : problem_count(problem_count) - , threadblock_count(threadblock_count) - , group_size(group_size) - , output_op(output_op) - , ptr_A(const_cast(ptr_A)) - , ptr_B(const_cast(ptr_B)) - , weight_scales(const_cast(weight_scales)) - , ptr_C(const_cast(ptr_C)) - , C_is_broadcast{C_is_broadcast} - , ptr_D(ptr_D) - , total_tokens_including_expert(total_tokens_including_expert) - , gemm_n(gemm_n) - , gemm_k(gemm_k) - , host_problem_sizes(nullptr) - { - if (platform::is_same::value || platform::is_same::value) - { - assert(weight_scales); - } - } - }; - - // - // Structure for precomputing values in host memory and passing to kernels - // - - /// Parameters structure - struct Params - { - - typename ProblemVisitor::Params problem_visitor; - int threadblock_count; - int group_size; - bool C_is_broadcast; - - typename EpilogueOutputOp::Params output_op; - - ElementA* ptr_A; - ElementB* ptr_B; - ElementScale* weight_scales; - ElementC* ptr_C; - ElementC* ptr_D; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params() - : ptr_A(nullptr) - , ptr_B(nullptr) - , weight_scales(nullptr) - , ptr_C(nullptr) - , ptr_D(nullptr) - , C_is_broadcast(true) - { - } - - CUTLASS_HOST_DEVICE - Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) - : problem_visitor( - args.total_tokens_including_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, tile_count) - , threadblock_count(args.threadblock_count) - , group_size(args.group_size) - , output_op(args.output_op) - , ptr_A(args.ptr_A) - , ptr_B(args.ptr_B) - , weight_scales(args.weight_scales) - , ptr_C(args.ptr_C) - , ptr_D(args.ptr_D) - , C_is_broadcast(args.C_is_broadcast) - { - } - - CUTLASS_HOST_DEVICE - void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) - { - - problem_visitor = typename ProblemVisitor::Params(args.total_tokens_including_expert, args.gemm_n, - args.gemm_k, args.problem_count, workspace, tile_count); - threadblock_count = args.threadblock_count; - output_op = args.output_op; - ptr_A = args.ptr_A; - ptr_B = args.ptr_B; - weight_scales = args.weight_scales; - ptr_C = args.ptr_C; - ptr_D = args.ptr_D; - C_is_broadcast = args.C_is_broadcast; - } - }; - - /// Shared memory storage structure - union SharedStorage - { - typename ProblemVisitor::SharedStorage problem_visitor; - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - }; - -public: - // - // Methods - // - - CUTLASS_DEVICE - MoeFCGemm() {} - - /// Determines whether kernel satisfies alignment - static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) - { - return Status::kSuccess; - } - - static Status can_implement(Arguments const& args) - { - if (platform::is_same::value || platform::is_same::value) - { - if (args.weight_scales == nullptr) - { - CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - weight scales are required for uint8_t and uint4b_t"); - return Status::kInvalid; - } - } - else if (args.weight_scales != nullptr) - { - CUTLASS_TRACE_HOST( - "MoeFCGemm::can_implement() - weight scales are ignored for all types except uint8_t and uint4b_t"); - return Status::kInvalid; - } - else if (args.group_size != args.gemm_k) - { - CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - scale shape should be (1, gemm_n)"); - return Status::kInvalid; - } - // Handle the case the input is too short - else if (args.gemm_n < Mma::IteratorB::AccessType::kElements) - { - CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - gemm_n is smaller than the input alignment"); - return Status::kInvalid; - } - return Status::kSuccess; - } - - static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) - { - - return 0; - } - - CUTLASS_DEVICE - void run_kernel_(Params const& params, SharedStorage& shared_storage) - { - // - // These types shadow the type-level definitions and support the ability to implement - // a 'transposed' GEMM that computes the transposed problems. - // - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename Epilogue::OutputTileIterator::Layout; - static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; - static_assert(platform::is_same::value && kInterleave == 1 - || platform::is_same::value && kInterleave >= 1, - "B must be row major/col major OR col major interleaved."); - - // - // Problem visitor. - // - ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); - - const int64_t gemm_k = params.problem_visitor.gemm_k; - const int64_t gemm_n = params.problem_visitor.gemm_n; - int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; - - // Outer 'persistent' loop to iterate over tiles - int loop = 0; - while (problem_visitor.next_tile()) - { - loop++; - - GemmCoord problem_size = problem_visitor.problem_size(); - int32_t problem_idx = problem_visitor.problem_index(); - int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); - - GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); - - cutlass::gemm::GemmCoord threadblock_offset( - int(cta_idx / grid_shape.n()) * Mma::Shape::kM, int(cta_idx % grid_shape.n()) * Mma::Shape::kN, 0); - - // Load element pointers. Exchange pointers and strides if working on the transpose - const int64_t rows_to_jump - = problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1]; - ElementA* ptr_A = reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; - typename LayoutA::LongIndex ldm_A = gemm_k; - - char* byte_ptr_B = ((char*) params.ptr_B) + problem_idx * bytes_per_expert_matrix; - ElementB* ptr_B = reinterpret_cast(byte_ptr_B); - typename LayoutB::LongIndex ldm_B - = platform::is_same::value ? gemm_n : gemm_k * kInterleave; - - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_offset.m(), - 0, - }; - - cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; - - cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()}; - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, tb_offset_A); - - typename Mma::IteratorB iterator_B(LayoutB(ldm_B), ptr_B, - {problem_size.k() * kInterleave, problem_size.n() / kInterleave}, thread_idx, tb_offset_B); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - - int lane_idx = threadIdx.x % 32; - - // - // Matrix multiply phase - // - - // Construct thread-scoped matrix multiply - auto CreateMMA = [&]() - { - if constexpr (use_dq_gemm::value) - return Mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx); - else - return Mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); - }; - Mma mma = CreateMMA(); - - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Wait for all threads to finish their epilogue phases from the previous tile. - __syncthreads(); - - // Compute threadblock-scoped matrix multiply-add - ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * problem_size.n(); - - if constexpr (use_dq_gemm::value) - { - const MatrixCoord scale_extent = {1, problem_size.n()}; - typename Mma::IteratorScale iterator_scale(Mma::IteratorScale::Layout(scale_extent.column()), - weight_scale_ptr, scale_extent, thread_idx, tb_offset_scale); - - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators); - } - else - { - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); - } - - // - // Epilogue - // - - ElementC* ptr_C = reinterpret_cast(params.ptr_C) - + (params.C_is_broadcast ? problem_idx : rows_to_jump) * gemm_n; - ElementC* ptr_D = reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; - - // lora need to set as layout_C(gemm_n) - LayoutC layout_C = params.C_is_broadcast ? LayoutC(0) : LayoutC(gemm_n); - LayoutC layout_D(gemm_n); - - typename Epilogue::OutputTileIterator::Params params_C(layout_C); - typename Epilogue::OutputTileIterator::Params params_D(layout_D); - - // Tile iterator loading from source tensor. - typename Epilogue::OutputTileIterator iterator_C( - params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset.mn()); - - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D( - params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset.mn()); - - Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); - - // Execute the epilogue operator to update the destination tensor. - if constexpr (platform::is_same>::value) - { - EpilogueOutputOp output_op(params.output_op, problem_idx); - epilogue(output_op, iterator_D, accumulators, iterator_C); - } - else - { - EpilogueOutputOp output_op(params.output_op); - epilogue(output_op, iterator_D, accumulators, iterator_C); - } - - // Next tile - problem_visitor.advance(gridDim.x); - } - } - - template - CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) - { - if constexpr (platform::is_same::value) - { - run_kernel_(params, shared_storage); - } - else - { - CUTLASS_NOT_IMPLEMENTED(); - } - } - - /* - To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond - to the ArchTag of the cutlass kernel operator. - */ - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const& params, SharedStorage& shared_storage) - { -#if defined(__CUDA_ARCH__) -#if (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) - run_kernel(params, shared_storage); -#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890) - run_kernel(params, shared_storage); -#elif (__CUDA_ARCH__ >= 890) && (__CUDA_ARCH__ < 900) - constexpr bool isFp8 = platform::is_same::value - || platform::is_same::value; - if constexpr (isFp8) - { - run_kernel(params, shared_storage); - } - else - { // reuse sm80 kernel for other types, align with dispatchToArch - run_kernel(params, shared_storage); - } -#elif (__CUDA_ARCH__ >= 900) - run_kernel(params, shared_storage); -#else - static_assert( - false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); -#endif -#else - CUTLASS_NOT_IMPLEMENTED(); -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_problem_visitor.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_problem_visitor.h deleted file mode 100644 index 796dc2fe78d..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_problem_visitor.h +++ /dev/null @@ -1,344 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*! \file - \brief Base scheduler for grouped problems, using MoE -*/ - -#pragma once - -#include "cutlass/gemm/kernel/grouped_problem_visitor.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Visitor class to abstract away the algorithm for iterating over tiles -template -struct BaseMoeProblemVisitor -{ - using ThreadblockShape = ThreadblockShape_; - - struct ProblemInfo - { - static int32_t const kNoPrefetchEntry = -1; - int32_t problem_idx; - int32_t problem_start; - - CUTLASS_DEVICE - ProblemInfo() - : problem_idx(kNoPrefetchEntry) - , problem_start(kNoPrefetchEntry) - { - } - - CUTLASS_DEVICE - ProblemInfo(int32_t problem_idx_, int32_t problem_start_) - : problem_idx(problem_idx_) - , problem_start(problem_start_) - { - } - }; - - struct Params - { - int64_t const* last_row_for_problem; - int64_t gemm_n; - int64_t gemm_k; - int32_t problem_count; - void const* workspace; - int32_t tile_count; - - // - // Methods - // - - /// Ctor - CUTLASS_HOST_DEVICE - Params() - : last_row_for_problem(nullptr) - , gemm_n(0) - , gemm_k(0) - , problem_count(0) - , workspace(nullptr) - , tile_count(0) - { - } - - /// Ctor - CUTLASS_HOST_DEVICE - Params(int64_t const* last_row_for_problem, int64_t gemm_n, int64_t gemm_k, int32_t problem_count, - void const* workspace = nullptr, int32_t tile_count = 0) - : last_row_for_problem(last_row_for_problem) - , gemm_n(gemm_n) - , gemm_k(gemm_k) - , problem_count(problem_count) - , workspace(workspace) - , tile_count(tile_count) - { - } - }; - - Params const& params; - int32_t tile_idx; - int32_t problem_tile_start; - int32_t problem_idx; - - // - // Methods - // - CUTLASS_DEVICE - BaseMoeProblemVisitor(Params const& params_, int32_t block_idx) - : params(params_) - , tile_idx(block_idx) - , problem_tile_start(0) - , problem_idx(0) - { - } - - /// Get the grid shape - CUTLASS_HOST_DEVICE - static cutlass::gemm::GemmCoord grid_shape(cutlass::gemm::GemmCoord const& problem) - { - - return cutlass::gemm::GemmCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), - ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), 1); - } - - /// Gets the global tile index - CUTLASS_HOST_DEVICE - int32_t tile_index() const - { - return tile_idx; - } - - /// Gets the index of the problem - CUTLASS_HOST_DEVICE - int32_t problem_index() const - { - return problem_idx; - } - - CUTLASS_HOST_DEVICE - int32_t threadblock_idx() const - { - return tile_idx - problem_tile_start; - } - - CUTLASS_DEVICE - void advance(int32_t grid_size) - { - tile_idx += grid_size; - } - - CUTLASS_HOST_DEVICE - static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) - { - ProblemSizeHelper::possibly_transpose_problem(problem); - } - - /// Returns the problem size for the current problem - CUTLASS_HOST_DEVICE - cutlass::gemm::GemmCoord problem_size() const - { - return problem_size(problem_idx); - } - - CUTLASS_HOST_DEVICE - cutlass::gemm::GemmCoord problem_size(int idx) const - { - const int64_t prev_problem_row = idx == 0 ? 0 : params.last_row_for_problem[idx - 1]; - const int64_t current_problem_row = params.last_row_for_problem[idx]; - const int64_t gemm_m = current_problem_row - prev_problem_row; - GemmCoord problem(GemmCoord::Index(gemm_m), GemmCoord::Index(params.gemm_n), GemmCoord::Index(params.gemm_k)); - ProblemSizeHelper::possibly_transpose_problem(problem); - return problem; - } - - CUTLASS_HOST_DEVICE - static int32_t tile_count(cutlass::gemm::GemmCoord const& grid) - { - return ProblemSizeHelper::tile_count(grid); - } - - static int32_t group_tile_count(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count) - { - int32_t total_tiles = 0; - for (int32_t i = 0; i < problem_count; ++i) - { - auto problem = host_problem_sizes_ptr[i]; - possibly_transpose_problem(problem); - auto grid = grid_shape(problem); - total_tiles += tile_count(grid); - } - - return total_tiles; - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct MoeProblemVisitor; - -///////////////////////////////////////////////////////////////////////////////////////////////// -// ProblemVisitor that performs all scheduling on device -// -template -struct MoeProblemVisitor : public BaseMoeProblemVisitor -{ - using Base = BaseMoeProblemVisitor; - using Params = typename Base::Params; - static int const kThreadCount = ThreadCount; - static bool const kRequiresPrecomputation = false; - static int const kThreadsPerWarp = 32; - - struct SharedStorage - { - }; - - // Final tile of the problem loaded by this thread. Each thread will hold - // a separate value. - int32_t problem_ending_tile; - - SharedStorage& shared_storage; - - // - // Methods - // - CUTLASS_DEVICE - MoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx) - : Base(params_, block_idx) - , problem_ending_tile(0) - , shared_storage(shared_storage_) - { - this->problem_idx = -1 * kThreadsPerWarp; - this->problem_tile_start = 0; - } - - CUTLASS_DEVICE - bool next_tile() - { - // Check whether the tile to compute is within the range of the current problem. - int32_t problem_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, this->problem_idx % kThreadsPerWarp); - if (this->tile_idx < problem_tile_end) - { - return true; - } - - // Check whether the tile to compute is within the current group of problems fetched by the warp. - // The last tile for this group is the final tile of the problem held by the final thread in the warp. - int32_t group_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1); - - // Keep the starting problem for this group in `problem_idx`. This is done to reduce - // register pressure. The starting problem for this group is simply the first problem - // in the group most recently fetched by the warp. - int32_t& group_problem_start = this->problem_idx; - group_problem_start = (this->problem_idx / kThreadsPerWarp) * kThreadsPerWarp; - - // Keep the starting tile for this group in `problem_tile_start`. This is done to reduce - // register pressure. - int32_t& group_tile_start = this->problem_tile_start; - - // Each thread in the warp processes a separate problem to advance until - // reaching a problem whose starting tile is less less than tile_idx. - while (group_tile_end <= this->tile_idx) - { - group_problem_start += kThreadsPerWarp; - if (group_problem_start > this->params.problem_count) - { - return false; - } - - // Since `group_tile_start` is a reference to `this->problem_tile_start`, this - // also sets `this->problem_tile_start`. The fact that `this->problem_tile_start` - // is also set here is used later in `next_tile`. - group_tile_start = group_tile_end; - - int lane_idx = threadIdx.x % kThreadsPerWarp; - int32_t lane_problem = group_problem_start + lane_idx; - - // Compute the number of tiles in the problem assigned to each thread. - problem_ending_tile = 0; - if (lane_problem < this->params.problem_count) - { - cutlass::gemm::GemmCoord problem = this->problem_size(lane_problem); - cutlass::gemm::GemmCoord grid = this->grid_shape(problem); - problem_ending_tile = this->tile_count(grid); - } - - // Compute a warp-wide inclusive prefix sum to compute the ending tile index of - // each thread's problem. - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < kThreadsPerWarp; i <<= 1) - { - int32_t val = __shfl_up_sync(0xffffffff, problem_ending_tile, i); - if (lane_idx >= i) - { - problem_ending_tile += val; - } - } - - // The total tile count for this group is now in the final position of the prefix sum - int32_t tiles_in_group = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1); - - problem_ending_tile += group_tile_start; - group_tile_end += tiles_in_group; - } - - // The next problem to process is the first one that does not have ending tile position - // that is greater than or equal to tile index. - int32_t problem_idx_in_group = __popc(__ballot_sync(0xffffffff, problem_ending_tile <= this->tile_idx)); - - this->problem_idx = group_problem_start + problem_idx_in_group; - - // The starting tile for this problem is the ending tile of the previous problem. In cases - // where `problem_idx_in_group` is the first problem in the group, we do not need to reset - // `problem_tile_start`, because it is set to the previous group's ending tile in the while - // loop above. - if (problem_idx_in_group > 0) - { - this->problem_tile_start = __shfl_sync(0xffffffff, problem_ending_tile, problem_idx_in_group - 1); - } - - return true; - } - - static size_t get_workspace_size( - cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count, int32_t block_count) - { - return 0; - } - - static void host_precompute(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count, - int32_t block_count, void* host_workspace_ptr) - { - } -}; - -} // namespace kernel -} // namespace gemm -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp deleted file mode 100644 index e3d31a2c5b3..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp +++ /dev/null @@ -1,646 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cute/arch/cluster_sm90.hpp" -#include "cute/tensor.hpp" -#include "cutlass/arch/mma_sm90.h" -#include "cutlass/arch/reg_reconfig.h" -#include "cutlass/cutlass.h" -#include "cutlass/epilogue/collective/detail.hpp" -#include "cutlass/fast_math.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/kernel/tile_scheduler.hpp" -#include "cutlass/kernel_hardware_info.hpp" -#include "cutlass/pipeline/pipeline.hpp" -#include "cutlass/trace.h" -#include "cutlass/workspace.h" - -/////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::kernel -{ - -/////////////////////////////////////////////////////////////////////////////// - -template -class GemmUniversalGated - && CollectiveMainloop_::isGated>> -{ -public: - // - // Type Aliases - // - using ProblemShape = ProblemShape_; - static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, - "ProblemShape{} should be or "); - // Mainloop derived types - using CollectiveMainloop = CollectiveMainloop_; - using TileShape = typename CollectiveMainloop::TileShape; - using TiledMma = typename CollectiveMainloop::TiledMma; - using ArchTag = typename CollectiveMainloop::ArchTag; - using ElementA = typename CollectiveMainloop::ElementA; - using StrideA = typename CollectiveMainloop::StrideA; - using ElementB = typename CollectiveMainloop::ElementB; - using StrideB = typename CollectiveMainloop::StrideB; - using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; - using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; - using ClusterShape = typename DispatchPolicy::ClusterShape; - using MainloopArguments = typename CollectiveMainloop::Arguments; - using MainloopParams = typename CollectiveMainloop::Params; - using Activation = typename CollectiveMainloop::Activation; - - // Epilogue derived types - using CollectiveEpilogue = CollectiveEpilogue_; - using ElementC = typename CollectiveEpilogue::ElementC; - using StrideC = typename CollectiveEpilogue::StrideC; - using ElementD = typename CollectiveEpilogue::ElementD; - using StrideD = typename CollectiveEpilogue::StrideD; - using EpilogueArguments = typename CollectiveEpilogue::Arguments; - using EpilogueParams = typename CollectiveEpilogue::Params; - - static_assert(ArchTag::kMinComputeCapability >= 90); - - using TileSchedulerTag = TileScheduler_; - using TileScheduler = - typename detail::TileSchedulerSelector::Scheduler; - using TileSchedulerArguments = typename TileScheduler::Arguments; - using TileSchedulerParams = typename TileScheduler::Params; - - static constexpr uint32_t NumLoadWarpGroups = 1; - static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup; - static constexpr uint32_t MaxThreadsPerBlock - = CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup); - static constexpr uint32_t MinBlocksPerMultiprocessor = 1; - - /// Register requirement for Load and Math WGs - static constexpr uint32_t LoadRegisterRequirement = 40; - static constexpr uint32_t MmaRegisterRequirement = 232; - - // 1 stage ordered sequence between mainloop and epilogue producer load threads - using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>; - - // Kernel level shared memory storage - struct SharedStorage - { - struct TensorStorage : cute::aligned_struct<128> - { - using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; - using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; - - MainloopTensorStorage mainloop; - EpilogueTensorStorage epilogue; - } tensors; - - struct PipelineStorage : cute::aligned_struct<16> - { - using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; - using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; - - alignas(16) MainloopPipelineStorage mainloop; - alignas(16) EpiLoadPipelineStorage epi_load; - alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; - } pipelines; - }; - - static constexpr int SharedStorageSize = sizeof(SharedStorage); - - // Device side arguments - struct Arguments - { - GemmUniversalMode mode{}; - ProblemShape problem_shape{}; - MainloopArguments mainloop{}; - EpilogueArguments epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerArguments scheduler{}; - }; - - // Kernel entry point API - struct Params - { - GemmUniversalMode mode{}; - ProblemShape problem_shape{}; - MainloopParams mainloop{}; - EpilogueParams epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerParams scheduler{}; - void* workspace{nullptr}; - }; - - // - // Methods - // - - // Convert to underlying arguments. In this case, a simple copy for the aliased type. - static Params to_underlying_arguments(Arguments const& args, void* workspace) - { - CUTLASS_TRACE_HOST("to_underlying_arguments():"); - - auto problem_shape = args.problem_shape; - // if constexpr (detail::IF_SWAP_AB::value) { - // // swap M/N - // get<0>(problem_shape) = get<1>(args.problem_shape); - // get<1>(problem_shape) = get<0>(args.problem_shape); - // } - auto problem_shape_MNKL = append<4>(problem_shape, 1); - - // Get SM count if needed, otherwise use user supplied SM count - int sm_count = args.hw_info.sm_count; - if (sm_count <= 0) - { - CUTLASS_TRACE_HOST( - " WARNING: Arguments do not include a valid SM count.\n" - " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); - sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); - } - - CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); - - KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; - - // Calculate workspace pointers - uint8_t* workspace_ptr = reinterpret_cast(workspace); - size_t workspace_offset = 0; - - void* scheduler_workspace = workspace_ptr; - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - - void* epilogue_workspace = workspace_ptr + workspace_offset; - workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - - void* mainloop_workspace = nullptr; - // Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used - // in separate reduction scheme for streamk case, NumEpilogueSubTiles default value is 1, which means - // subtile will not be used, therefore separate reduction will not be enabled. - constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); - TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments(problem_shape_MNKL, TileShape{}, - ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles); - - return {args.mode, problem_shape, - CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace), - CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), hw_info, - scheduler, workspace}; - } - - static bool can_implement(Arguments const& args) - { - bool implementable = (args.mode == GemmUniversalMode::kGemm) - or (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); - if (!implementable) - { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); - return implementable; - } - implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); - implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); - return implementable; - } - - static size_t get_workspace_size(Arguments const& args) - { - size_t workspace_size = 0; - constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); - - workspace_size += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - return workspace_size; - } - - static cutlass::Status initialize_workspace(Arguments const& args, void* workspace = nullptr, - cudaStream_t stream = nullptr, CudaHostAdapter* cuda_adapter = nullptr) - { - Status status = Status::kSuccess; - uint8_t* workspace_ptr = reinterpret_cast(workspace); - size_t workspace_offset = 0; - constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); - - status = TileScheduler::template initialize_workspace(args.scheduler, - workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups, - NumEpilogueSubTiles); - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) - { - return status; - } - - status = CollectiveEpilogue::initialize_workspace( - args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) - { - return status; - } - - return status; - } - - // Computes the kernel launch grid shape based on runtime parameters - static dim3 get_grid_shape(Params const& params) - { - // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently - TileSchedulerArguments args{}; - if constexpr (!std::is_const_v) - { - args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; - } - args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN - ? TileScheduler::RasterOrderOptions::AlongN - : TileScheduler::RasterOrderOptions::AlongM; - return TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); - } - - static dim3 get_block_shape() - { - return dim3(MaxThreadsPerBlock, 1, 1); - } - - CUTLASS_DEVICE - void operator()(Params const& params, char* smem_buf) - { - using namespace cute; - using X = Underscore; - -// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. -#if !defined(__CUDA_ARCH_FEAT_SM90_ALL) - printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); -#else - - // Preconditions - static_assert(size(TiledMma{}) == 256, "Cooperative kernel must have TiledMMA operating using 256 threads."); - static_assert(size<0>(TileShape{}) >= 128, - "Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension."); - - static_assert(cute::rank(StrideA{}) == 3, - "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideB{}) == 3, - "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideC{}) == 3, - "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideD{}) == 3, - "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - - /* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */ - enum class WarpGroupRole - { - Producer = 0, - Consumer0 = 1, - Consumer1 = 2 - }; - enum class ProducerWarpRole - { - Mainloop = 0, - Warp1 = 1, - Epilogue = 2, - Warp3 = 3 - }; - - // Kernel level shared memory storage - SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - - int thread_idx = int(threadIdx.x); - int lane_idx = canonical_lane_idx(); - int warp_idx = canonical_warp_idx_sync(); - int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; - int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; - int mma_thread_idx = thread_idx % size(TiledMma{}); - auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); - auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); - int lane_predicate = cute::elect_one_sync(); - uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); - - // Issue Tma Descriptor Prefetch from a single thread - if ((warp_idx == 0) && lane_predicate) - { - CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); - CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); - } - - // Mainloop Load pipeline - using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; - typename MainloopPipeline::Params mainloop_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop) - { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; - } - if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) - { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; - } - mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; - mainloop_pipeline_params.num_consumers = size(TiledMma{}); - mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; - MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); - - // Epilogue Load pipeline - using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; - typename EpiLoadPipeline::Params epi_load_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue) - { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; - } - if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) - { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; - } - epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); - epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; - epi_load_pipeline_params.consumer_arv_count = size(TiledMma{}); - epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; - EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); - - // Epilogue Store pipeline - using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; - typename EpiStorePipeline::Params epi_store_pipeline_params; - epi_store_pipeline_params.always_wait = true; - EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); - - typename LoadWarpOrderBarrier::Params params_load_order_barrier; - params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1; - params_load_order_barrier.group_size = NumThreadsPerWarp; - LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier); - - // Initialize starting pipeline states for the collectives - // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) - typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; - typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; - - // For the DMA Load (producer) we start with an opposite phase - // i.e., we skip all waits since we know that the buffer is indeed empty - PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); - - auto cluster_wait_fn = []() - { - // We need this to guarantee that the Pipeline init is visible - // To all producers and consumer thread blocks in the Cluster - if constexpr (size(ClusterShape{}) > 1) - { - cute::cluster_arrive_relaxed(); - return []() { cute::cluster_wait(); }; - } - else - { - __syncthreads(); - return []() {}; // do nothing - } - }(); - - // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); - - // Get the appropriate blocks for this thread block -- potential for thread block locality - TiledMma tiled_mma; - auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) - - TileScheduler scheduler{params.scheduler}; - auto work_tile_info = scheduler.get_current_work(); - - // In a warp specialized kernel, collectives expose data movement and compute operations separately - CollectiveMainloop collective_mainloop; - CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); - - // Prepare and partition the input tensors. Expects a tuple of tensors where: - // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) - // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) - auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); - static_assert(cute::tuple_size_v >= 3, - "Output of load_init must have at least three elements (A, B, Aux)"); - - // Extract out partitioned A and B. - Tensor gA_mkl = get<0>(load_inputs); - Tensor gB_nkl = get<1>(load_inputs); - Tensor gAux_xkl = get<2>(load_inputs); - - // Get pipeline stage increments from tensor shapes - auto k_tile_count = size<3>(gA_mkl); - - // Wait for all thread blocks in the Cluster - cluster_wait_fn(); - - if (warp_group_role == WarpGroupRole::Producer) - { - cutlass::arch::warpgroup_reg_dealloc(); - - // Mainloop Producer Warp - if (producer_warp_role == ProducerWarpRole::Mainloop) - { - bool do_load_order_arrive = true; - while (work_tile_info.is_valid()) - { - if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) - { - work_tile_info = fetch_next_work(work_tile_info, scheduler); - continue; - } - - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - // Get the number of K tiles to compute for this work as well as the starting K tile offset of the - // work. - auto work_k_tile_count - = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); - auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); - auto k_tile_iter - = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); - - collective_mainloop.load(params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state, - load_inputs, blk_coord, k_tile_iter, work_k_tile_count, lane_idx, block_rank_in_cluster, - shared_storage.tensors.mainloop); - // Update starting pipeline state for the next tile - mainloop_pipe_producer_state.advance(work_k_tile_count); - - // Signal for the epilogue load warp to begin - if (do_load_order_arrive) - { - load_order_barrier.arrive(); - do_load_order_arrive = false; - } - - // Get next work tile - work_tile_info = fetch_next_work(work_tile_info, scheduler); - } // Scheduler work fetch loop - - // Make sure all Consumer Warp Groups have been waited upon - collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); - } // Mainloop Producer Warp End - - // Epilogue Producer Warp - else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) - { - while (work_tile_info.is_valid()) - { - if (!TileScheduler::requires_separate_reduction(params.scheduler)) - { - load_order_barrier.wait(); - } - if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) - { - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - epi_load_pipe_producer_state = collective_epilogue.load(epi_load_pipeline, - epi_load_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, tiled_mma, lane_idx, - shared_storage.tensors.epilogue, work_tile_info.reduction_subtile_idx()); - } - - // Get next work tile - work_tile_info = fetch_next_work(work_tile_info, scheduler); - } // Scheduler work fetch loop - - // Make sure all Consumer Warp Groups have been waited upon - collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); - } // Epilogue Producer Warp End - } // Producer Warp Group End - - else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) - { - cutlass::arch::warpgroup_reg_alloc(); - - // Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it - bool do_store_tail = false; - float scale_d0 = params.mainloop.scale_d0; - float scale_d1 = params.mainloop.scale_d1; - while (work_tile_info.is_valid()) - { - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - auto work_k_tile_count - = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); - - // Allocate the accumulators for the (M,N) blk_shape - // - // MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead. - auto accumulators0 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) - auto accumulators1 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) - if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) - { - collective_mainloop.mma(mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0, - accumulators1, work_k_tile_count, mma_thread_idx, shared_storage.tensors.mainloop, - params.mainloop); - - // Make sure the math instructions are done and free buffers before entering the epilogue - collective_mainloop.mma_tail(mainloop_pipeline, mainloop_pipe_consumer_state, work_k_tile_count); - - // Update starting mainloop pipeline state for the next tile - mainloop_pipe_consumer_state.advance(work_k_tile_count); - } - // Index of warp group within consumer warp groups - int consumer_warp_group_idx = canonical_warp_group_idx() - NumLoadWarpGroups; - - // Perform reduction across splits, if needed - TileScheduler::fixup( - params.scheduler, work_tile_info, accumulators0, NumMmaWarpGroups, consumer_warp_group_idx); - TileScheduler::fixup( - params.scheduler, work_tile_info, accumulators1, NumMmaWarpGroups, consumer_warp_group_idx); - - Activation elt_op; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(accumulators0); i++) - { - accumulators0[i] = (accumulators0[i] * scale_d0) * elt_op(scale_d1 * accumulators1[i]); - } - - if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) - { - // Epilogue and write to gD - auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] - = collective_epilogue.store(epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline, - epi_store_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, accumulators0, - tiled_mma, mma_thread_idx, shared_storage.tensors.epilogue, - work_tile_info.reduction_subtile_idx()); - epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next; - epi_store_pipe_producer_state = epi_store_pipe_producer_state_next; - do_store_tail = true; - } - - // Get next work tile - work_tile_info = fetch_next_work(work_tile_info, scheduler); - } // Scheduler work fetch loop - - if (do_store_tail) - { - collective_epilogue.store_tail( - epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline, epi_store_pipe_producer_state); - } - } // Consumer Warp Groups End -#endif - } - -private: - // Kernel helper function to get next work unit - CUTLASS_DEVICE - typename TileScheduler::WorkTileInfo fetch_next_work( - typename TileScheduler::WorkTileInfo& work_tile_info, TileScheduler& scheduler) const - { - // Check whether we should continue on with the current work unit. If this is the case, - // the work unit will have been updated in continue_current_work to reflect the new - // tile to be computed. - if (scheduler.continue_current_work(work_tile_info)) - { - return work_tile_info; - } - - // Get next work tile - scheduler.advance_to_next_work(); - return scheduler.get_current_work(); - } -}; - -/////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::kernel diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp deleted file mode 100644 index 39886f2431d..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp +++ /dev/null @@ -1,621 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cute/arch/cluster_sm90.hpp" -#include "cutlass/arch/mma_sm90.h" -#include "cutlass/arch/reg_reconfig.h" -#include "cutlass/cutlass.h" -#include "cutlass/epilogue/collective/detail.hpp" -#include "cutlass/fast_math.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" -#include "cutlass/kernel_hardware_info.hpp" -#include "cutlass/pipeline/pipeline.hpp" -#include "cutlass/trace.h" -#include "cutlass/workspace.h" - -#include "cute/tensor.hpp" - -#include "cute/util/debug.hpp" - -/////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::kernel -{ - -/////////////////////////////////////////////////////////////////////////////// - -template -class GemmUniversalGated - && CollectiveMainloop_::isGated>> -{ -public: - // - // Type Aliases - // - using ProblemShape = ProblemShape_; - static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, - "ProblemShape{} should be or "); - // Mainloop derived types - using CollectiveMainloop = CollectiveMainloop_; - using TileShape = typename CollectiveMainloop::TileShape; - using TiledMma = typename CollectiveMainloop::TiledMma; - using ArchTag = typename CollectiveMainloop::ArchTag; - using ElementA = typename CollectiveMainloop::ElementA; - using StrideA = typename CollectiveMainloop::StrideA; - using ElementB = typename CollectiveMainloop::ElementB; - using StrideB = typename CollectiveMainloop::StrideB; - using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; - using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; - using ClusterShape = typename DispatchPolicy::ClusterShape; - using MainloopArguments = typename CollectiveMainloop::Arguments; - using MainloopParams = typename CollectiveMainloop::Params; - using Activation = typename CollectiveMainloop::Activation; - static_assert(ArchTag::kMinComputeCapability >= 90); - - // Epilogue derived types - using CollectiveEpilogue = CollectiveEpilogue_; - using ElementC = typename CollectiveEpilogue::ElementC; - using StrideC = typename CollectiveEpilogue::StrideC; - using ElementD = typename CollectiveEpilogue::ElementD; - using StrideD = typename CollectiveEpilogue::StrideD; - using EpilogueArguments = typename CollectiveEpilogue::Arguments; - using EpilogueParams = typename CollectiveEpilogue::Params; - - static_assert(!cute::is_same_v, - "Ping-pong kernel does not currently support stream-K scheduler."); - using TileSchedulerTag = TileScheduler_; - using TileScheduler = - typename detail::TileSchedulerSelector::Scheduler; - using TileSchedulerArguments = typename TileScheduler::Arguments; - using TileSchedulerParams = typename TileScheduler::Params; - - static constexpr uint32_t NumLoadWarpGroups = 1; - static constexpr uint32_t NumMmaWarpGroups = 2; - static constexpr uint32_t MaxThreadsPerBlock - = CUTE_STATIC_V(size(TiledMma{})) + (NumMmaWarpGroups * NumThreadsPerWarpGroup); - static constexpr uint32_t MinBlocksPerMultiprocessor = 1; - - /// Register requirement for Load and Math WGs - static constexpr uint32_t LoadRegisterRequirement = 40; - static constexpr uint32_t MmaRegisterRequirement = 232; - - // 1 stage ordered sequence between mainloop and epilogue producer load threads - using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>; - - // Order Sequence barrier with two stages: one for Mainloop and one for Epilogue - static constexpr uint32_t StagesPerMathWarpGroup = 2; - using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier; - - // Kernel level shared memory storage - struct SharedStorage - { - struct TensorStorage : cute::aligned_struct<128> - { - using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; - using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; - - MainloopTensorStorage mainloop; - EpilogueTensorStorage epilogue; - } tensors; - - struct PipelineStorage : cute::aligned_struct<16> - { - using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; - using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; - using MathWarpGroupOrderBarrierStorage = typename MathWarpGroupOrderBarrier::SharedStorage; - - alignas(16) MainloopPipelineStorage mainloop; - alignas(16) EpiLoadPipelineStorage epi_load; - alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order; - alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; - } pipelines; - }; - - static constexpr int SharedStorageSize = sizeof(SharedStorage); - - // Device side arguments - struct Arguments - { - GemmUniversalMode mode{}; - ProblemShape problem_shape{}; - MainloopArguments mainloop{}; - EpilogueArguments epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerArguments scheduler{}; - }; - - // Kernel entry point API - struct Params - { - GemmUniversalMode mode{}; - ProblemShape problem_shape{}; - MainloopParams mainloop{}; - EpilogueParams epilogue{}; - KernelHardwareInfo hw_info{}; - TileSchedulerParams scheduler{}; - }; - - // - // Methods - // - - // Convert to underlying arguments. In this case, a simple copy for the aliased type. - static Params to_underlying_arguments(Arguments const& args, void* workspace) - { - CUTLASS_TRACE_HOST("to_underlying_arguments():"); - - (void) workspace; - auto problem_shape = args.problem_shape; - // if constexpr (detail::IF_SWAP_AB::value) { - // // swap M/N - // get<0>(problem_shape) = get<1>(args.problem_shape); - // get<1>(problem_shape) = get<0>(args.problem_shape); - // } - auto problem_shape_MNKL = append<4>(problem_shape, 1); - - // Get SM count if needed, otherwise use user supplied SM count - int sm_count = args.hw_info.sm_count; - if (sm_count <= 0) - { - CUTLASS_TRACE_HOST( - " WARNING: Arguments do not include a valid SM count.\n" - " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); - sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); - } - - CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); - KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; - - // Calculate workspace pointers - uint8_t* workspace_ptr = reinterpret_cast(workspace); - size_t workspace_offset = 0; - - void* scheduler_workspace = workspace_ptr; - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - - void* epilogue_workspace = workspace_ptr + workspace_offset; - workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - - void* mainloop_workspace = nullptr; - - return {args.mode, problem_shape, - CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace), - CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), hw_info, - TileScheduler::to_underlying_arguments( - problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace)}; - } - - static bool can_implement(Arguments const& args) - { - bool implementable = (args.mode == GemmUniversalMode::kGemm) - or (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); - if (!implementable) - { - CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); - return implementable; - } - implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); - implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); - implementable &= TileScheduler::can_implement(args.scheduler); - return implementable; - } - - static size_t get_workspace_size(Arguments const& args) - { - size_t workspace_size = 0; - workspace_size += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); - workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); - - return workspace_size; - } - - static cutlass::Status initialize_workspace(Arguments const& args, void* workspace = nullptr, - cudaStream_t stream = nullptr, CudaHostAdapter* cuda_adapter = nullptr) - { - Status status = Status::kSuccess; - uint8_t* workspace_ptr = reinterpret_cast(workspace); - size_t workspace_offset = 0; - - status = TileScheduler::template initialize_workspace(args.scheduler, - workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups); - workspace_offset += TileScheduler::template get_workspace_size( - args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) - { - return status; - } - - status = CollectiveEpilogue::initialize_workspace( - args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); - workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); - workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); - if (status != Status::kSuccess) - { - return status; - } - - return status; - } - - // Computes the kernel launch grid shape based on runtime parameters - static dim3 get_grid_shape(Params const& params) - { - // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently - TileSchedulerArguments args{}; - if constexpr (!std::is_const_v) - { - args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; - } - args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN - ? TileScheduler::RasterOrderOptions::AlongN - : TileScheduler::RasterOrderOptions::AlongM; - return TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); - } - - static dim3 get_block_shape() - { - return dim3(MaxThreadsPerBlock, 1, 1); - } - - CUTLASS_DEVICE - void operator()(Params const& params, char* smem_buf) - { - using namespace cute; - using X = Underscore; - -// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. -#if !defined(__CUDA_ARCH_FEAT_SM90_ALL) - printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); -#else - - // Preconditions - static_assert(cute::rank(StrideA{}) == 3, - "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideB{}) == 3, - "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideC{}) == 3, - "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(cute::rank(StrideD{}) == 3, - "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - - enum class WarpGroupRole - { - Producer = 0, - Consumer0 = 1, - Consumer1 = 2 - }; - enum class ProducerWarpRole - { - Mainloop = 0, - Warp1 = 1, - Epilogue = 2, - Warp3 = 3 - }; - - // Kernel level shared memory storage - SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - - int thread_idx = int(threadIdx.x); - int lane_idx = canonical_lane_idx(); - int warp_idx = canonical_warp_idx_sync(); - int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; - int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; - auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); - auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); - int lane_predicate = cute::elect_one_sync(); - uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); - - // Issue Tma Descriptor Prefetch from a single thread - if ((warp_idx == 0) && lane_predicate) - { - CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); - CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); - } - - // Mainloop Load pipeline - using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; - typename MainloopPipeline::Params mainloop_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop) - { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; - } - if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) - { - mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; - } - mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; - mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup; - mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; - MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); - - // Epilogue Load pipeline - using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; - typename EpiLoadPipeline::Params epi_load_pipeline_params; - if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue) - { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; - } - if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) - { - epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; - } - epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); - epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; - epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; - epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; - EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); - - // Epilogue Store pipeline - using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; - typename EpiStorePipeline::Params epi_store_pipeline_params; - epi_store_pipeline_params.always_wait = true; - EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); - - typename LoadWarpOrderBarrier::Params params_load_order_barrier; - params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1; - params_load_order_barrier.group_size = NumThreadsPerWarp; - LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier); - - typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier; - // DMA Load WG will not participate in these Ordered Barrier syncs - params_math_wg_order_barrier.group_id = canonical_warp_group_idx() - static_cast(WarpGroupRole::Consumer0); - params_math_wg_order_barrier.group_size = NumThreadsPerWarpGroup; // Number of threads / participants in a group - MathWarpGroupOrderBarrier math_wg_order_barrier( - shared_storage.pipelines.math_wg_order, params_math_wg_order_barrier); - - // Initialize starting pipeline states for the collectives - // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) - typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; - typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; - - // For the DMA Load (producer) we start with an opposite phase - // i.e., we skip all waits since we know that the buffer is indeed empty - PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); - PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); - - auto cluster_wait_fn = [&]() - { - // We need this to guarantee that the Pipeline init is visible - // To all producers and consumer thread blocks in the Cluster - if constexpr (size(ClusterShape{}) > 1) - { - cute::cluster_arrive_relaxed(); - return []() { cute::cluster_wait(); }; - } - else - { - __syncthreads(); - return []() {}; // do nothing - } - }(); - - // Separate out problem shape for convenience - // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) - auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); - - // Get the appropriate blocks for this thread block -- potential for thread block locality - TiledMma tiled_mma; - auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) - - // In a warp specialized kernel, collectives expose data movement and compute operations separately - CollectiveMainloop collective_mainloop; - CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); - - // Prepare and partition the input tensors. Expects a tuple of tensors where: - // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) - // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) - auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); - static_assert(cute::tuple_size_v >= 3, - "Output of load_init must have at least three elements (A, B, Aux)"); - - // Extract out partitioned A and B. - Tensor gA_mkl = get<0>(load_inputs); - Tensor gB_nkl = get<1>(load_inputs); - Tensor gAux_xkl = get<2>(load_inputs); - - // Get pipeline stage increments from tensor shapes - auto k_tile_count = size<3>(gA_mkl); - auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape); - auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); - - TileScheduler scheduler{params.scheduler}; - - if (warp_group_role == WarpGroupRole::Consumer1) - { - // Advance 2nd Math WG to the next work tile for the startup - scheduler.advance_to_next_work(); - // Advance 2nd Math WG pipeline states to the end of 1st Math WG - mainloop_pipe_consumer_state.advance(k_tile_count); - epi_load_pipe_consumer_state.advance(c_tile_count); - epi_store_pipe_producer_state.advance(d_tile_count); - } - auto work_tile_info = scheduler.get_current_work(); - - // Wait for all thread blocks in the Cluster - cluster_wait_fn(); - - if (warp_group_role == WarpGroupRole::Producer) - { - cutlass::arch::warpgroup_reg_dealloc(); - - // Mainloop Producer Warp - if (producer_warp_role == ProducerWarpRole::Mainloop) - { - bool do_load_order_arrive = true; - while (work_tile_info.is_valid()) - { - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl)); - - collective_mainloop.load(params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state, - load_inputs, blk_coord, k_tile_iter, k_tile_count, lane_idx, block_rank_in_cluster, - shared_storage.tensors.mainloop); - // Update starting pipeline state for the next tile - mainloop_pipe_producer_state.advance(k_tile_count); - - // Signal for the epilogue load warp to begin - if (do_load_order_arrive) - { - load_order_barrier.arrive(); - do_load_order_arrive = false; - } - - // Get next work tile - scheduler.advance_to_next_work(); - work_tile_info = scheduler.get_current_work(); - } // Scheduler work fetch loop - - // Make sure all Consumer Warp Groups have been waited upon - collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); - } // Mainloop Producer Warp End - - // Epilogue Producer Warp - else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) - { - load_order_barrier.wait(); - while (work_tile_info.is_valid()) - { - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - epi_load_pipe_producer_state - = collective_epilogue.load(epi_load_pipeline, epi_load_pipe_producer_state, problem_shape_MNKL, - blk_shape, blk_coord, tiled_mma, lane_idx, shared_storage.tensors.epilogue); - - // Get next work tile - scheduler.advance_to_next_work(); - work_tile_info = scheduler.get_current_work(); - } // Scheduler work fetch loop - - // Make sure all Consumer Warp Groups have been waited upon - collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); - } // Epilogue Producer Warp End - } // Producer Warp Group End - - else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) - { - cutlass::arch::warpgroup_reg_alloc(); - - float scale_d0 = params.mainloop.scale_d0; - float scale_d1 = params.mainloop.scale_d1; - while (work_tile_info.is_valid()) - { - // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape - auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); - auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); - auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); - auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - - // Allocate the accumulators for the (M,N) blk_shape - Tensor accumulators0 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) - Tensor accumulators1 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) - - // Order two Math WG's MMA one after the other, helps hide Epilogue - math_wg_order_barrier.wait(); - - collective_mainloop.mma(mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0, accumulators1, - k_tile_count, warp_group_thread_idx, shared_storage.tensors.mainloop, params.mainloop); - - // Cue for next Math WG's MMA to start - math_wg_order_barrier.arrive(); - - // Make sure the math instructions are done and free buffers before entering the epilogue - collective_mainloop.mma_tail(mainloop_pipeline, mainloop_pipe_consumer_state, k_tile_count); - // Update starting mainloop pipeline state for the next tile - mainloop_pipe_consumer_state.advance(k_tile_count * NumMmaWarpGroups); - - Activation elt_op; - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(accumulators0); i++) - { - accumulators0[i] = (accumulators0[i] * scale_d0) * elt_op(scale_d1 * accumulators1[i]); - } - - // Order two Math WG's Epilogue one after the other - math_wg_order_barrier.wait(); - - // Epilogue and write to gD - auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] - = collective_epilogue.store(epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline, - epi_store_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, accumulators0, - tiled_mma, warp_group_thread_idx, shared_storage.tensors.epilogue); - - // TMA store pipeline wait is only visible to TMA-issuing warp, so for multiple-consumer kernels - // we need to wait for all TMA stores to complete before issuing consumer order barrier arrives - // to ensure next math consumer doesn't overwrite smem of in-flight TMA stores of current consumer. - auto [epi_load_pipe_consumer_state_next_, epi_store_pipe_producer_state_next_] - = collective_epilogue.store_tail(epi_load_pipeline, epi_load_pipe_consumer_state_next, - epi_store_pipeline, epi_store_pipe_producer_state_next); - - // Update starting load/store pipeline states for the next tile - // state has already been incremented by 1 tile in collective calls, advance once again for ping pong - epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next_; - epi_store_pipe_producer_state = epi_store_pipe_producer_state_next_; - epi_load_pipe_consumer_state.advance(c_tile_count); - epi_store_pipe_producer_state.advance(d_tile_count); - - // Cue for next Math WG's Epilogue to start - math_wg_order_barrier.arrive(); - - // Get next work tile - scheduler.advance_to_next_work(NumMmaWarpGroups); - work_tile_info = scheduler.get_current_work(); - } // Scheduler work fetch loop - } // Consumer Warp Groups End -#endif - } -}; - -/////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::kernel diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h deleted file mode 100644 index 5e3531f0938..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h +++ /dev/null @@ -1,494 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/*! \file - \brief based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h -*/ - -#pragma once - -#include "cutlass/complex.h" -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/semaphore.h" - -#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" -#include "cutlass/gemm/kernel/gemm_transpose_operands.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/trace.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace kernel -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct SplitkGemmGrouped -{ -public: - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; - static bool const kTransposed = Transposed; - - // Optional transpose - using MapArguments = kernel::detail::MapArguments; - - // Public-facing type definitions related to operand element type, layout, and complex conjugate - // operation. Must interact with the 'kTransposed' notion. - using ElementA = typename MapArguments::ElementA; - using LayoutA = typename MapArguments::LayoutA; - using ElementB = typename MapArguments::ElementB; - using LayoutB = typename MapArguments::LayoutB; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename MapArguments::LayoutC; - - using ElementFinalOutput = typename MapArguments::ElementA; - - static ComplexTransform const kTransformA = MapArguments::kTransformA; - static ComplexTransform const kTransformB = MapArguments::kTransformB; - - // Type definitions about the mainloop. - using Operator = typename Mma::Operator; - using OperatorClass = typename Mma::Operator::OperatorClass; - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; - using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; - - static int const kStages = Mma::kStages; - static int const kAlignmentA = MapArguments::kAlignmentA; - static int const kAlignmentB = MapArguments::kAlignmentB; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - using ProblemVisitor - = GemmGroupedProblemVisitor; - - // - // Structures - // - - /// Argument structure - struct Arguments - { - - // - // Data members - // - - GemmCoord* problem_sizes; - int problem_count; - int threadblock_count; - - typename EpilogueOutputOp::Params output_op; - - ElementA** ptr_A; - ElementB** ptr_B; - ElementFinalOutput** ptr_C; - ElementFinalOutput** ptr_D; - - typename LayoutA::Stride::LongIndex* lda; - typename LayoutB::Stride::LongIndex* ldb; - typename LayoutC::Stride::LongIndex* ldc; - typename LayoutC::Stride::LongIndex* ldd; - - // Only used by device-level operator - GemmCoord* host_problem_sizes; - - // splitK - int split_k_slices; - int64_t* splitk_buffer_offsets; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - Arguments() - : problem_count(0) - , threadblock_count(0) - , ptr_A(nullptr) - , ptr_B(nullptr) - , ptr_C(nullptr) - , ptr_D(nullptr) - , lda(nullptr) - , ldb(nullptr) - , ldc(nullptr) - , ldd(nullptr) - , host_problem_sizes(nullptr) - , split_k_slices(1) - , splitk_buffer_offsets(nullptr) - { - } - - /// Ctor - CUTLASS_HOST_DEVICE - Arguments(GemmCoord* problem_sizes, int problem_count, int threadblock_count, - typename EpilogueOutputOp::Params output_op, ElementA** ptr_A, ElementB** ptr_B, ElementFinalOutput** ptr_C, - ElementFinalOutput** ptr_D, typename LayoutA::Stride::LongIndex* lda, - typename LayoutB::Stride::LongIndex* ldb, typename LayoutC::Stride::LongIndex* ldc, - typename LayoutC::Stride::LongIndex* ldd, GemmCoord* host_problem_sizes, int split_k_slices, - int64_t* splitk_buffer_offsets) - : problem_sizes(problem_sizes) - , problem_count(problem_count) - , threadblock_count(threadblock_count) - , output_op(output_op) - , ptr_A(ptr_A) - , ptr_B(ptr_B) - , ptr_C(ptr_C) - , ptr_D(ptr_D) - , lda(lda) - , ldb(ldb) - , ldc(ldc) - , ldd(ldd) - , host_problem_sizes(host_problem_sizes) - , split_k_slices(split_k_slices) - , splitk_buffer_offsets(splitk_buffer_offsets) - { - } - }; - - // - // Structure for precomputing values in host memory and passing to kernels - // - - /// Parameters structure - struct Params - { - - typename ProblemVisitor::Params problem_visitor; - int threadblock_count; - - typename EpilogueOutputOp::Params output_op; - - ElementA** ptr_A; - ElementB** ptr_B; - ElementFinalOutput** ptr_C; - ElementFinalOutput** ptr_D; - ElementC* ptr_C_split; - ElementC* ptr_D_split; - - typename LayoutA::Stride::LongIndex* lda; - typename LayoutB::Stride::LongIndex* ldb; - typename LayoutC::Stride::LongIndex* ldc; - typename LayoutC::Stride::LongIndex* ldd; - - // - // Methods - // - - // splitk - GemmCoord grid_tiled_shape; - int swizzle_log_tile; - int gemm_k_size; - GemmCoord* host_problem_sizes; - int split_k_slices; - int64_t* splitk_buffer_offsets; - - CUTLASS_HOST_DEVICE - Params() - : ptr_A(nullptr) - , ptr_B(nullptr) - , ptr_C(nullptr) - , ptr_D(nullptr) - , ptr_C_split(nullptr) - , ptr_D_split(nullptr) - , lda(nullptr) - , ldb(nullptr) - , ldc(nullptr) - , ldd(nullptr) - , swizzle_log_tile(0) - , gemm_k_size(0) - , host_problem_sizes(nullptr) - , split_k_slices(1) - , splitk_buffer_offsets(nullptr) - { - } - - CUTLASS_HOST_DEVICE - Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) - : problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count) - , host_problem_sizes(args.host_problem_sizes) - , threadblock_count(args.threadblock_count) - , output_op(args.output_op) - , ptr_A(args.ptr_A) - , ptr_B(args.ptr_B) - , ptr_C(args.ptr_C) - , ptr_D(args.ptr_D) - , ptr_C_split((ElementC*) workspace) - , ptr_D_split((ElementC*) workspace) - , lda(args.lda) - , ldb(args.ldb) - , ldc(args.ldc) - , ldd(args.ldd) - , split_k_slices(args.split_k_slices) - , splitk_buffer_offsets(args.splitk_buffer_offsets) - { - // Determine grid shape - ThreadblockSwizzle threadblock_swizzle; - grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.host_problem_sizes[0], - {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.split_k_slices); - swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape); - - // only support same k - int full_gemm_k_iterations = args.host_problem_sizes[0].k() / Mma::Shape::kK; - int gemm_k_iterations = full_gemm_k_iterations / grid_tiled_shape.k(); - - gemm_k_size = gemm_k_iterations * Mma::Shape::kK; - } - - CUTLASS_HOST_DEVICE - void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) - { - - problem_visitor = - typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, workspace, tile_count); - threadblock_count = args.threadblock_count; - output_op = args.output_op; - ptr_A = args.ptr_A; - ptr_B = args.ptr_B; - ptr_C = args.ptr_C; - ptr_D = args.ptr_D; - ptr_C_split = workspace; - ptr_D_split = workspace; - - lda = args.lda; - ldb = args.ldb; - ldc = args.ldc; - ldd = args.ldd; - } - }; - - /// Shared memory storage structure - struct SharedStorage - { - union - { - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - } kernel; - - // ProblemVisitor shared storage can't be overlapped with others - typename ProblemVisitor::SharedStorage problem_visitor; - }; - -public: - // - // Methods - // - - CUTLASS_DEVICE - SplitkGemmGrouped() {} - - /// Determines whether kernel satisfies alignment - static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) - { - return Status::kSuccess; - } - - static Status can_implement(Arguments const& args) - { - return Status::kSuccess; - } - - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const& params, SharedStorage& shared_storage) - { - - // - // These types shadow the type-level definitions and support the ability to implement - // a 'transposed' GEMM that computes the transposed problems. - // - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename Epilogue::OutputTileIterator::Layout; - - // - // Problem visitor. - // - ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); - - // Outer 'persistent' loop to iterate over tiles - while (problem_visitor.next_tile()) - { - - GemmCoord problem_size = problem_visitor.problem_size(); - int32_t problem_idx = problem_visitor.problem_index(); - int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); - - GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); - - // Load element pointers. Exchange pointers and strides if working on the transpose - ElementA* ptr_A - = reinterpret_cast((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx])); - typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]); - - ElementB* ptr_B - = reinterpret_cast((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx])); - typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]); - - // Compute threadblock location - ThreadblockSwizzle threadblock_swizzle; - GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); - - cutlass::gemm::GemmCoord threadblock_offset(int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM, - int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, 0); - - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{ - threadblock_offset.m(), - threadblock_tile_offset.k() * params.gemm_k_size, - }; - - cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size, threadblock_offset.n()}; - - // Problem size is a function of threadblock index in the K dimension - int problem_size_k; - if (threadblock_tile_offset.k() + 1 == params.grid_tiled_shape.k()) - { - problem_size_k = problem_size.k(); - } - else - { - problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; - } - - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size_k}, thread_idx, tb_offset_A); - - typename Mma::IteratorB iterator_B( - LayoutB(ldm_B), ptr_B, {problem_size_k, problem_size.n()}, thread_idx, tb_offset_B); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = canonical_warp_idx_sync(); - - int lane_idx = threadIdx.x % 32; - - // - // Matrix multiply phase - // - - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); - - // Wait for all threads to finish their epilogue phases from the previous tile. - __syncthreads(); - - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); - - // - // Epilogue - // - - EpilogueOutputOp output_op(params.output_op); - - ElementC* ptr_C = params.ptr_C_split; - ElementC* ptr_D = params.ptr_D_split; - - LayoutC layout_C(params.ldc[problem_idx]); - LayoutC layout_D(params.ldd[problem_idx]); - - typename Epilogue::OutputTileIterator::Params params_C(layout_C); - typename Epilogue::OutputTileIterator::Params params_D(layout_D); - - // assume identity swizzle - MatrixCoord threadblock_offset_C(threadblock_offset.m(), threadblock_offset.n()); - - // Tile iterator loading from source tensor. - typename Epilogue::OutputTileIterator iterator_C( - params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset_C); - - iterator_C.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k() - + gridDim.z * params.splitk_buffer_offsets[problem_idx]); - - // Tile iterator writing to destination tensor. - typename Epilogue::OutputTileIterator iterator_D( - params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset_C); - iterator_D.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k() - + gridDim.z * params.splitk_buffer_offsets[problem_idx]); - - Epilogue epilogue(shared_storage.kernel.epilogue, thread_idx, warp_idx, lane_idx); - - // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, iterator_D, accumulators, iterator_C); - - // Next tile - problem_visitor.advance(gridDim.x); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h deleted file mode 100644 index ed5e3e4daf8..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h +++ /dev/null @@ -1,125 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "cutlass_extensions/arch/mma.h" -#include "cutlass_extensions/interleaved_numeric_conversion.h" - -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ -//////////////////////////////////////////////////////////////////////////////// - -// We need to distinguish here, since we want volta support. It is too much effort -// to write shared memory iterators that are probably needed for volta to function -// properly. As a result, we allow converters both after the LDG (for volta) and after -// the LDS for Turing+. -template < - /// Iterator for B matrix in global memory - typename IteratorB, - /// Warp level Mma - typename MmaOperator, - /// Math operation perform by warp level operator - typename MathOperator> -struct SetConverters -{ -}; - -// Dequantize after LDG, so set transforms accordingly -template < - /// Iterator for B matrix in global memory - typename IteratorB, - /// Mma Policy - typename MmaOperator> -struct SetConverters -{ - using TransformAfterLDG - = FastInterleavedAndBiasedNumericArrayConverter; - - using TransformAfterLDS = NumericArrayConverter; -}; - -// Dequantize after LDS, so set transforms accordingly - -template < - /// Iterator for B matrix in global memory - typename IteratorB, - /// Mma Policy - typename MmaOperator> -struct SetConverters -{ - using TransformAfterLDG = NumericArrayConverter; - - using TransformAfterLDS - = FastInterleavedAndBiasedNumericArrayConverter; -}; - -//////////////////////////////////////////////////////////////////////////////// - -template < - /// Element type for A matrix operand - typename ElementA_, - /// Layout type for A matrix operand - typename LayoutA_, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB_, - /// Layout type for B matrix operand - typename LayoutB_, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for the input scale - typename ElementScale_, - /// Layout for the scale operand - typename LayoutScale_, - /// Access granularity of Scales in unit of elements - int kAlignmentScale, - /// Element type for internal accumulation - typename ElementAccumulator_, - /// Layout type for C and D matrix operands - typename LayoutC_, - /// Operator class tag - typename OperatorClass_, - /// Tag indicating architecture to tune for - typename ArchTag_, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape_, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape_, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape_, - /// Number of stages used in the pipelined mainloop - int Stages, - /// Operation performed by GEMM - typename Operator_, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, - /// - typename Enable = void> -struct DqMma; - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h deleted file mode 100644 index 17c6346553c..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h +++ /dev/null @@ -1,302 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "cutlass/gemm/threadblock/default_mma.h" -#include "cutlass_extensions/arch/mma.h" - -#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage.h" -#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h" -#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" -#include "cutlass_extensions/tile_interleaved_layout.h" - -#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h" -#include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h" - -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ - -//////////////////////////////////////////////////////////////////////////////// - -template -struct DefaultScaleIteratorsMultistage; - -// Fine grained iterators -template -struct DefaultScaleIteratorsMultistage> -{ - using IteratorScale - = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, Element, - Layout, 0, Alignment>; - - using SmemIteratorScale = IteratorScale; -}; - -// Per column iterators -template -struct DefaultScaleIteratorsMultistage> -{ - // ThreadMap for scale iterator - static_assert((MmaShape::kN % Alignment) == 0, ""); - -private: - using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap, - MmaShape::kN / Alignment, Alignment>; - -public: - // Define iterators over tiles from the scale operand - using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator, - Element, Layout, 0, IteratorScaleThreadMap, Alignment>; - - using SmemIteratorScale = IteratorScale; -}; - -//////////////////////////////////////////////////////////////////////////////// - -template < - /// Type for element A - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Type for element B - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for the input scale - typename ElementScale, - /// Layout for the scale operand - typename LayoutScale, - /// Access granularity of Scales in unit of elements - int kAlignmentScale, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Operator class tag - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Stages in GEMM - int kStages, - /// Operator performed by GEMM - typename Operator_, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear> -struct DqMma= 80 && !layout::IsColumnMajorTileInterleave::value)>::type> -{ - - static_assert(platform::is_same::value || platform::is_same::value - || platform::is_same::value, - "Element A must be fp16, fp8 or bf16"); - - using OperatorInfo = arch::DetagOperator; - using Operator = typename OperatorInfo::Operator; - static_assert(platform::is_same::value, - "Mma multistage must dequantize after ldsm"); - - static_assert(platform::is_same::value || platform::is_same::value, - "Element B must be uint8 or uint4"); - - static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the MmaCore components - // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, - AccessTypeA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, ElementB, LayoutB, 0, ThreadMapB, - AccessTypeB>; - - using ScaleIterators = DefaultScaleIteratorsMultistage; - - // Define iterators over tiles from the scale operand - using IteratorScale = typename ScaleIterators::IteratorScale; - - using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; - - using Converter = FastInterleavedAndBiasedNumericArrayConverter; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage; -}; - -// Specialization to handle column major interleave B -template < - /// Type for element A - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Type for element B - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for the input scale - typename ElementScale, - /// Layout for the scale operand - typename LayoutScale, - /// Access granularity of Scales in unit of elements - int kAlignmentScale, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Operator class tag - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Stages in GEMM - int kStages, - /// Operator performed by GEMM - typename Operator_, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear> -struct DqMma= 80 && layout::IsColumnMajorTileInterleave::value)>::type> -{ - - static_assert(platform::is_same::value || platform::is_same::value - || platform::is_same::value, - "Element A must be fp16, fp8 or bf16"); - - using OperatorInfo = arch::DetagOperator; - using Operator = typename OperatorInfo::Operator; - static_assert(platform::is_same::value, - "Mma multistage must dequantize after ldsm"); - - static_assert(platform::is_same::value || platform::is_same::value, - "Element B must be uint8 or uint4"); - - static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) - ? cutlass::arch::CacheOperation::Global - : cutlass::arch::CacheOperation::Always; - - // Define the MmaCore components - // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, - AccessTypeA>; - -private: - static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; - static constexpr int RowsPerTile = LayoutB::kRowsPerTile; - static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); - static_assert(RowsPerTile == MmaCore::Shape::kK, ""); - - using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; - using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; - static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); - - using GmemIteratorShape - = MatrixShape; - using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, OriginalThreadMap::kThreads, - layout::PitchLinearShape, - MmaCore::kAccessSizeInBits / sizeof_bits::value>; - -public: - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator; - - using ScaleIterators = DefaultScaleIteratorsMultistage; - - // Define iterators over tiles from the scale operand - using IteratorScale = typename ScaleIterators::IteratorScale; - - using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; - - using Converter = FastInterleavedAndBiasedNumericArrayConverter; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage; -}; - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h deleted file mode 100644 index 345cd2eec9a..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h +++ /dev/null @@ -1,284 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "cutlass/gemm/threadblock/default_mma.h" -#include "cutlass_extensions/arch/mma.h" - -#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h" -#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h" -#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" -#include "cutlass_extensions/tile_interleaved_layout.h" - -#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h" -#include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h" - -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ - -//////////////////////////////////////////////////////////////////////////////// - -template -struct DefaultScaleIteratorsPipelined; - -// Fine grained iterators -template -struct DefaultScaleIteratorsPipelined> -{ -private: - using SmemScaleType = half_t; - -public: - using IteratorScale - = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, Element, - Layout, 0, Alignment>; - - using SmemIteratorScale - = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, - SmemScaleType, Layout, 0, Alignment>; -}; - -// Per column iterators -template -struct DefaultScaleIteratorsPipelined> -{ - static_assert((MmaShape::kN % Alignment) == 0, ""); - -private: - // ThreadMap for scale iterator - using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap, - MmaShape::kN / Alignment, Alignment>; - using SmemScaleType = half_t; - -public: - // Define iterators over tiles from the scale operand - using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator, - Element, Layout, 0, IteratorScaleThreadMap, Alignment>; - - using SmemIteratorScale - = cutlass::transform::threadblock::PredicatedTileIterator, SmemScaleType, - Layout, 0, IteratorScaleThreadMap, Alignment>; -}; - -//////////////////////////////////////////////////////////////////////////////// - -template < - /// Type for element A - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Type for element B - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for the input scale - typename ElementScale, - /// Layout for the scale operand - typename LayoutScale, - /// Access granularity of Scales in unit of elements - int kAlignmentScale, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Operator class tag - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator_> -struct DqMma::value)>::type> -{ - - static_assert(platform::is_same::value || platform::is_same::value, - "Element A must be fp16 or bf16"); - - static_assert(platform::is_same::value || platform::is_same::value, - "Element B must be uint8 or uint4"); - - using OperatorInfo = arch::DetagOperator; - using Operator = typename OperatorInfo::Operator; - static_assert(OperatorInfo::QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); - - static constexpr bool DqAfterLDG = platform::is_same::value; - using MmaCoreElementA = half_t; - using MmaCoreElementB = typename platform::conditional::type; - - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, ElementA, LayoutA, 1, - typename MmaCore::IteratorThreadMapA, kAlignmentA>; - - // Define iterators over tiles from the B operand - using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, ElementB, LayoutB, 0, - typename MmaCore::IteratorThreadMapB, kAlignmentB>; - - using ScaleIterators = DefaultScaleIteratorsPipelined; - - // Define iterators over tiles from the scale operand - using IteratorScale = typename ScaleIterators::IteratorScale; - - using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; - - using Converters = SetConverters; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined; -}; - -// Specialization to handle column major interleave B -template < - /// Type for element A - typename ElementA, - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Type for element B - typename ElementB, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for the input scale - typename ElementScale, - /// Layout for the scale operand - typename LayoutScale, - /// Access granularity of Scales in unit of elements - int kAlignmentScale, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Operator class tag - typename OperatorClass, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator_> -struct DqMma::value)>::type> -{ - - static_assert(platform::is_same::value || platform::is_same::value, - "Element A must be fp16 or bf16"); - - static_assert(platform::is_same::value || platform::is_same::value, - "Element B must be uint8 or uint4"); - - using OperatorInfo = arch::DetagOperator; - using Operator = typename OperatorInfo::Operator; - - static constexpr bool DqAfterLDG = platform::is_same::value; - using MmaCoreElementA = half_t; - using MmaCoreElementB = typename platform::conditional::type; - - // Define the MmaCore components - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, ElementA, LayoutA, 1, - typename MmaCore::IteratorThreadMapA, kAlignmentA>; - -private: - static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; - static constexpr int RowsPerTile = LayoutB::kRowsPerTile; - static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); - static_assert(RowsPerTile == MmaCore::Shape::kK, ""); - - using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; - using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; - static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); - - using GmemIteratorShape - = MatrixShape; - using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< - layout::PitchLinearShape, OriginalThreadMap::kThreads, - layout::PitchLinearShape, - MmaCore::kAccessSizeInBits / sizeof_bits::value>; - -public: - // Define iterators over tiles from the B operand - using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator; - - // ThreadMap for scale iterator - static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); - using IteratorScaleThreadMap - = transform::PitchLinearStripminedThreadMap, - MmaCore::Shape::kN / kAlignmentScale, kAlignmentScale>; - - using ScaleIterators = DefaultScaleIteratorsPipelined; - - // Define iterators over tiles from the scale operand - using IteratorScale = typename ScaleIterators::IteratorScale; - - using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; - - using Converters = SetConverters; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined; -}; - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h deleted file mode 100644 index ad6c7496e14..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h +++ /dev/null @@ -1,351 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" -#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" -#include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h" - -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma pipelined (stage=2) -template < - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator> -struct DefaultMma -{ - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; -}; - -//////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma pipelined (stage=2) -template < - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator> -struct DefaultMma -{ - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; -}; - -//////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage -/// (stage>=3) -template < - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator, - /// - int kStages, - /// Shared memory clear option - SharedMemoryClearOption SharedMemoryClear> -struct DefaultMma -{ - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; -}; - -//////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma multistage -/// (stage>=3) -template < - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator, - /// - int kStages, - /// Shared memory clear option - SharedMemoryClearOption SharedMemoryClear> -struct DefaultMma -{ - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; -}; - -#ifdef ENABLE_FP8 -//////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage -/// (stage>=3) -template < - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator, - /// - int kStages, - /// Shared memory clear option - SharedMemoryClearOption SharedMemoryClear> -struct DefaultMma -{ - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; -}; - -#endif - -// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on -// large tile when not enough shared mem is present to do 3+ stage -template < - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear, - /// Gather operand A by using an index array - bool GatherA, - /// Gather operand B by using an index array - bool GatherB> -struct DefaultMma -{ - - // Define the MmaCore components - // 3 is used on purpose here to trigger components for mma multistage - using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, half_t, LayoutA, 1, ThreadMapA, AccessTypeA, - GatherA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, half_t, LayoutB, 0, ThreadMapB, AccessTypeB, - GatherB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage; -}; - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h deleted file mode 100644 index 77af81005ab..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h +++ /dev/null @@ -1,353 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "cutlass/gemm/threadblock/default_mma.h" -#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" -#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" - -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & bf16 weight -template < - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear, - /// Gather operand A by using an index array - bool GatherA, - /// Gather operand B by using an index array - bool GatherB> -struct DefaultMma -{ - -private: - // Conversions only needed pre-ampere. This will trigger mma pipeline, so we convert before STS. - static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80; - using MmaElementA = typename platform::conditional::type; - using MmaElementB = typename platform::conditional::type; - -public: - // Define the MmaCore components - using MmaCore = - typename cutlass::gemm::threadblock::DefaultMmaCore; - - using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, bfloat16_t, LayoutA, 1, - typename MmaCore::IteratorThreadMapA, kAlignmentA, GatherA>; - - // Define iterators over tiles from the B operand - using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< - cutlass::MatrixShape, bfloat16_t, LayoutB, 0, - typename MmaCore::IteratorThreadMapB, kAlignmentB, GatherB>; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined; -}; - -// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on -// large tile when not enough shared mem is present to do 3+ stage -template < - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear, - /// Gather operand A by using an index array - bool GatherA, - /// Gather operand B by using an index array - bool GatherB> -struct DefaultMma -{ - - // Define the MmaCore components - // 3 is used on purpose here to trigger components for mma multistage - using MmaCore = - typename cutlass::gemm::threadblock::DefaultMmaCore; - - // Define iterators over tiles from the A operand - using ThreadMapA = typename MmaCore::IteratorThreadMapA; - using AccessTypeA = cutlass::Array; - using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, bfloat16_t, LayoutA, 1, ThreadMapA, - AccessTypeA, GatherA>; - - // Define iterators over tiles from the B operand - using ThreadMapB = typename MmaCore::IteratorThreadMapB; - using AccessTypeB = cutlass::Array; - using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< - cutlass::MatrixShape, bfloat16_t, LayoutB, 0, ThreadMapB, - AccessTypeB, GatherB>; - - // Define the threadblock-scoped multistage matrix multiply - using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage; -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight -template < - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator> -struct DefaultMma -{ - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; -}; - -//////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight -template < - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator> -struct DefaultMma -{ - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; -}; - -template < - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator, - /// - int kStages, - /// Shared memory clear option - SharedMemoryClearOption SharedMemoryClear> -struct DefaultMma -{ - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; -}; - -//////////////////////////////////////////////////////////////////////////////// -/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight -template < - /// Layout type for A matrix operand - typename LayoutA, - /// Access granularity of A matrix in units of elements - int kAlignmentA, - /// Layout type for B matrix operand - typename LayoutB, - /// Access granularity of B matrix in units of elements - int kAlignmentB, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Tag indicating architecture to tune for - typename ArchTag, - /// Threadblock-level tile size (concept: GemmShape) - typename ThreadblockShape, - /// Warp-level tile size (concept: GemmShape) - typename WarpShape, - /// Instruction-level tile size (concept: GemmShape) - typename InstructionShape, - /// Operation performed by GEMM - typename Operator, - /// - int kStages, - /// Shared memory clear option - SharedMemoryClearOption SharedMemoryClear> -struct DefaultMma -{ - -private: - static constexpr int kAlignmentScale = 128 / sizeof_bits::value; - - using Mma = DqMma; - -public: - // Define the MmaCore components - using MmaCore = typename Mma::MmaCore; - - // Define iterators over tiles from the A operand - using IteratorA = typename Mma::IteratorA; - - // Define iterators over tiles from the B operand - using IteratorB = typename Mma::IteratorB; - - // Define the threadblock-scoped pipelined matrix multiply - using ThreadblockMma = typename Mma::ThreadblockMma; -}; - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h deleted file mode 100644 index 1fb7f7eb28f..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h +++ /dev/null @@ -1,257 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. -*/ - -#pragma once - -#include "cutlass/aligned_buffer.h" -#include "cutlass/arch/memory.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/threadblock/mma_base.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" - -#include "cutlass_extensions/weight_only_quant_op.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ - -//////////////////////////////////////////////////////////////////////////////// -// SFINAE trick so I can keep the same loop code for Volta and dispatch to the -// correct warp level mma. On volta, all data is stored to shared memory as FP16. -template -CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D, - typename WarpMma::FragmentA const& A, typename WarpMma::FragmentB const& B, typename WarpMma::FragmentC const& C, - int const warp_tileB_k_offset) -{ - warp_mma(D, A, B, C); -} - -template -CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D, - typename WarpMma::TransformedFragmentA const& A, typename WarpMma::TransformedFragmentB const& B, - typename WarpMma::FragmentC const& C, int const warp_tileB_k_offset) -{ - warp_mma(D, A, B, C, warp_tileB_k_offset); -} - -//////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// The type of the scales - typename ElementScale_, - /// Number of stages, - int Stages, - /// The dequantizing op to be performed. - WeightOnlyQuantOp DequantOp, - /// Used for partial specialization, - typename Enable = bool> -class DqMmaBase -{ -public: - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - - ///< Policy describing tuning details - using Policy = Policy_; - - ///< Type of the scale to be loaded - using ElementScale = ElementScale_; - - static_assert(DequantOp != WeightOnlyQuantOp::UNDEFINED, ""); - - // Finegrained scales get streamed in via cp.async - static constexpr int ScalebiasStages = isFinegrained(DequantOp) ? Stages : 1; - // We always have scales. - static constexpr int ScaleElementsPerStage = Shape::kN; - // We sometimes have a bias - static constexpr int BiasElementsPerStage = hasZero(DequantOp) ? Shape::kN : 0; - - // - // Dependent types - // - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Shape describing the overall GEMM computed from shared memory - /// by each warp. - using WarpGemm = typename Policy::Operator::Shape; - - /// Shape describing the number of warps filling the CTA - using WarpCount = GemmShape; - - /// Number of warp-level GEMM operations - static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); - - static constexpr int kNumKIterationsPerWarpBLoad - = Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK; - - static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), ""); - static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad; - - /// Number of stages - static int const kStages = Stages; - - /// Tensor reference to the A operand - using TensorRefA = TensorRef; - - /// Tensor reference to the B operand - using TensorRefB = TensorRef; - - // - // Nested structs - // - - /// Shared storage object needed by threadblock-scoped GEMM - class SharedStorage - { - public: - // - // Type definitions - // - - /// Shape of the A matrix operand in shared memory - using ShapeA - = MatrixShape; - - /// Shape of the B matrix operand in shared memory - using ShapeB - = MatrixShape; - - /// Shape of the shared memory buffer for the scales for the B matrix. - using ShapeScale = MatrixShape; - /// Shape of the shared memory buffer for the biases of the B matrix. - using ShapeZero = MatrixShape; - - public: - // - // Data members - // - - /// Buffer for A operand - AlignedBuffer operand_A; - - /// Buffer for B operand - AlignedBuffer operand_B; - - /// Buffer to hold scales for threadblock - AlignedBuffer operand_scale; - - /// Buffer to hold scales for threadblock - AlignedBuffer operand_zero; - - public: - // - // Methods - // - - /// Returns a layout object for the A matrix - CUTLASS_DEVICE - static typename Operator::LayoutA LayoutA() - { - return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); - } - - /// Returns a layout object for the B matrix - CUTLASS_HOST_DEVICE - static typename Operator::LayoutB LayoutB() - { - return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); - } - - /// Returns a TensorRef to the A operand - CUTLASS_HOST_DEVICE - TensorRefA operand_A_ref() - { - return TensorRefA{operand_A.data(), LayoutA()}; - } - - /// Returns a TensorRef to the B operand - CUTLASS_HOST_DEVICE - TensorRefB operand_B_ref() - { - return TensorRefB{operand_B.data(), LayoutB()}; - } - }; - -protected: - // - // Data members - // - - /// Iterator to load a warp-scoped tile of A operand from shared memory - typename Operator::IteratorA warp_tile_iterator_A_; - - /// Iterator to load a warp-scoped tile of B operand from shared memory - typename Operator::IteratorB warp_tile_iterator_B_; - -public: - /// Construct from tensor references - CUTLASS_DEVICE - DqMmaBase( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - SharedStorage& shared_storage, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx) - : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx) - , warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) - { - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h deleted file mode 100644 index 3c4036dd8cc..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h +++ /dev/null @@ -1,110 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. -*/ - -#pragma once - -#include "cutlass/aligned_buffer.h" -#include "cutlass/arch/memory.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" - -#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" -#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" -#include "cutlass_extensions/interleaved_numeric_conversion.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Cache operation for operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Cache operation for operand B - cutlass::arch::CacheOperation::Kind CacheOpB, - /// Data type for the scales - typename IteratorScale_, - /// Iterators over scales in shared memory - typename SmemIteratorScale_, - /// Data type of accumulator matrix - typename ElementC_, - /// Data type of accumulator matrix - typename LayoutC_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Number of stages, - int Stages, - /// Converter for B matrix applited immediately after the LDS - typename TransformBAfterLDS_, - /// The quantization operator being used - WeightOnlyQuantOp QuantOp_, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, - /// Used for partial specialization - typename Enable = void> -class DqMmaMultistage; - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h" -#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h" diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h deleted file mode 100644 index f81961dee3c..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h +++ /dev/null @@ -1,708 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. -*/ - -#pragma once - -#include "cutlass/aligned_buffer.h" -#include "cutlass/arch/memory.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" - -#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" -#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" -#include "cutlass_extensions/interleaved_numeric_conversion.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Cache operation for operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Cache operation for operand B - cutlass::arch::CacheOperation::Kind CacheOpB, - /// Iterators over scales in global memory - typename IteratorScale_, - /// Iterators over scales in shared memory - typename SmemIteratorScale_, - /// Data type of accumulator matrix - typename ElementC_, - /// Layout of accumulator matrix - typename LayoutC_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Number of stages, - int Stages, - /// Converter for B matrix applied immediately after the LDS - typename TransformBAfterLDS_, - /// The quantization operator being used - WeightOnlyQuantOp QuantOp_, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear> -class DqMmaMultistage> - : public DqMmaBase -{ -public: - ///< Base class - using Base = DqMmaBase; - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - ///< Iterates over tiles of A operand in global memory - using IteratorA = IteratorA_; - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; - ///< Data type of accumulator matrix - using ElementC = ElementC_; - ///< Layout of accumulator matrix - using LayoutC = LayoutC_; - ///< Policy describing tuning details - using Policy = Policy_; - - using IteratorScale = IteratorScale_; - using ElementScale = typename IteratorScale::Element; - using LayoutScale = typename IteratorScale::Layout; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - using SmemIteratorScale = SmemIteratorScale_; - - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - using TransformBAfterLDS = TransformBAfterLDS_; - - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; - // - // Dependent types - // - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Minimum architecture is Sm80 to support cp.async - using ArchTag = arch::Sm80; - - using Dequantizer = warp::MmaTensorOpDequantizer; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - static_assert(Base::SharedStorage::ShapeScale::kRow == Stages, ""); - static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, ""); - - /// Internal structure exposed for introspection. - struct Detail - { - - static_assert(Base::kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - - /// Number of cp.async instructions to load one stage of operand A - static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; - - /// Number of cp.async instructions to load one stage of operand B - static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; - - /// Number of stages - static int const kStages = Stages; - - /// Number of cp.async instructions to load on group of operand A - static int const kAccessesPerGroupA - = (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - - /// Number of cp.async instructions to load on group of operand B - static int const kAccessesPerGroupB - = (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - }; - -private: - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - Dequantizer warp_dequantizer_; - - using ElementA = typename IteratorA::Element; - using ElementB = typename IteratorB::Element; - using LayoutDetailsForB = kernel::LayoutDetailsB; - - static constexpr bool RequiresTileInterleave - = layout::IsColumnMajorTileInterleave::value; - static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), - "Layout K must match threadblockK"); - -private: - // - // Data members - // - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - - /// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory - SmemIteratorScale smem_iterator_scale_; - -public: - /// Construct from tensor references - CUTLASS_DEVICE - DqMmaMultistage( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage& shared_storage, - /// The group size for quantization - int const group_size, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx) - : Base(shared_storage, thread_idx, warp_idx, lane_idx) - , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, - {shared_storage.operand_zero.data(), LayoutScale(Shape::kN)}, - (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) - , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) - , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) - , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), - shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size) - { - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); - } - - CUTLASS_DEVICE - void copy_scales_and_advance(IteratorScale& iterator_scale, int stage = -1, int k_iter = -1) - { - static_assert(IteratorScale::Shape::kRow == 1, "Scale stride must be 1."); - - typename IteratorScale::AccessType* gmem_scale_ptr = iterator_scale.get_scale(); - typename IteratorScale::AccessType* gmem_zero_ptr = iterator_scale.get_zero(); - - typename IteratorScale::AccessType* smem_scale_ptr - = reinterpret_cast(this->smem_iterator_scale_.get_scale()); - typename IteratorScale::AccessType* smem_zero_ptr - = reinterpret_cast(this->smem_iterator_scale_.get_zero()); - - int const kSrcBytes = sizeof_bits::value * IteratorScale::kAlignment / 8; - - cutlass::arch::cp_async(smem_scale_ptr, gmem_scale_ptr, iterator_scale.valid()); - - if (gmem_zero_ptr != nullptr) - { - cutlass::arch::cp_async(smem_zero_ptr, gmem_zero_ptr, iterator_scale.valid()); - } - - if (iterator_scale.group_size_ == 64) - { - iterator_scale.add_tile_offset({1, 0}); - } - else if (iterator_scale.group_size_ == 128) - { - if constexpr (Shape::kK == 128) - { - iterator_scale.add_tile_offset({1, 0}); - } - else if constexpr (Shape::kK == 64) - { - if (iterator_scale.row_groupsize64_ & 0x1) - { - iterator_scale.add_tile_offset({1, 0}); - } - } - else - { - static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128"); - } - } - - iterator_scale.row_groupsize64_++; - - this->smem_iterator_scale_.add_tile_offset({1, 0}); - } - - CUTLASS_DEVICE - void copy_tiles_and_advance( - IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0) - { - iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); - this->smem_iterator_A_.set_iteration_index(group_start_A); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) - { - if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) - { - typename IteratorA::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_A_.get()); - - int const kSrcBytes = sizeof_bits::value - * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) - { - auto gmem_ptr = iterator_A.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) - { - cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - else - { - cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - } - - iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); - this->smem_iterator_B_.set_iteration_index(group_start_B); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) - { - if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) - { - typename IteratorB::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_B_.get()); - - int const kSrcBytes = sizeof_bits::value - * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) - { - auto gmem_ptr = iterator_B.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) - { - cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - else - { - cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - - ++iterator_B; - } - ++this->smem_iterator_B_; - } - } - } - - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC& accum, - ///< iterator over A operand in global memory - IteratorA iterator_A, - ///< iterator over B operand in global memory - IteratorB iterator_B, - ///< iterator over scale operand in global memory - IteratorScale iterator_scale, - ///< initial value of accumulator - FragmentC const& src_accum) - { - - // - // Prologue - // - - TransformBAfterLDS lds_converter; - - // Issue several complete stages - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) - { - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - iterator_scale.clear_mask(gemm_k_iterations == 0); - - iterator_A.set_iteration_index(0); - this->smem_iterator_A_.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) - { - typename IteratorA::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_A_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) - { - int const kSrcBytes = sizeof_bits::value - * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_A.get(), iterator_A.valid()); - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - - iterator_B.set_iteration_index(0); - this->smem_iterator_B_.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) - { - typename IteratorB::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_B_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) - { - int const kSrcBytes = sizeof_bits::value - * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_B.get(), iterator_B.valid()); - - ++iterator_B; - } - - ++this->smem_iterator_B_; - } - - copy_scales_and_advance(iterator_scale, stage, gemm_k_iterations); - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Defines the boundary of a stage of cp.async. - cutlass::arch::cp_async_fence(); - } - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - // - // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels - // so that all accumulator elements outside the GEMM footprint are zero. - // - - if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) - { - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); - - typename IteratorA::AccessType zero_A; - zero_A.clear(); - - last_smem_iterator_A.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) - { - - typename IteratorA::AccessType* dst_ptr - = reinterpret_cast(last_smem_iterator_A.get()); - - *dst_ptr = zero_A; - - ++last_smem_iterator_A; - } - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); - typename IteratorB::AccessType zero_B; - - zero_B.clear(); - last_smem_iterator_B.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) - { - - typename IteratorB::AccessType* dst_ptr - = reinterpret_cast(last_smem_iterator_B.get()); - - *dst_ptr = zero_B; - - ++last_smem_iterator_B; - } - } - - // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) - cutlass::arch::cp_async_wait(); - __syncthreads(); - - // Pair of fragments used to overlap shared memory loads and math - // instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; - typename Dequantizer::FragmentScale warp_frag_scales; - typename Dequantizer::FragmentZero warp_frag_zeros; - - Operator warp_mma; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.load(warp_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_frag_B[0]); - - warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - warp_dequantizer_.add_pointer_offset(Shape::kN); - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - iterator_scale.clear_mask(gemm_k_iterations == 0); - - int smem_write_stage_idx = Base::kStages - 1; - int smem_read_stage_idx = 0; - - // - // Mainloop - // - - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > (-Base::kStages + 1);) - { - // - // Loop over GEMM K dimension - // - - // Computes a warp-level GEMM on data held in shared memory - // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) - { - - // Load warp-level tiles from shared memory, wrapping to k offset if - // this is the last group as the case may be. - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_A_; - - int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; - int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; - if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) - { - this->warp_tile_iterator_B_.set_kgroup_index( - (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - ++this->warp_tile_iterator_B_; - } - - typename TransformBAfterLDS::result_type converted_frag_B - = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zeros); - - using FragmentOperandB = cutlass::Array; - constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; - constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements; - static_assert(ConversionVectorWidth == FragmentOperandB::kElements); - - using Converter - = cutlass::NumericArrayConverter; - - FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B); - run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum, - warp_tileB_k_compute_offset); - - // Issue global->shared copies for the this stage - if (warp_mma_k < Base::kWarpGemmIterations - 1) - { - int group_start_iteration_A, group_start_iteration_B; - - group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; - group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); - - // This is the first group of a given stage, so we issue the loads for the B scales immediately. - if (group_start_iteration_B == 0) - { - copy_scales_and_advance(iterator_scale); - } - } - - if (warp_mma_k + 2 == Base::kWarpGemmIterations) - { - int group_start_iteration_A, group_start_iteration_B; - group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; - group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); - - // Inserts a memory fence between stages of cp.async instructions. - cutlass::arch::cp_async_fence(); - - // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - - // #committed) - arch::cp_async_wait(); - __syncthreads(); - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Add negative offsets to return iterators to the 'start' of the - // circular buffer in shared memory - if (smem_write_stage_idx == (Base::kStages - 1)) - { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0}); - smem_write_stage_idx = 0; - } - else - { - ++smem_write_stage_idx; - } - - if (smem_read_stage_idx == (Base::kStages - 1)) - { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); - warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN); - smem_read_stage_idx = 0; - } - else - { - ++smem_read_stage_idx; - } - - --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - iterator_scale.clear_mask(gemm_k_iterations == 0); - } - } - - // Load the scale needed for the next tile iteration. - warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); - // Update internal pointer to set of scales in shared memory. - warp_dequantizer_.add_pointer_offset(Shape::kN); - } - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) - { - // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h deleted file mode 100644 index 83efdc5cb01..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h +++ /dev/null @@ -1,647 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. -*/ - -#pragma once - -#include "cutlass/aligned_buffer.h" -#include "cutlass/arch/memory.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" - -#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" -#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" -#include "cutlass_extensions/interleaved_numeric_conversion.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math -/// instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Cache operation for operand A - cutlass::arch::CacheOperation::Kind CacheOpA, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | - // MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Cache operation for operand B - cutlass::arch::CacheOperation::Kind CacheOpB, - /// Iterators over scales in global memory - typename IteratorScale_, - /// Iterators over scales in shared memory - typename SmemIteratorScale_, - /// Data type of accumulator matrix - typename ElementC_, - /// Layout of accumulator matrix - typename LayoutC_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Number of stages, - int Stages, - /// Converter for B matrix applited immediately after the LDS - typename TransformBAfterLDS_, - /// The quantization operator being used - WeightOnlyQuantOp QuantOp_, - /// Use zfill or predicate for out-of-bound cp.async - SharedMemoryClearOption SharedMemoryClear> -class DqMmaMultistage> - : public DqMmaBase -{ -public: - ///< Base class - using Base = DqMmaBase; - ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using Shape = Shape_; - ///< Iterates over tiles of A operand in global memory - using IteratorA = IteratorA_; - ///< Iterates over tiles of B operand in global memory - using IteratorB = IteratorB_; - ///< Data type of accumulator matrix - using ElementC = ElementC_; - ///< Layout of accumulator matrix - using LayoutC = LayoutC_; - ///< Policy describing tuning details - using Policy = Policy_; - - using IteratorScale = IteratorScale_; - using ElementScale = typename IteratorScale::Element; - using LayoutScale = typename IteratorScale::Layout; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - using SmemIteratorScale = SmemIteratorScale_; - - static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; - static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; - - using TransformBAfterLDS = TransformBAfterLDS_; - - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; - - // - // Dependent types - // - - /// Fragment of operand Scale loaded from global memory; - using FragmentScale = typename IteratorScale::Fragment; - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Minimum architecture is Sm80 to support cp.async - using ArchTag = arch::Sm80; - - using Dequantizer = warp::MmaTensorOpDequantizer; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - /// Internal structure exposed for introspection. - struct Detail - { - - static_assert(Base::kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - - /// Number of cp.async instructions to load one stage of operand A - static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; - - /// Number of cp.async instructions to load one stage of operand B - static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; - - /// Number of stages - static int const kStages = Stages; - - /// Number of cp.async instructions to load on group of operand A - static int const kAccessesPerGroupA - = (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - - /// Number of cp.async instructions to load on group of operand B - static int const kAccessesPerGroupB - = (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; - }; - -private: - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - Dequantizer warp_dequantizer_; - - using ElementA = typename IteratorA::Element; - using ElementB = typename IteratorB::Element; - using LayoutDetailsForB = kernel::LayoutDetailsB; - - static constexpr bool RequiresTileInterleave - = layout::IsColumnMajorTileInterleave::value; - static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), - "Layout K must match threadblockK"); - -private: - // - // Data members - // - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - - /// Iterator to write threadblock-scoped tile of scale operand to shared memory - SmemIteratorScale smem_iterator_scale_; - -public: - /// Construct from tensor references - CUTLASS_DEVICE - DqMmaMultistage( - ///< Shared storage needed for internal use by threadblock-scoped GEMM - typename Base::SharedStorage& shared_storage, - ///< Group size for quantization. Not used by this main loop since it assumes per-column - int const group_size, - ///< ID within the threadblock - int thread_idx, - ///< ID of warp - int warp_idx, - ///< ID of each thread within a warp - int lane_idx) - : Base(shared_storage, thread_idx, warp_idx, lane_idx) - , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, - (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) - , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) - , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) - , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx) - { - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); - } - - CUTLASS_DEVICE - void copy_tiles_and_advance( - IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0) - { - iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); - this->smem_iterator_A_.set_iteration_index(group_start_A); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) - { - if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) - { - typename IteratorA::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_A_.get()); - - int const kSrcBytes = sizeof_bits::value - * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) - { - auto gmem_ptr = iterator_A.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) - { - cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - else - { - cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_A.valid()); - } - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - } - - iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); - this->smem_iterator_B_.set_iteration_index(group_start_B); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) - { - if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) - { - typename IteratorB::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_B_.get()); - - int const kSrcBytes = sizeof_bits::value - * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) - { - auto gmem_ptr = iterator_B.get(); - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) - { - cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - else - { - cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_B.valid()); - } - - ++iterator_B; - } - ++this->smem_iterator_B_; - } - } - } - - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()( - ///< problem size of GEMM - int gemm_k_iterations, - ///< destination accumulator tile - FragmentC& accum, - ///< iterator over A operand in global memory - IteratorA iterator_A, - ///< iterator over B operand in global memory - IteratorB iterator_B, - ///< iterator over scale operand in global memory - IteratorScale iterator_scale, - ///< initial value of accumulator - FragmentC const& src_accum) - { - - // - // Prologue - // - - TransformBAfterLDS lds_converter; - - // NOTE - switch to ldg.sts - // Issue this first, so cp.async.commit_group will commit this load as well. - // Note: we do not commit here and this load will commit in the same group as - // the first load of A. - FragmentScale tb_frag_scales; - tb_frag_scales.clear(); - iterator_scale.load(tb_frag_scales); - this->smem_iterator_scale_.store(tb_frag_scales); - - // Issue several complete stages - CUTLASS_PRAGMA_UNROLL - for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) - { - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - - iterator_A.set_iteration_index(0); - this->smem_iterator_A_.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) - { - typename IteratorA::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_A_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) - { - int const kSrcBytes = sizeof_bits::value - * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; - - int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_A.get(), iterator_A.valid()); - - ++iterator_A; - } - - ++this->smem_iterator_A_; - } - - iterator_B.set_iteration_index(0); - this->smem_iterator_B_.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) - { - typename IteratorB::AccessType* dst_ptr - = reinterpret_cast(this->smem_iterator_B_.get()); - - CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) - { - int const kSrcBytes = sizeof_bits::value - * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; - - cutlass::arch::cp_async_zfill( - dst_ptr + v, iterator_B.get(), iterator_B.valid()); - - ++iterator_B; - } - - ++this->smem_iterator_B_; - } - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Defines the boundary of a stage of cp.async. - cutlass::arch::cp_async_fence(); - } - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - // - // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels - // so that all accumulator elements outside the GEMM footprint are zero. - // - - if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) - { - - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); - - typename IteratorA::AccessType zero_A; - zero_A.clear(); - - last_smem_iterator_A.set_iteration_index(0); - - // Async Copy for operand A - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) - { - - typename IteratorA::AccessType* dst_ptr - = reinterpret_cast(last_smem_iterator_A.get()); - - *dst_ptr = zero_A; - - ++last_smem_iterator_A; - } - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); - typename IteratorB::AccessType zero_B; - - zero_B.clear(); - last_smem_iterator_B.set_iteration_index(0); - - // Async Copy for operand B - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) - { - - typename IteratorB::AccessType* dst_ptr - = reinterpret_cast(last_smem_iterator_B.get()); - - *dst_ptr = zero_B; - - ++last_smem_iterator_B; - } - } - - // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) - cutlass::arch::cp_async_wait(); - __syncthreads(); - - // Pair of fragments used to overlap shared memory loads and math - // instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; - typename Dequantizer::FragmentScale warp_frag_scales; - - Operator warp_mma; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.load(warp_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_frag_B[0]); - warp_dequantizer_.load(warp_frag_scales); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - - int smem_write_stage_idx = Base::kStages - 1; - int smem_read_stage_idx = 0; - - // - // Mainloop - // - - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > (-Base::kStages + 1);) - { - // - // Loop over GEMM K dimension - // - - // Computes a warp-level GEMM on data held in shared memory - // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) - { - - // Load warp-level tiles from shared memory, wrapping to k offset if - // this is the last group as the case may be. - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_A_; - - int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; - int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; - if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) - { - this->warp_tile_iterator_B_.set_kgroup_index( - (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - ++this->warp_tile_iterator_B_; - } - - typename TransformBAfterLDS::result_type converted_frag_B - = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); - - using FragmentOperandB = cutlass::Array; - constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; - constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements; - static_assert(ConversionVectorWidth == FragmentOperandB::kElements); - - using Converter - = cutlass::NumericArrayConverter; - - FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B); - run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum, - warp_tileB_k_compute_offset); - - // Issue global->shared copies for the this stage - if (warp_mma_k < Base::kWarpGemmIterations - 1) - { - int group_start_iteration_A, group_start_iteration_B; - - group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; - group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); - } - - if (warp_mma_k + 2 == Base::kWarpGemmIterations) - { - int group_start_iteration_A, group_start_iteration_B; - group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; - group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; - - copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); - - // Inserts a memory fence between stages of cp.async instructions. - cutlass::arch::cp_async_fence(); - - // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - - // #committed) - arch::cp_async_wait(); - __syncthreads(); - - // Move to the next stage - iterator_A.add_tile_offset({0, 1}); - iterator_B.add_tile_offset({1, 0}); - - this->smem_iterator_A_.add_tile_offset({0, 1}); - this->smem_iterator_B_.add_tile_offset({1, 0}); - - // Add negative offsets to return iterators to the 'start' of the - // circular buffer in shared memory - if (smem_write_stage_idx == (Base::kStages - 1)) - { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - smem_write_stage_idx = 0; - } - else - { - ++smem_write_stage_idx; - } - - if (smem_read_stage_idx == (Base::kStages - 1)) - { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); - smem_read_stage_idx = 0; - } - else - { - ++smem_read_stage_idx; - } - - --gemm_k_iterations; - iterator_A.clear_mask(gemm_k_iterations == 0); - iterator_B.clear_mask(gemm_k_iterations == 0); - } - } - } - - if (SharedMemoryClear == SharedMemoryClearOption::kZfill) - { - // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop - cutlass::arch::cp_async_fence(); - cutlass::arch::cp_async_wait<0>(); - __syncthreads(); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h deleted file mode 100644 index bd3e38971b0..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h +++ /dev/null @@ -1,106 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. -*/ - -#pragma once - -#include "cutlass/aligned_buffer.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/numeric_conversion.h" - -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/gemm.h" - -#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" -#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" -#include "cutlass_extensions/interleaved_numeric_conversion.h" - -#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" -#include "cutlass_extensions/gemm_configs.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Data type for the scales - typename IteratorScale_, - /// Iterators over scales in shared memory - typename SmemIteratorScale_, - /// Data type of accumulator matrix - typename ElementC_, - /// Data type of accumulator matrix - typename LayoutC_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Converter for B matrix applied immediately after the LDG (before STS) - typename TransformBAfterLDG_, - /// Converter for B matrix applited immediately after the LDS - typename TransformBAfterLDS_, - /// The quantization operator being used - WeightOnlyQuantOp QuantOp_, - /// Used for partial specialization - typename Enable = void> -class DqMmaPipelined; - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h" -#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h" diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h deleted file mode 100644 index 50bdd0d85b0..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h +++ /dev/null @@ -1,486 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. -*/ - -#pragma once - -#include "cutlass/aligned_buffer.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/numeric_conversion.h" - -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/gemm.h" - -#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" -#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" -#include "cutlass_extensions/interleaved_numeric_conversion.h" - -#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" -#include "cutlass_extensions/gemm_configs.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Iterators over scales in global memory - typename IteratorScale_, - /// Iterators over scales in shared memory - typename SmemIteratorScale_, - /// Data type of accumulator matrix - typename ElementC_, - /// Layout of accumulator matrix - typename LayoutC_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Converter for B matrix applied immediately after the LDG (before STS) - typename TransformBAfterLDG_, - /// Converter for B matrix applited immediately after the LDS - typename TransformBAfterLDS_, - /// The quantization operator being used - WeightOnlyQuantOp QuantOp_> -class DqMmaPipelined> - : public DqMmaBase -{ -public: - ///< Base class - using Base = DqMmaBase; - - using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory - using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory - using ElementC = ElementC_; ///< Data type of accumulator matrix - using LayoutC = LayoutC_; ///< Layout of accumulator matrix - using Policy = Policy_; ///< Policy describing tuning details - - using IteratorScale = IteratorScale_; - using ElementScale = typename IteratorScale::Element; - using LayoutScale = typename IteratorScale::Layout; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - using SmemIteratorScale = SmemIteratorScale_; - - using TransformBAfterLDG = TransformBAfterLDG_; - using TransformBAfterLDS = TransformBAfterLDS_; - - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; - - // - // Dependent types - // - - /// Fragment of operand A loaded from global memory - using FragmentA = typename IteratorA::Fragment; - - /// Fragment of operand B loaded from global memory - using FragmentB = typename IteratorB::Fragment; - - /// Fragment of operand Scale loaded from global memory; - using FragmentScale = typename IteratorScale::Fragment; - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Obtain the arch tag from the warp-level operator - using ArchTag = typename Policy::Operator::ArchTag; - - using Dequantizer = warp::MmaTensorOpDequantizer; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - // staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline) - static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2"); - - static_assert(Base::SharedStorage::ShapeScale::kRow == Base::kStages, ""); - static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, ""); - -private: - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - Dequantizer warp_dequantizer_; - - using WarpFragmentScale = typename Dequantizer::FragmentScale; - using WarpFragmentZero = typename Dequantizer::FragmentZero; - - using ElementA = typename IteratorA::Element; - using ElementB = typename IteratorB::Element; - using LayoutDetailsForB = kernel::LayoutDetailsB; - - static constexpr bool RequiresTileInterleave - = layout::IsColumnMajorTileInterleave::value; - static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), - "Layout K must match threadblockK"); - -protected: - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - - /// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory - SmemIteratorScale smem_iterator_scale_; - -public: - /// Construct from tensor references - CUTLASS_DEVICE - DqMmaPipelined(typename Base::SharedStorage& - shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM - int const group_size, ///< The group size for quantization - int thread_idx, ///< ID within the threadblock - int warp_idx, ///< ID of warp - int lane_idx ///< ID of each thread within a warp - ) - : Base(shared_storage, thread_idx, warp_idx, lane_idx) - , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, - {shared_storage.operand_zero.data(), LayoutScale(Shape::kN)}, - (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) - , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) - , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) - , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), - shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size) - { - - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); - } - - CUTLASS_DEVICE - void copy_scales_and_advance(IteratorScale& iterator_scale) - { - using TransformScale = NumericArrayConverter; - - FragmentScale tb_frag_scales; - FragmentScale tb_frag_zeros; - tb_frag_scales.clear(); - tb_frag_zeros.clear(); - - TransformScale transformScale; - - using FragmentElement = typename FragmentScale::Element; - - auto gmem_scale_ptr = iterator_scale.get_scale(); - auto gmem_zero_ptr = iterator_scale.get_zero(); - - arch::global_load(tb_frag_scales, gmem_scale_ptr, iterator_scale.valid()); - - if (gmem_zero_ptr != nullptr) - { - arch::global_load( - tb_frag_zeros, gmem_zero_ptr, iterator_scale.valid()); - } - - typename TransformScale::result_type tb_frag_scales_fp16 = transformScale(tb_frag_scales); - typename TransformScale::result_type tb_frag_zeros_fp16; - if (gmem_zero_ptr != nullptr) - tb_frag_zeros_fp16 = transformScale(tb_frag_zeros); - - auto frag_scale_ptr_fp16 = reinterpret_cast(&tb_frag_scales_fp16); - auto frag_zero_ptr_fp16 = reinterpret_cast(&tb_frag_zeros_fp16); - auto smem_scale_ptr = this->smem_iterator_scale_.get_scale(); - auto smem_zero_ptr = this->smem_iterator_scale_.get_zero(); - - if (iterator_scale.valid()) - { - auto smem_offset = cast_smem_ptr_to_uint(smem_scale_ptr); - arch::shared_store(smem_offset, frag_scale_ptr_fp16); - - if (gmem_zero_ptr != nullptr) - { - smem_offset = cast_smem_ptr_to_uint(smem_zero_ptr); - arch::shared_store(smem_offset, frag_zero_ptr_fp16); - } - } - - if (iterator_scale.group_size_ == 64) - { - iterator_scale.add_tile_offset({1, 0}); - } - else if (iterator_scale.group_size_ == 128) - { - if constexpr (Shape::kK == 128) - { - iterator_scale.add_tile_offset({1, 0}); - } - else if constexpr (Shape::kK == 64) - { - if (iterator_scale.row_groupsize64_ & 0x1) - { - iterator_scale.add_tile_offset({1, 0}); - } - } - else - { - static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128"); - } - } - - iterator_scale.row_groupsize64_++; - - this->smem_iterator_scale_.add_tile_offset({1, 0}); - } - - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop - FragmentC& accum, ///< destination accumulator tile - IteratorA iterator_A, ///< iterator over A operand in global memory - IteratorB iterator_B, ///< iterator over B operand in global memory - IteratorScale iterator_scale, ///< iterator over scale operand in global memory - FragmentC const& src_accum) - { ///< source accumulator tile - - // - // Prologue - // - TransformBAfterLDG ldg_converter; - TransformBAfterLDS lds_converter; - - using TransformA - = NumericArrayConverter; - - // These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want - // to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS. - TransformA transformA; - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - FragmentA tb_frag_A; - FragmentB tb_frag_B; - - tb_frag_A.clear(); - tb_frag_B.clear(); - - // The last kblock is loaded in the prolog - iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); - - ++iterator_A; - ++iterator_B; - - this->smem_iterator_A_.store(transformA(tb_frag_A)); - this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); - - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; - - copy_scales_and_advance(iterator_scale); - - __syncthreads(); - - // Pair of fragments used to overlap shared memory loads and math instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; - WarpFragmentScale warp_frag_scales; - WarpFragmentZero warp_frag_zero; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.load(warp_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_frag_B[0]); - - warp_dequantizer_.load(warp_frag_scales, warp_frag_zero); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - warp_dequantizer_.add_pointer_offset(Shape::kN); - - Operator warp_mma; - - int smem_write_stage_idx = 1; - - // Avoid reading out of bounds - iterator_A.clear_mask(gemm_k_iterations <= 1); - iterator_B.clear_mask(gemm_k_iterations <= 1); - iterator_scale.clear_mask(gemm_k_iterations <= 1); - - // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing - // shared memory loads (which have the tighest latency requirement). - - // - // Mainloop - // - - // Note: The main loop does not support Base::kWarpGemmIterations == 2. - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > 0; --gemm_k_iterations) - { - // - // Loop over GEMM K dimension - // - - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) - { - - // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group - // as the case may be. - - if (warp_mma_k == Base::kWarpGemmIterations - 1) - { - - // Write fragments to shared memory - this->smem_iterator_A_.store(transformA(tb_frag_A)); - - this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); - - __syncthreads(); - - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; - - // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory - if (smem_write_stage_idx == 1) - { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0}); - } - else - { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); - warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN); - } - - smem_write_stage_idx ^= 1; - } - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_A_; - - int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; - int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; - // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment. - if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) - { - this->warp_tile_iterator_B_.set_kgroup_index( - (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - ++this->warp_tile_iterator_B_; - } - - if (warp_mma_k == 0) - { - - iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); - - ++iterator_A; - ++iterator_B; - - copy_scales_and_advance(iterator_scale); - - // Avoid reading out of bounds if this was the last loop iteration - iterator_A.clear_mask(gemm_k_iterations <= 2); - iterator_B.clear_mask(gemm_k_iterations <= 2); - iterator_scale.clear_mask(gemm_k_iterations <= 2); - } - - typename TransformBAfterLDS::result_type converted_frag_B - = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zero); - run_warp_mma( - warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset); - } - - // Load the scales needed for the next tile iteration - warp_dequantizer_.load(warp_frag_scales, warp_frag_zero); - // Update internal pointer to the set of scales in shared memory - warp_dequantizer_.add_pointer_offset(Shape::kN); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h deleted file mode 100644 index 316ea9f80a9..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h +++ /dev/null @@ -1,399 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Template for a double-buffered threadblock-scoped GEMM kernel. -*/ - -#pragma once - -#include "cutlass/aligned_buffer.h" -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/numeric_conversion.h" - -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/gemm/gemm.h" - -#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" -#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" -#include "cutlass_extensions/interleaved_numeric_conversion.h" - -#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" -#include "cutlass_extensions/gemm_configs.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace threadblock -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Iterates over tiles of A operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) - typename IteratorA_, - /// Iterates over tiles of A operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorA_, - /// Iterates over tiles of B operand in global memory - // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) - typename IteratorB_, - /// Iterates over tiles of B operand in shared memory - /// (concept: WriteableTileIterator | RandomAccessTileIterator) - typename SmemIteratorB_, - /// Iterators over scales in global memory - typename IteratorScale_, - /// Iterators over scales in shared memory - typename SmemIteratorScale_, - /// Data type of accumulator matrix - typename ElementC_, - /// Layout of accumulator matrix - typename LayoutC_, - /// Policy describing tuning details (concept: MmaPolicy) - typename Policy_, - /// Converter for B matrix applied immediately after the LDG (before STS) - typename TransformBAfterLDG_, - /// Converter for B matrix applited immediately after the LDS - typename TransformBAfterLDS_, - /// The quantization operator being used - WeightOnlyQuantOp QuantOp_> -class DqMmaPipelined> - : public DqMmaBase -{ -public: - ///< Base class - using Base = DqMmaBase; - - using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> - using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory - using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory - using ElementC = ElementC_; ///< Data type of accumulator matrix - using LayoutC = LayoutC_; ///< Layout of accumulator matrix - using Policy = Policy_; ///< Policy describing tuning details - - using IteratorScale = IteratorScale_; - using ElementScale = typename IteratorScale::Element; - using LayoutScale = typename IteratorScale::Layout; - - using SmemIteratorA = SmemIteratorA_; - using SmemIteratorB = SmemIteratorB_; - using SmemIteratorScale = SmemIteratorScale_; - - using TransformBAfterLDG = TransformBAfterLDG_; - using TransformBAfterLDS = TransformBAfterLDS_; - - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; - - // - // Dependent types - // - - /// Fragment of operand A loaded from global memory - using FragmentA = typename IteratorA::Fragment; - - /// Fragment of operand B loaded from global memory - using FragmentB = typename IteratorB::Fragment; - - /// Fragment of operand Scale loaded from global memory; - using FragmentScale = typename IteratorScale::Fragment; - - /// Fragment of accumulator tile - using FragmentC = typename Policy::Operator::FragmentC; - - /// Warp-level Mma - using Operator = typename Policy::Operator; - - /// Obtain the arch tag from the warp-level operator - using ArchTag = typename Policy::Operator::ArchTag; - - using Dequantizer = warp::MmaTensorOpDequantizer; - - /// Complex transform on A operand - static ComplexTransform const kTransformA = Operator::kTransformA; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = Operator::kTransformB; - - // staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline) - static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2"); - -private: - using WarpFragmentA = typename Operator::FragmentA; - using WarpFragmentB = typename Operator::FragmentB; - Dequantizer warp_dequantizer_; - - using ElementA = typename IteratorA::Element; - using ElementB = typename IteratorB::Element; - using LayoutDetailsForB = kernel::LayoutDetailsB; - - static constexpr bool RequiresTileInterleave - = layout::IsColumnMajorTileInterleave::value; - static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), - "Layout K must match threadblockK"); - -protected: - /// Iterator to write threadblock-scoped tile of A operand to shared memory - SmemIteratorA smem_iterator_A_; - - /// Iterator to write threadblock-scoped tile of B operand to shared memory - SmemIteratorB smem_iterator_B_; - - /// Iterator to write threadblock-scoped tile of scale operand to shared memory - SmemIteratorScale smem_iterator_scale_; - -public: - /// Construct from tensor references - CUTLASS_DEVICE - DqMmaPipelined(typename Base::SharedStorage& - shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM - int const group_size, ///< Will not be used, just to adapt to finegrained modifications and make the compilation - ///< successful. Because DqMmaPipelined is only enabled for sm<80, so even if this - ///< argument is not added, it does not affect compilation for sm>=80. - int thread_idx, ///< ID within the threadblock - int warp_idx, ///< ID of warp - int lane_idx ///< ID of each thread within a warp - ) - : Base(shared_storage, thread_idx, warp_idx, lane_idx) - , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, - (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) - , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) - , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) - , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx) - { - - // Compute warp location within threadblock tile by mapping the warp_id to - // three coordinates: - // _m: the warp's position within the threadblock along the M dimension - // _n: the warp's position within the threadblock along the N dimension - // _k: the warp's position within the threadblock along the K dimension - - int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); - int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); - - int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; - int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; - - // Add per-warp offsets in units of warp-level tiles - this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); - this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); - } - - /// Perform a threadblock-scoped matrix multiply-accumulate - CUTLASS_DEVICE - void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop - FragmentC& accum, ///< destination accumulator tile - IteratorA iterator_A, ///< iterator over A operand in global memory - IteratorB iterator_B, ///< iterator over B operand in global memory - IteratorScale iterator_scale, ///< iterator over scale operand in global memory - FragmentC const& src_accum) - { ///< source accumulator tile - - // - // Prologue - // - TransformBAfterLDG ldg_converter; - TransformBAfterLDS lds_converter; - - using TransformA - = NumericArrayConverter; - - using TransformScale = NumericArrayConverter; - - // These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want - // to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS. - TransformA transformA; - TransformScale transformScale; - - // Perform accumulation in the 'd' output operand - accum = src_accum; - - FragmentA tb_frag_A; - FragmentB tb_frag_B; - FragmentScale tb_frag_scales; - - using WarpFragmentScale = typename Dequantizer::FragmentScale; - WarpFragmentScale warp_frag_scales; - - tb_frag_A.clear(); - tb_frag_B.clear(); - tb_frag_scales.clear(); - - // The last kblock is loaded in the prolog - iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); - iterator_scale.load(tb_frag_scales); - - ++iterator_A; - ++iterator_B; - - this->smem_iterator_A_.store(transformA(tb_frag_A)); - this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); - this->smem_iterator_scale_.store(transformScale(tb_frag_scales)); - - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; - - __syncthreads(); - - warp_dequantizer_.load(warp_frag_scales); - - // Pair of fragments used to overlap shared memory loads and math instructions - WarpFragmentA warp_frag_A[2]; - WarpFragmentB warp_frag_B[2]; - - this->warp_tile_iterator_A_.set_kgroup_index(0); - this->warp_tile_iterator_B_.set_kgroup_index(0); - - this->warp_tile_iterator_A_.load(warp_frag_A[0]); - this->warp_tile_iterator_B_.load(warp_frag_B[0]); - - ++this->warp_tile_iterator_A_; - ++this->warp_tile_iterator_B_; - - Operator warp_mma; - - int smem_write_stage_idx = 1; - - // Avoid reading out of bounds - iterator_A.clear_mask(gemm_k_iterations <= 1); - iterator_B.clear_mask(gemm_k_iterations <= 1); - - // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing - // shared memory loads (which have the tighest latency requirement). - - // - // Mainloop - // - - // Note: The main loop does not support Base::kWarpGemmIterations == 2. - CUTLASS_GEMM_LOOP - for (; gemm_k_iterations > 0; --gemm_k_iterations) - { - // - // Loop over GEMM K dimension - // - - CUTLASS_PRAGMA_UNROLL - for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) - { - - // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group - // as the case may be. - - if (warp_mma_k == Base::kWarpGemmIterations - 1) - { - - // Write fragments to shared memory - this->smem_iterator_A_.store(transformA(tb_frag_A)); - - this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); - - __syncthreads(); - - ++this->smem_iterator_A_; - ++this->smem_iterator_B_; - - // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory - if (smem_write_stage_idx == 1) - { - this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); - this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); - } - else - { - this->warp_tile_iterator_A_.add_tile_offset( - {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); - this->warp_tile_iterator_B_.add_tile_offset( - {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); - } - - smem_write_stage_idx ^= 1; - } - - this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); - this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); - ++this->warp_tile_iterator_A_; - - int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; - int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; - // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment. - if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) - { - this->warp_tile_iterator_B_.set_kgroup_index( - (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); - this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); - ++this->warp_tile_iterator_B_; - } - - if (warp_mma_k == 0) - { - - iterator_A.load(tb_frag_A); - iterator_B.load(tb_frag_B); - - ++iterator_A; - ++iterator_B; - - // Avoid reading out of bounds if this was the last loop iteration - iterator_A.clear_mask(gemm_k_iterations <= 2); - iterator_B.clear_mask(gemm_k_iterations <= 2); - } - - typename TransformBAfterLDS::result_type converted_frag_B - = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); - warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); - run_warp_mma( - warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset); - } - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h deleted file mode 100644 index 350b247de2e..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h +++ /dev/null @@ -1,107 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. -*/ - -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/gemm/warp/default_mma_tensor_op.h" -#include "cutlass/gemm/warp/mma_tensor_op.h" - -#include "cutlass_extensions/arch/mma.h" -#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" - -namespace cutlass -{ -namespace gemm -{ -namespace warp -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for m-by-n-by-kgroup -template < - /// Shape of one matrix production operation (concept: GemmShape) - typename WarpShape_, - /// Shape of one matrix production operation (concept: GemmShape) - typename InstructionShape_, - /// Data type of A elements, - typename ElementA, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA, - /// Data type of B elements - typename ElementB, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB, - /// Element type of C matrix - typename ElementC, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC, - /// Number of partitions along K dimension - int PartitionsK, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor> -struct DefaultMmaTensorOp -{ - -private: - // Shape for computing the FP16s - using ComputeInstructionShape = InstructionShape_; - - // Chosen so we get K=16 for int8 and K=32 for int4. - static constexpr int LoadInstructionK = 128 / sizeof_bits::value; - - // Shape for loading the narrow data type from shared memory - using LoadInstructionShape = GemmShape; - -public: - using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< - cutlass::arch::Mma, - cutlass::MatrixShape<1, 1>>; - - // Define the warp-level tensor op - using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h deleted file mode 100644 index 7c5088894b4..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h +++ /dev/null @@ -1,306 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates implementing warp-level matrix multiply-accumulate operations targeting - Tensor Cores. -*/ - -#pragma once - -#include "cutlass/array.h" -#include "cutlass/cutlass.h" -#include "cutlass/platform/platform.h" - -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_conversion.h" -#include "cutlass/numeric_types.h" - -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/arch/mma_sm75.h" -#include "cutlass/arch/mma_sm80.h" -#include "cutlass/arch/mma_sm89.h" - -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/warp/mma.h" - -#include "cutlass/gemm/warp/mma_tensor_op_policy.h" - -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" -#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace warp -{ - -///////////////////////////////////////////////////////////////////////////////////////////////// -/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. -template < - /// Size of the Gemm problem - concept: gemm::GemmShape<> - typename Shape_, - /// Data type of A elements - typename ElementA_, - /// Layout of A matrix (concept: MatrixLayout) - typename LayoutA_, - /// Data type of B elements - typename ElementB_, - /// Layout of B matrix (concept: MatrixLayout) - typename LayoutB_, - /// Element type of C matrix - typename ElementC_, - /// Layout of C matrix (concept: MatrixLayout) - typename LayoutC_, - /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) - typename Policy_, - /// Instruction shape to override shared memory iterators with - typename SharedMemoryInstructionShape_, - /// Number of partitions along K dimension - int PartitionsK_ = 1, - /// Store the accumulators in row major or column major. Row major is used - /// when output layout is interleaved. - bool AccumulatorsInRowMajor = false, - /// Used for partial specialization - typename Enable = bool> -class MmaTensorOpComputeBWithF16 -{ -public: - /// Shape of warp-level matrix operation (concept: GemmShape) - using Shape = Shape_; - - /// Data type of multiplicand A - using ElementA = ElementA_; - - /// Layout of multiplicand A - using LayoutA = LayoutA_; - - /// Data type of multiplicand B - using ElementB = ElementB_; - - /// Layout of multiplicand B - using LayoutB = LayoutB_; - - /// Data type of accumulator matrix C - using ElementC = ElementC_; - - /// Layout of accumulator matrix C - using LayoutC = LayoutC_; - - /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) - using Policy = Policy_; - - /// Underlying matrix multiply operator (concept: arch::Mma) - using ArchMmaOperator = typename Policy::Operator; - - /// Indicates math operator - using MathOperator = typename ArchMmaOperator::Operator; - - /// Architecture tag from underlying instruction - using ArchTag = typename ArchMmaOperator::ArchTag; - static_assert((platform::is_same::value - && platform::is_same::value) - || (platform::is_same::value - && platform::is_same::value - && ArchTag::kMinComputeCapability >= 80) - || (platform::is_same::value - && platform::is_same::value - && ArchTag::kMinComputeCapability >= 89), - "MmaTensorOpCvtBToA only supports underlying HMMA/QMMA"); - - static_assert(platform::is_same::value - || (platform::is_same::value && ArchTag::kMinComputeCapability >= 80) - || (platform::is_same::value && ArchTag::kMinComputeCapability >= 89), - "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+, or FP8 on Ada"); - - /// Indicates class of matrix operator - using OperatorClass = arch::OpClassTensorOp; - - /// Shape of underlying instruction - using InstructionShape = typename ArchMmaOperator::Shape; - - /// Instruction shape to override shared memory iterators with - using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; - - static_assert( - SharedMemoryInstructionShape::kM == InstructionShape::kM, "M dimension of compute instruction must match load"); - static_assert( - SharedMemoryInstructionShape::kN == InstructionShape::kN, "N dimension of compute instruction must match load"); - - static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK; - - static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); - - /// Complex transform on A operand - static ComplexTransform const kTransformA = ComplexTransform::kNone; - - /// Complex transform on B operand - static ComplexTransform const kTransformB = ComplexTransform::kNone; - - /// Number of threads participating in warp-level matrix product - static int const kThreadCount = 32; - - /// Number of partitions along K dimension - static int const kPartitionsK = PartitionsK_; - -public: - /// Iterates over the A operand in memory - using IteratorA - = MmaTensorOpMultiplicandTileIterator, Operand::kA, ElementA, LayoutA, - MatrixShape, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; - - /// Storage for A tile - using FragmentA = typename IteratorA::Fragment; - - /// Storage for transformed A tile - using TransformedFragmentA = Array; - - /// Iterates over the B operand in memory - using IteratorB = MmaTensorOpMultiplicandTileIterator, Operand::kB, ElementB, - LayoutB, MatrixShape, Policy::OpDelta::kRow, - kThreadCount, kPartitionsK>; - - /// Storage for B tile - using FragmentB = typename IteratorB::Fragment; - - /// Storage for transformed B tile - using TransformedFragmentB = Array; - - /// Iterates over the C operand in memory - using IteratorC = MmaTensorOpAccumulatorTileIterator, ElementC, LayoutC, - typename ArchMmaOperator::Shape, typename Policy::OpDelta>; - - /// Storage for C tile - using FragmentC = typename IteratorC::Fragment; - - /// Number of mma operations performed - using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, - (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>; - -public: - /// Underlying matrix multiply operator (concept: arch::Mma) - ArchMmaOperator mma; - -public: - // - // Methods - // - - /// Ctor - CUTLASS_DEVICE - MmaTensorOpComputeBWithF16() {} - - /// Performs a warp-level matrix multiply-accumulate operation - CUTLASS_DEVICE - void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C, - int const warp_tileB_k_offset) const - { - - using MmaOperandA = typename ArchMmaOperator::FragmentA; - using MmaOperandB = typename ArchMmaOperator::FragmentB; - using MmaOperandC = typename ArchMmaOperator::FragmentC; - - static_assert( - TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, - "Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of " - "B"); - - D = C; - - MmaOperandA const* ptr_A = reinterpret_cast(&A); - MmaOperandB const* ptr_B = reinterpret_cast(&B); - MmaOperandC* ptr_D = reinterpret_cast(&D); - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) - // Serpentine visitation order maximizing reuse of Rb - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) - { - - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) - { - - int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); - - int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; - if (AccumulatorsInRowMajor) - { // matrix B is reordered - mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n_offsetB], - ptr_D[n + m_serpentine * MmaIterations::kColumn]); - } - else - { - mma(ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n_offsetB], - ptr_D[m_serpentine + n * MmaIterations::kRow]); - } - } - } -#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - // Serpentine visitation order maximizing reuse of Ra - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < MmaIterations::kRow; ++m) - { - - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < MmaIterations::kColumn; ++n) - { - - int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); - - int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine; - if (AccumulatorsInRowMajor) - { // matrix B is reordered - mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine_offsetB], - ptr_D[n_serpentine + m * MmaIterations::kColumn]); - } - else - { - mma(ptr_D[m + n_serpentine * MmaIterations::kRow], ptr_A[m], ptr_B[n_serpentine_offsetB], - ptr_D[m + n_serpentine * MmaIterations::kRow]); - } - } - } -#else - assert(0); -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h deleted file mode 100644 index 1d5cd5d8985..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h +++ /dev/null @@ -1,463 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. -*/ - -#pragma once - -#include "cutlass/cutlass.h" - -#include "cutlass/array.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/numeric_types.h" -#include "cutlass/tensor_ref.h" - -#include "cutlass/arch/arch.h" -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/gemm/gemm.h" - -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/layout/tensor.h" - -#include "cutlass/functional.h" -#include "cutlass/platform/platform.h" - -#include "cutlass_extensions/weight_only_quant_op.h" -#include "tensorrt_llm/common/cudaBf16Wrapper.h" - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace gemm -{ -namespace warp -{ - -//////////////////////////////////////////////////////////////////////////////// - -template < - /// Matrix multiply operator - typename MmaOperator_, - /// Size of the matrix to load (concept: MatrixShape) - typename Shape_, - /// Operand identity - Operand Operand, - /// Data type of Scale elements - typename Element_, - /// Layout of operand - typename Layout_, - /// Number of threads participating in one matrix operation - int Threads, - /// - WeightOnlyQuantOp QuantOp_, - /// - typename Enable = void> -class MmaTensorOpDequantizer; - -//////////////////////////////////////////////////////////////////////////////// -// Bfloat specialization for Ampere -template < - /// Underlying matrix multiply operator (concept: MmaTensorOp) - typename MmaOperator_, - /// Shape of the warp level matrix multiply (concept: GemmShape) - typename Shape_, - /// - WeightOnlyQuantOp QuantOp_> -class MmaTensorOpDequantizer= 80 - && platform::is_same::value>::type> -{ - -public: - /// Mma Operator - using MmaOperator = MmaOperator_; - - // The architecture specific mma ooperator being used - using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; - - // Mma Instruction Shape - using InstructionShape = typename ArchMmaOperator::Shape; - - // This is the ratio of the load instruction vs the compute instruction. - static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; - - /// Type of the scales - using ElementScale = bfloat16_t; - - /// Fragment to hold B data before Mma - using FragmentDequantizedOperand = Array; - - // Fragment to hold scale data to apply to B before mma - // We need 1 fp16 per matrix iteration in the N dimension - static constexpr int kColsPerMmaPerThread = 1; - using FragmentScale = Array; - using FragmentZero = Array; - - /// Warp mma shape - using Shape = Shape_; - - /// Layout of the scales in shared memory - using Layout = layout::RowMajor; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; - - CUTLASS_DEVICE - MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx) - { - int const warp_offset = warp_idx_n * Shape::kN; - int const quad = lane_idx / 4; - int const thread_offset = warp_offset + quad; - pointer_scale_ = smem_scales.data() + thread_offset; - if constexpr (hasZero(QuantOp)) - { - pointer_zero_ = smem_zeros.data() + thread_offset; - } - } - - CUTLASS_DEVICE - MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) - : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) - { - } - - CUTLASS_DEVICE - void load(FragmentScale& scale_frag) - { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; - } - } - - CUTLASS_DEVICE - void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) - { -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) - using _MmaOperandB = typename ArchMmaOperator::FragmentB; - using ExpandedMmaOperandB = Array; - static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn - == FragmentDequantizedOperand::kElements, - ""); - - __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag); - ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); - - __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]); - __nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); - - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) - { - operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2); - } - } -#else - // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should - // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid - // numerous conversion instructions in GEMM main loop. - arch::device_breakpoint(); -#endif - } - - CUTLASS_DEVICE - void load(FragmentScale& scale_frag, FragmentScale& zero_frag) - { - if constexpr (hasZero(QuantOp)) - { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; - zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN]; - } - } - else - { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; - } - } - } - - CUTLASS_DEVICE - void dequantize( - FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag) - { -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) - using _MmaOperandB = typename ArchMmaOperator::FragmentB; - using ExpandedMmaOperandB = Array; - static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn - == FragmentDequantizedOperand::kElements, - ""); - - __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag); - __nv_bfloat16 const* zero_ptr = reinterpret_cast<__nv_bfloat16 const*>(&zero_frag); - - ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); - - __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]); - __nv_bfloat162 zerox2 = __bfloat162bfloat162(zero_ptr[mma_n_iter]); - __nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); - - if constexpr (hasZero(QuantOp)) - { - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) - { - operand_bf16x2_ptr[ii] = __hfma2(operand_bf16x2_ptr[ii], scalex2, zerox2); - } - } - else - { - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) - { - operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2); - } - } - } -#else - // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should - // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid - // numerous conversion instructions in GEMM main loop. - arch::device_breakpoint(); -#endif - } - - // Adds a pointer offset in units of elements. - CUTLASS_DEVICE - void add_pointer_offset(int64_t const& offset) - { - static_assert(sizeof(ElementScale) > 1, ""); - pointer_scale_ += offset; - pointer_zero_ += offset; - } - -private: - ElementScale const* pointer_scale_; - ElementScale const* pointer_zero_; -}; - -//////////////////////////////////////////////////////////////////////////////// - -// Specialization for Turing & Ampere -template < - /// Underlying matrix multiply operator (concept: MmaTensorOp) - typename MmaOperator_, - /// Shape of the warp level matrix multiply (concept: GemmShape) - typename Shape_, - /// - WeightOnlyQuantOp QuantOp_> -class MmaTensorOpDequantizer= 75 - && platform::is_same::value>::type> -{ - -public: - /// Mma Operator - using MmaOperator = MmaOperator_; - - // The architecture specific mma ooperator being used - using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; - - // Mma Instruction Shape - using InstructionShape = typename ArchMmaOperator::Shape; - - // This is the ratio of the load instruction vs the compute instruction. - static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; - - /// Type of the scales - using ElementScale = half_t; - - /// Fragment to hold B data before Mma - using FragmentDequantizedOperand = Array; - - // Fragment to hold scale data to apply to B before mma - // We need 1 fp16 per matrix iteration in the N dimension - static constexpr int kColsPerMmaPerThread = 1; - using FragmentScale = Array; - using FragmentZero = Array; - - /// Warp mma shape - using Shape = Shape_; - - /// Layout of the scales in shared memory - using Layout = layout::RowMajor; - - /// TensorRef type for loading element from a tensor - using TensorRef = TensorRef; - - static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; - - CUTLASS_DEVICE - MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx) - { - int const warp_offset = warp_idx_n * Shape::kN; - int const quad = lane_idx / 4; - int const thread_offset = warp_offset + quad; - pointer_scale_ = smem_scales.data() + thread_offset; - if constexpr (hasZero(QuantOp)) - { - pointer_zero_ = smem_zeros.data() + thread_offset; - } - } - - CUTLASS_DEVICE - MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) - : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) - { - } - - CUTLASS_DEVICE - void load(FragmentScale& scale_frag) - { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; - } - } - - CUTLASS_DEVICE - void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) - { - using _MmaOperandB = typename ArchMmaOperator::FragmentB; - using ExpandedMmaOperandB - = Array; - static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn - == FragmentDequantizedOperand::kElements, - ""); - - multiplies mul_op; - - ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); - } - } - - CUTLASS_DEVICE - void load(FragmentScale& scale_frag, FragmentScale& zero_frag) - { - if constexpr (hasZero(QuantOp)) - { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; - zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN]; - } - } - else - { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; - } - } - } - - CUTLASS_DEVICE - void dequantize( - FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag) - { - using _MmaOperandB = typename ArchMmaOperator::FragmentB; - using ExpandedMmaOperandB - = Array; - static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn - == FragmentDequantizedOperand::kElements, - ""); - - multiplies mul_op; - ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); - - if constexpr (hasZero(QuantOp)) - { - plus plus_op; - - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - operand_frag_ptr[mma_n_iter] - = plus_op(mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]), zero_frag[mma_n_iter]); - } - } - else - { - CUTLASS_PRAGMA_UNROLL - for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) - { - operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); - } - } - } - - // Adds a pointer offset in units of elements. - CUTLASS_DEVICE - void add_pointer_offset(int64_t const& offset) - { - static_assert(sizeof(ElementScale) > 1, ""); - pointer_scale_ += offset; - pointer_zero_ += offset; - } - -private: - ElementScale const* pointer_scale_; - ElementScale const* pointer_zero_; -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace warp -} // namespace gemm -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h deleted file mode 100644 index 4acef2d180f..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h +++ /dev/null @@ -1,224 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include - -namespace tensorrt_llm -{ -namespace cutlass_extensions -{ -// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape -// in the kernel layout details when doing weight only quantization. -enum class CutlassTileConfig -{ - // Signals that we should run heuristics do choose a config - Undefined, - - // Signals that we should run heuristics do choose a config - ChooseWithHeuristic, - - // SiMT config - CtaShape128x128x8_WarpShape64x64x8, - - // TensorCore configs CTA_N = 128, CTA_K = 64 - // Warp configs for M=16 - CtaShape16x128x64_WarpShape16x32x64, - // Warp configs for M=32 - CtaShape32x128x64_WarpShape32x32x64, - - // Warp configs for M=64 - CtaShape64x128x64_WarpShape32x64x64, - CtaShape64x64x128_WarpShape32x64x64, - CtaShape64x128x64_WarpShape64x32x64, - - // Warp configs for M=128 - CtaShape128x64x64_WarpShape64x32x64, - CtaShape128x128x64_WarpShape64x32x64, - CtaShape128x128x64_WarpShape64x64x64, - CtaShape128x128x64_WarpShape128x32x64, - CtaShape128x256x64_WarpShape64x64x64, - - // Warp configs for M=256 - CtaShape256x128x64_WarpShape64x64x64, - - // TensorCore config CTA_N = 64, CTA_K = 128 - CtaShape128x64x128_WarpShape64x32x128, - - // TensorCore config CTA_N = 256, CTA_K = 64 - CtaShape16x256x64_WarpShape16x64x64, - - // TensorCore config CTA_N = 256, CTA_K = 128 - CtaShape16x256x128_WarpShape16x64x128 - -}; - -enum class SplitKStyle -{ - NO_SPLIT_K, - SPLIT_K_SERIAL, - STREAM_K, // Sm80+ - // SPLIT_K_PARALLEL // Not supported yet -}; - -enum class CutlassTileConfigSM90 -{ - // Signals that we should run heuristics do choose a config - Undefined, - - // Signals that we should run heuristics do choose a config - ChooseWithHeuristic, - - // CTA configs for M=64 - CtaShape64x16x128B, - CtaShape64x32x128B, - CtaShape64x64x128B, - CtaShape64x128x128B, - CtaShape64x256x128B, - - // CTA configs for M=128 - CtaShape128x16x128B, - CtaShape128x32x128B, - CtaShape128x64x128B, - CtaShape128x128x128B, - CtaShape128x256x128B, - - // CTA configs for M=128 - CtaShape256x128x128B, -}; - -enum class MainloopScheduleType -{ - AUTO // Automatically selects between pingpong and cooperative schedules on Hopper. On older architectures, this - // defaults to the "legacy" main loop schedule. -}; - -enum class EpilogueScheduleType -{ - AUTO // Automatically chooses an epilogue schedule compatible with the selected main loop schedule for Hopper. For - // architectures older than hopper, the epilogue is always performed by the same thread block as the main loop. -}; - -enum class ClusterShape -{ - ClusterShape_1x1x1, - ClusterShape_2x1x1, - ClusterShape_1x2x1, - ClusterShape_2x2x1, - ClusterShape_1x8x1, - ClusterShape_8x1x1 -}; - -struct CutlassGemmConfig -{ - enum CandidateConfigTypeParam : int - { - NONE = 0, - WEIGHT_ONLY = 1u << 0, - SIMT_ONLY = 1u << 1, - INT8_ONLY = 1u << 2, - HOPPER = 1u << 3, - GROUPED_GEMM = 1u << 4, - FP8_ONLY = 1u << 5, - }; - - CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic; - SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K; - int split_k_factor = -1; - int stages = -1; - - // config options for sm90 - CutlassTileConfigSM90 tile_config_sm90 = CutlassTileConfigSM90::ChooseWithHeuristic; - MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO; - EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO; - ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1; - bool is_sm90 = false; - - CutlassGemmConfig() {} - - CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages) - : tile_config(tile_config) - , split_k_style(split_k_style) - , split_k_factor(split_k_factor) - , stages(stages) - , is_sm90(false) - { - } - - CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90, MainloopScheduleType mainloop_schedule, - EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape) - : tile_config_sm90(tile_config_sm90) - , mainloop_schedule(mainloop_schedule) - , epilogue_schedule(epilogue_schedule) - , cluster_shape(cluster_shape) - , is_sm90(true) - { - } - - std::string toString() const - { - std::stringstream tactic; - tactic << "Cutlass GEMM Tactic"; - if (tile_config_sm90 != tensorrt_llm::cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic) - { - assert(is_sm90 && "Invalid cutlass GEMM config"); - tactic << "\n\tstyle=TMA" - << "\n\ttile shape ID: " << (int) tile_config_sm90 << "\n\tcluster shape ID: " << (int) cluster_shape - << "\n\tmainloop sched: " << (int) mainloop_schedule << "\n\tepi sched: " << (int) epilogue_schedule; - } - else if (tile_config != tensorrt_llm::cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic) - { - assert(!is_sm90 && "Invalid cutlass GEMM config"); - tactic << "\n\tstyle=compatible" - << "\n\ttile shape ID: " << (int) tile_config << "\n\tstages: " << (int) stages - << "\n\tsplit k: " << (int) split_k_factor; - } - else - { - tactic << "\n\tundefined"; - } - tactic << "\n"; - return tactic.str(); - } -}; - -inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& config) -{ - // clang-format off - if (config.is_sm90) - { - out << "tile_config_sm90_enum: " << int(config.tile_config_sm90) - << ", mainloop_schedule_enum: " << int(config.mainloop_schedule) - << ", epilogue_schedule_enum: " << int(config.epilogue_schedule) - << ", cluster_shape_enum: " << int(config.cluster_shape); - } - else - { - out << "tile_config_enum: " << int(config.tile_config) - << ", split_k_style_enum: " << int(config.split_k_style) - << ", split_k_factor: " << config.split_k_factor - << ", stages: " << config.stages; - } - // clang-format on - return out; -} - -} // namespace cutlass_extensions -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h deleted file mode 100644 index 44ba79680e6..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +++ /dev/null @@ -1,447 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! - \file - \brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t interleaved in a register -*/ - -#pragma once - -#include "cutlass/arch/arch.h" -#include "cutlass/array.h" -#include "cutlass/half.h" -#include "cutlass/numeric_types.h" - -namespace cutlass -{ - -// This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low -// bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally -// signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned. -// This converter will uninterleave the data and subtract the bias while converting to the result type. -template -struct FastInterleavedAndBiasedNumericArrayConverter -{ -}; - -template <> -struct FastInterleavedAndBiasedNumericArrayConverter -{ - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; - - uint32_t* h = reinterpret_cast(&result); - uint32_t const i8s = reinterpret_cast(source); - - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); - asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); - - // Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16. - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } -}; - -template -struct FastInterleavedAndBiasedNumericArrayConverter -{ - static constexpr int VEC_WIDTH = 4; - static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); - - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - FastInterleavedAndBiasedNumericArrayConverter - convert_vector_; - - result_type result; - using vec_result = Array; - using vec_source = Array; - - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / VEC_WIDTH; ++i) - { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } -}; - -template <> -struct FastInterleavedAndBiasedNumericArrayConverter -{ - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) - - uint32_t* bf16_result_ptr = reinterpret_cast(&result); - uint32_t const i8s = reinterpret_cast(source); - - static constexpr uint32_t fp32_base = 0x4B000000; - float fp32_intermediates[4]; - - // Construct FP32s, bfloat does not have enough mantissa for IADD trick - uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); - fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); - fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652); - fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651); - fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); - - // Subtract out fp32_base + 128 to make the unsigned integer signed. - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < 4; ++ii) - { - fp32_intermediates[ii] -= 8388736.f; - } - - // Truncate the fp32 representation and pack up as bfloat16s. - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < 2; ++ii) - { - bf16_result_ptr[ii] - = __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632); - } -#else - // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use - // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. - result.clear(); // Suppress compiler warning - arch::device_breakpoint(); -#endif - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } -}; - -template -struct FastInterleavedAndBiasedNumericArrayConverter -{ - static constexpr int VEC_WIDTH = 4; - static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); - - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - FastInterleavedAndBiasedNumericArrayConverter - convert_vector_; - - result_type result; - using vec_result = Array; - using vec_source = Array; - - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / VEC_WIDTH; ++i) - { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } -}; - -template <> -struct FastInterleavedAndBiasedNumericArrayConverter -{ - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; - - uint32_t* h = reinterpret_cast(&result); - uint32_t const i4s = reinterpret_cast(source); - - // First, we extract the i4s and construct an intermediate fp16 number. - static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint32_t BOTTOM_MASK = 0x000f000f; - static constexpr uint32_t TOP_MASK = 0x00f000f0; - static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; - - // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing - // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. - // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and - // elt_67 to fp16 without having to shift them to the bottom bits before hand. - - // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue - // immediately before required. - const uint32_t top_i4s = i4s >> 8; - // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[0]) - : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[1]) - : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[2]) - : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[3]) - : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); - - // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the - // half2 ctor. In this case, I chose performance reliability over code readability. - - // This is the half2 {1032, 1032} represented as an integer. - static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; - // This is the half2 {1 / 16, 1 / 16} represented as an integer. - static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; - // This is the half2 {-72, -72} represented as an integer. - static constexpr uint32_t NEG_72 = 0xd480d480; - - // Finally, we construct the output numbers. - // Convert elt_01 - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); - // Convert elt_23 - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); - // Convert elt_45 - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); - // Convert elt_67 - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } -}; - -template -struct FastInterleavedAndBiasedNumericArrayConverter -{ - static constexpr int VEC_WIDTH = 8; - static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); - - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - FastInterleavedAndBiasedNumericArrayConverter - convert_vector_; - - result_type result; - using vec_result = Array; - using vec_source = Array; - - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / VEC_WIDTH; ++i) - { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } -}; - -template <> -struct FastInterleavedAndBiasedNumericArrayConverter -{ - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - result_type result; -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) - - uint32_t* h = reinterpret_cast(&result); - uint32_t const source_i4s = reinterpret_cast(source); - - // First, we extract the i4s and construct an intermediate fp16 number. - static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint32_t MASK = 0x000f000f; - static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; - - // We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop. - // No shift needed for first item. - uint32_t i4s = source_i4s; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[0]) - : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); - CUTLASS_PRAGMA_UNROLL - for (int ii = 1; ii < result_type::kElements / 2; ++ii) - { - i4s >>= sizeof_bits::value; - // (i4s & 0x000f000f) | 0x43004300 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(h[ii]) - : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); - } - - // This is the BF16 {-136, -136} represented as an integer. - static constexpr uint32_t BF16_BIAS = 0xC308C308; - static constexpr uint32_t BF16_ONE = 0x3F803F80; - - // Finally, we construct the output numbers. - CUTLASS_PRAGMA_UNROLL - for (int ii = 0; ii < result_type::kElements / 2; ++ii) - { - // Since this section is for Ampere+, we use bf16 fma to do the bias subtraction - asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); - } -#else - // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use - // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. - arch::device_breakpoint(); - result.clear(); // Suppress compiler warning. -#endif - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } -}; - -template -struct FastInterleavedAndBiasedNumericArrayConverter -{ - static constexpr int VEC_WIDTH = 8; - static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); - - using result_type = Array; - using source_type = Array; - - CUTLASS_DEVICE - static result_type convert(source_type const& source) - { - using scalar_result_type = typename result_type::Element; - using scalar_source_type = typename source_type::Element; - FastInterleavedAndBiasedNumericArrayConverter - convert_vector_; - - result_type result; - using vec_result = Array; - using vec_source = Array; - - vec_result* result_ptr = reinterpret_cast(&result); - vec_source const* source_ptr = reinterpret_cast(&source); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < N / VEC_WIDTH; ++i) - { - result_ptr[i] = convert_vector_(source_ptr[i]); - } - - return result; - } - - CUTLASS_DEVICE - result_type operator()(source_type const& s) - { - return convert(s); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h deleted file mode 100644 index 5a0cd295708..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h +++ /dev/null @@ -1,66 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines new layouts needed for MoE -*/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/pitch_linear_coord.h" - -namespace cutlass -{ -namespace layout -{ - -template -struct ColumnMajorTileInterleave -{ - static constexpr int kRowsPerTile = RowsPerTile; - static constexpr int kColumnsInterleaved = ColumnsInterleaved; -}; - -template -struct IsColumnMajorTileInterleave -{ - static constexpr bool value = false; -}; - -template -struct IsColumnMajorTileInterleave> -{ - static constexpr bool value = true; -}; - -} // namespace layout -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h deleted file mode 100644 index 6095925e372..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h +++ /dev/null @@ -1,250 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Templates for visiting scales to be used when dequantizing the weights for weight-only GEMM - quantization. -*/ - -#pragma once - -#include "cutlass/array.h" -#include "cutlass/coord.h" -#include "cutlass/cutlass.h" -#include "cutlass/layout/matrix.h" -#include "cutlass/layout/pitch_linear.h" -#include "cutlass/matrix_shape.h" -#include "cutlass/predicate_vector.h" -#include "cutlass/tensor_ref.h" -#include "cutlass/tensor_view.h" -#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" - -//////////////////////////////////////////////////////////////////////////////// - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass -{ -namespace transform -{ -namespace threadblock -{ - -//////////////////////////////////////////////////////////////////////////////// - -template -class FineGrainedScaleZeroIterator; - -template -class FineGrainedScaleZeroIterator -{ -public: - using Shape = Shape_; - using Element = Element_; - using Layout = layout::RowMajor; - static int const kAdvanceRank = 0; - static int const kAlignment = Alignment_; - - static int const kAccessesPerVector = 1; - - /// Row index of scales corresponding to the groupsize of 64 - int row_groupsize64_; - int group_size_; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - - using TensorRef = TensorRef; - using TensorView = TensorView; - using TensorCoord = typename Layout::TensorCoord; - using Pointer = Element*; - using NonConstPointer = typename platform::remove_const::type*; - - using AccessType = AlignedArray; - - using Fragment = cutlass::Array; - - // For compatibility with existing iterator interface - struct Params - { - LongIndex stride_ = 0; - - /// amount (in byte) to increment pointer from first access of current tile - /// to first access of next tile - LongIndex inc_advance_ = 0; - - // Default ctor - CUTLASS_HOST_DEVICE - Params() {} - - /// Construct the Params object given a pitch-linear tensor's layout - CUTLASS_HOST_DEVICE - Params(Layout const& layout) - : stride_(layout.stride(0)) - { - inc_advance_ = Shape::kRow * stride_ * sizeof_bits::value / 8; - } - }; - -private: - /// Internal pointer type permits fast address arithmetic - using BytePointer = char*; - -private: - // - // Data members - // - - /// Parameters object with precomputed internal state - Params const params_; - - /// Internal pointer to first access of tile - BytePointer pointer_scale_; - BytePointer pointer_zero_; - - bool is_valid_ = false; - -public: - /// Constructs a TileIterator from its precomputed state, threadblock offset, - /// and thread ID - CUTLASS_DEVICE - FineGrainedScaleZeroIterator( - ///< Precomputed parameters object - Params const& params, - ///< Pointer to start of scale tensor - Pointer pointer_scale, - ///< Pointer to start of zero tensor - Pointer pointer_zero, - ///< Extent of the scale and bias - TensorCoord extent, - ///< ID of each participating thread - int thread_id, - ///< Initial offset of threadblock - TensorCoord const& threadblock_offset, - ///< Group size - int group_size) - : params_(params) - , pointer_scale_(reinterpret_cast(const_cast(pointer_scale))) - , pointer_zero_(reinterpret_cast(const_cast(pointer_zero))) - { - row_groupsize64_ = threadblock_offset.row(); - group_size_ = group_size; - - const LongIndex tb_row_byte_offset - = threadblock_offset.row() / (group_size / 64) * params_.stride_ * sizeof_bits::value / 8; - const LongIndex tb_col_byte_offset = threadblock_offset.column() * sizeof_bits::value / 8; - pointer_scale_ += (tb_row_byte_offset + tb_col_byte_offset); - - if (pointer_zero_ != nullptr) - { - pointer_zero_ += (tb_row_byte_offset + tb_col_byte_offset); - } - - static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment; - - int const thread_row = thread_id / THREADS_PER_ROW; - int const thread_col = thread_id % THREADS_PER_ROW; - - const LongIndex thread_row_byte_offset = thread_row * params_.stride_ * sizeof_bits::value / 8; - const LongIndex thread_col_byte_offset = thread_col * kAlignment * sizeof_bits::value / 8; - pointer_scale_ += (thread_row_byte_offset + thread_col_byte_offset); - if (pointer_zero_ != nullptr) - { - pointer_zero_ += (thread_row_byte_offset + thread_col_byte_offset); - } - - // For the rows, we must check that we are within the extent AND the tile to avoid extra reads on - // a given iteration. The same threads will be responsible for issues reads since the number of scales - // read in a given iteration is a constant. Therefore, we should never have to update is_valid_ - // outside of the constructor. - int const global_row = threadblock_offset.row() + thread_row; - int const global_col = threadblock_offset.column() + thread_col * kAlignment; - - bool const row_in_bounds = global_row < extent.row() && thread_row < Shape::kRow; - bool const col_in_bounds = global_col < extent.column(); - - is_valid_ = row_in_bounds && col_in_bounds; - } - - /// Construct a PredicatedTileAccessIterator with zero threadblock offset - CUTLASS_HOST_DEVICE FineGrainedScaleZeroIterator(Params const& params, ///< Precomputed parameters object - Pointer pointer_scale, ///< Pointer to start of scale tensor - Pointer pointer_zero, ///< Pointer to start of zero tensor - TensorCoord extent, ///< Extent of tensor - int thread_id, ///< ID of each participating thread - int group_size) - : FineGrainedScaleZeroIterator( - params, pointer_scale, pointer_zero, extent, thread_id, make_Coord(0, 0), group_size) - { - } - - CUTLASS_DEVICE - void add_tile_offset(TensorCoord const& tile_offset) - { - const LongIndex row_byte_offset = tile_offset.row() * params_.inc_advance_; - const LongIndex col_byte_offset = tile_offset.column() * Shape::kColumn * sizeof_bits::value / 8; - pointer_scale_ += row_byte_offset + col_byte_offset; - if (pointer_zero_ != nullptr) - { - pointer_zero_ += row_byte_offset + col_byte_offset; - } - } - - /// Clears the predicate set efficiently - CUTLASS_HOST_DEVICE void clear_mask(bool enable = true) - { - is_valid_ &= (!enable); - } - - /// Returns whether access is valid or not - CUTLASS_HOST_DEVICE - bool valid() const - { - return is_valid_; - } - - /// Returns a scale pointer - CUTLASS_HOST_DEVICE - AccessType* get_scale() const - { - return reinterpret_cast(pointer_scale_); - } - - /// Returns a zero pointer - CUTLASS_HOST_DEVICE - AccessType* get_zero() const - { - return reinterpret_cast(pointer_zero_); - } -}; - -} // namespace threadblock -} // namespace transform -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp deleted file mode 100644 index b430380b014..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp +++ /dev/null @@ -1,181 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cute/layout.hpp" -#include "cute/tensor.hpp" -#include "cute/util/print.hpp" - -using namespace cute; - -/// Function object that applies an index to its argument -template -struct IndexedGather -{ - CUTE_HOST_DEVICE constexpr IndexedGather(Iter indices = {}) - : indices_(indices) - { - } - - template - CUTE_HOST_DEVICE constexpr auto operator()(I i) const - { - return indices_[i]; - } - - CUTE_HOST_DEVICE friend void print(IndexedGather const& s) - { - cute::print("Indexed{"); - print(s.indices_); - print("}"); - } - - Iter indices_; -}; - -/// Custom stride object that applies a function followed by a stride -template -struct CustomStride -{ - CUTE_HOST_DEVICE constexpr CustomStride(Func const& func, Stride const& stride) - : func_(func) - , stride_(stride) - { - } - - template - CUTE_HOST_DEVICE constexpr friend auto operator*(I i, CustomStride const& s) - { - return s.func_(i) * s.stride_; - } - - template - CUTE_HOST_DEVICE constexpr friend auto operator*(CustomStride const& s, I i) - { - return s.func_(i) * s.stride_; - } - - CUTE_HOST_DEVICE friend void print(CustomStride const& s) - { - cute::print("Custom{"); - print(s.func_); - cute::print(","); - print(s.stride_); - cute::print("}"); - } - - template - CUTE_HOST_DEVICE constexpr friend auto safe_div(CustomStride const& s, Div const& div) - { - return CustomStride(s.func_, safe_div(s.stride_, div)); - } - - // Circumvent the requirement on make_layout that shape and stride are integral - template - CUTE_HOST_DEVICE constexpr friend auto make_layout(Shape const& shape, CustomStride const& stride) - { - return Layout(shape, stride); - } - - Func func_; - Stride stride_; -}; - -template -CUTLASS_HOST_DEVICE auto make_custom_stride_layout(Stride const& stride, Func&& func) -{ - // Use a dummy shape and replace the first non-unit and non-zero stride with a custom gather stride - auto idx = find_if(stride, [](auto x) { return !is_constant<1, decltype(x)>{} && !is_constant<0, decltype(x)>{}; }); - constexpr int I = decltype(idx)::value; - return make_layout( - repeat_like(stride, _1{}), replace(stride, CustomStride{static_cast(func), get(stride)})); -} - -/// Helper function to optionally create a gather tensor -template -CUTLASS_HOST_DEVICE auto make_gather_tensor(Iterator iter, Shape const& shape, Stride const& stride, Func&& func) -{ - Layout matrix_layout = make_identity_layout(shape); - auto offset = as_arithmetic_tuple(repeat_like(shape, _0{})); - Layout gather_layout = make_custom_stride_layout(stride, static_cast(func)); - return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout}); -} - -namespace cute -{ - -template -CUTE_HOST_DEVICE constexpr auto upcast(Shape const& shape, Stride const& stride) -{ - if constexpr (is_tuple::value) - { - return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast(s, d); }); - } - else if constexpr (is_scaled_basis::value) - { - if constexpr (Stride::mode() == I) - { - return make_layout(shape_div(shape, Int{}), shape_div(stride, Int{})); - } - else - { - return make_layout(shape, stride); - } - } - else - { - return upcast(shape, stride); - } - - CUTE_GCC_UNREACHABLE; -} - -template -CUTE_HOST_DEVICE constexpr auto upcast( - ComposedLayout, Offset, Layout> const& layout) -{ - // Find index of the stride-1 mode - that is the only one that requires updating inner shape and offset - auto idx = find_if(layout.layout_a().stride(), [](auto x) { return is_constant<1, decltype(x)>{}; }); - constexpr int I = decltype(idx)::value; - - // Upcast the outer layout (works as expected) - auto outer = upcast(layout.layout_a()); - - // Upcast the accumulated offset along stride-1 mode - auto offset = as_arithmetic_tuple(replace(layout.offset(), upcast(get(layout.offset())))); - - // Upcast the inner layout's shape along stride-1 mode - auto inner = upcast(layout.layout_b().shape(), layout.layout_b().stride()); - - return composition(outer, offset, inner); -} - -} // namespace cute diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/weight_only_quant_op.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/weight_only_quant_op.h deleted file mode 100644 index 64774428e9f..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/weight_only_quant_op.h +++ /dev/null @@ -1,58 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. -*/ - -#pragma once - -namespace cutlass -{ - -enum class WeightOnlyQuantOp -{ - UNDEFINED, - PER_COLUMN_SCALE_ONLY, - FINEGRAINED_SCALE_ONLY, - FINEGRAINED_SCALE_AND_ZEROS -}; - -constexpr bool isFinegrained(WeightOnlyQuantOp op) -{ - return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS || op == WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; -} - -constexpr bool hasZero(WeightOnlyQuantOp op) -{ - return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; -} - -} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h deleted file mode 100644 index f4eed277c18..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.h +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -namespace tensorrt_llm::kernels::cutlass_kernels -{ -template -void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B, - ElementType_ const* biases, bool bias_is_broadcast, ElementType_* C, int64_t const* total_tokens_including_expert, - int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, - int* kernel_occupancy); -} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl deleted file mode 100644 index 126e761ec93..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/fused_moe_gemm_launcher_sm80.inl +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "cutlass/array.h" -#include "cutlass/numeric_conversion.h" - -#include "cutlass/gemm/device/gemm_grouped.h" -#include "cutlass/gemm/kernel/default_gemm_grouped.h" - -#include "cute/tensor.hpp" -#include "cutlass/cutlass.h" - -#include -#include -#include - -namespace tensorrt_llm::kernels::cutlass_kernels -{ -template -void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWeightType_ const* B, - ElementType_ const* biases, bool bias_is_broadcast, ElementType_* C, int64_t const* total_tokens_including_expert, - int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, int multi_processor_count, cudaStream_t stream, - int* kernel_occupancy) -{ - constexpr auto activation_type = fused_moe::EpilogueRouting(true); - using GemmType = fused_moe::Fused_Moe_Kernel_sm80; - - // make sure GPU has enough resources.. - if (kernel_occupancy != nullptr) - { - constexpr int smem_size = GemmType::kSmemSize; - - if (smem_size > (48 << 10)) - { - cudaFuncAttributes attr{}; - int device = 0; - int max_smem_per_block = 0; - tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device)); - tensorrt_llm::common::check_cuda_error( - cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); - tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, fused_moe::run_global)); - if (smem_size + attr.sharedSizeBytes >= static_cast(max_smem_per_block)) - { - // This should mean that - // cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - // smem_size) wouldn't work. In that case, we return an occupancy of 0. This will cause the - // heuristic to ignore this configuration. - *kernel_occupancy = 0; - return; - } - } - - int max_active_blocks = -1; - tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_active_blocks, fused_moe::run_global, GemmType::kThreadCount, smem_size)); - *kernel_occupancy = max_active_blocks; - return; - } - int occupancy = std::min(2, fused_moe::fused_gemm_maximum_active_blocks()); - int const threadblock_count = multi_processor_count * occupancy; - TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run fused_moe kernel"); - using Arguments = typename GemmType::Arguments; - Arguments args{{const_cast(A), const_cast(B), const_cast(biases), - reinterpret_cast(C), total_tokens_including_expert, static_cast(gemm_n), - static_cast(gemm_k), num_experts, bias_is_broadcast}, - num_experts, threadblock_count}; - auto params = GemmType::to_underlying_arguments(args); - if (GemmType::kSmemSize >= (48 << 10)) - { - cudaError_t result = cudaFuncSetAttribute( - fused_moe::run_global, cudaFuncAttributeMaxDynamicSharedMemorySize, GemmType::kSmemSize); - TLLM_CHECK_WITH_INFO(result == cudaSuccess, - "Fail to set the max smem size to " + std::to_string(GemmType::kSmemSize) + " for fused moe kernel"); - } - dim3 grid(params.threadblock_count, 1, 1); - dim3 block(GemmType::kThreadCount); - fused_moe::run_global<<>>(params); - auto result = cudaGetLastError(); - TLLM_CHECK_WITH_INFO(result == cudaSuccess, "Fail to execute fused moe kernel, cuda error %d\n", (int) (result)); -} -} // namespace tensorrt_llm::kernels::cutlass_kernels diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h deleted file mode 100644 index 91527fadb67..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h +++ /dev/null @@ -1,37 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" -#include - -namespace tensorrt_llm -{ -namespace kernels -{ -namespace cutlass_kernels -{ - -// Keep in sync with the signature generated by generate_kernels.py -template -void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, int num_experts, - int multi_processor_count, cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size); - -} // namespace cutlass_kernels -} // namespace kernels -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.inl b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.inl deleted file mode 100644 index cca60a9816f..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.inl +++ /dev/null @@ -1,348 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include "cutlass/array.h" -#include "cutlass/numeric_conversion.h" - -#include "cutlass/gemm/device/gemm_grouped.h" -#include "cutlass/gemm/kernel/default_gemm_grouped.h" - -#include "cutlass/cutlass.h" - -#include "cute/tensor.hpp" - -#include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/epilogue/thread/linear_combination.h" -#include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/group_array_problem_shape.hpp" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/tensor_ref.h" - -#include "cutlass_extensions/compute_occupancy.h" -#include "cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp" -#include "cutlass_extensions/epilogue_helpers.h" -#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" -#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h" -#include "cutlass_extensions/gemm/threadblock/default_mma.h" - -#include "tensorrt_llm/common/assert.h" -#include "tensorrt_llm/common/cudaUtils.h" -#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" -#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h" -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h" - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h" -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" - -#include -#include -#include -#include - -namespace tensorrt_llm -{ -namespace kernels -{ -namespace cutlass_kernels -{ -using EpilogueFusion = HopperGroupedGemmInput::EpilogueFusion; - -// Hopper helper class for defining all the cutlass helper types -template -struct HopperGroupedGemmInfo -{ - using Arch = cutlass::arch::Sm90; - - // TODO Update once mixed input support is added - static_assert(cutlass::platform::is_same::value, - "CUTLASS does not currently have specialised SM90 support for quantized operations"); - -#ifdef ENABLE_FP8 - constexpr static bool IsFP8 - = cutlass::platform::is_same::value || cutlass::platform::is_same::value; -#else - constexpr static bool IsFP8 = false; -#endif - -#ifdef ENABLE_BF16 - static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value - || cutlass::platform::is_same::value || IsFP8, - "Specialized for bfloat16, half, float, fp8"); -#else - static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value || IsFP8, - "Specialized for half, float, fp8"); -#endif - - static_assert(cutlass::platform::is_same::value - || cutlass::platform::is_same::value - || cutlass::platform::is_same::value - || cutlass::platform::is_same::value - || cutlass::platform::is_same::value, - "Unexpected quantization type"); - - // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. - using ElementType = typename TllmToCutlassTypeAdapter::type; - - using CutlassWeightTypeMaybeUint4 = typename TllmToCutlassTypeAdapter::type; - // For legacy reasons we convert unsigned 8-bit to signed - using CutlassWeightTypeMaybeUint8 - = std::conditional_t, cutlass::int4b_t, - CutlassWeightTypeMaybeUint4>; - using CutlassWeightType - = std::conditional_t, int8_t, CutlassWeightTypeMaybeUint8>; - - using ElementA = ElementType; - using ElementB = CutlassWeightType; - - using ElementD = typename TllmToCutlassTypeAdapter>::type; - using ElementFinalOutput = typename TllmToCutlassTypeAdapter::type; - - // using ElementC = std::conditional_t; - // using ElementCNoVoid = std::conditional_t; - using ElementC = void; - using ElementCNoVoid = ElementD; - - using ElementAccumulator = float; - - using ElementBias = ElementFinalOutput; - using ElementRouterScales = float; - - // A matrix configuration - this is transposed and swapped with B - using LayoutA = HopperGroupedGemmInput::LayoutA; - constexpr static int AlignmentA - = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units - // of elements (up to 16 bytes) - - // B matrix configuration - this is transposed and swapped with A - using LayoutB = HopperGroupedGemmInput::LayoutB; // Layout type for B matrix operand - constexpr static int AlignmentB - = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units - // of elements (up to 16 bytes) - - // C matrix configuration - using LayoutC = HopperGroupedGemmInput::LayoutC; // Layout type for C matrix operand - using StrideC = HopperGroupedGemmInput::StrideC; - // Note we use ElementType here deliberately, so we don't break when BIAS is disabled - constexpr static int AlignmentC - = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units - // of elements (up to 16 bytes) - - // D matrix configuration - using LayoutD = HopperGroupedGemmInput::DefaultEpilogue::LayoutD; - using StrideD = HopperGroupedGemmInput::DefaultEpilogue::StrideD; - constexpr static int AlignmentD - = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of D matrix - // in units of elements (up to 16 bytes) - - static_assert(cutlass::platform::is_same::value, - "Hopper Grouped GEMM specialisation doesn't support fused activation"); - - using EpilogueOp - = cutlass::epilogue::fusion::LinearCombination; - - // TODO Add mode for fused activation once CUTLASS adds support - // using EpilogueSchedule = cutlass::platform::conditional_t< - // cutlass::platform::is_same::value, - // cutlass::epilogue::PtrArrayNoSmemWarpSpecialized, - // cutlass::epilogue::?????????????????? /// <<<<<< what supports activations - // >; - using EpilogueSchedule = cutlass::epilogue::PtrArrayNoSmemWarpSpecialized; - - // Epilogue For Default Finalize - using CollectiveEpilogueDefault = typename cutlass::epilogue::collective::CollectiveBuilder< // - Arch, cutlass::arch::OpClassTensorOp, // - TileShape, ClusterShape, // - cutlass::epilogue::collective::EpilogueTileAuto, // - ElementAccumulator, ElementAccumulator, // - ElementC, LayoutC*, AlignmentC, // - ElementD, LayoutD*, AlignmentD, // - EpilogueSchedule>::CollectiveOp; - - // Epilogue For Fused Finalize - using CollectiveEpilogueFinalize = typename cutlass::epilogue::collective::EpilogueMoeFusedFinalizeBuilder< // - TileShape, // - ElementCNoVoid, StrideC*, // - ElementFinalOutput, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideFinalOutput, // - ElementAccumulator, // - ElementAccumulator, // - ElementBias, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideBias, // - ElementRouterScales, HopperGroupedGemmInput::FusedFinalizeEpilogue::StrideRouterScales // - >::CollectiveOp; - - using CollectiveEpilogue - = std::conditional_t; - - using StageCountAutoCarveout = cutlass::gemm::collective::StageCountAutoCarveout( - sizeof(typename CollectiveEpilogue::SharedStorage))>; - - using KernelSchedule - = std::conditional_t; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< // - Arch, cutlass::arch::OpClassTensorOp, // - CutlassWeightType, LayoutB*, AlignmentB, // A & B swapped here - ElementType, LayoutA*, AlignmentA, // - ElementAccumulator, // - TileShape, ClusterShape, // - StageCountAutoCarveout, KernelSchedule>::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal; - - using GemmGrouped = cutlass::gemm::device::GemmUniversalAdapter; -}; - -// Hopper specialised version -template -void sm90_generic_moe_gemm_kernelLauncher(HopperGroupedGemmInput hopper_input, int num_experts, - int const multi_processor_count, cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size) -{ -#ifdef COMPILE_HOPPER_TMA_GEMMS - using namespace cute; - if constexpr (!should_filter_sm90_gemm_problem_shape_v) - { - using GemmInfo - = HopperGroupedGemmInfo; - - using ElementAccumulator = typename GemmInfo::ElementAccumulator; - using ElementA = typename GemmInfo::ElementA; - using ElementB = typename GemmInfo::ElementB; - using ElementC = typename GemmInfo::ElementC; - using ElementCNoVoid = typename GemmInfo::ElementCNoVoid; - using ElementD = typename GemmInfo::ElementD; - - using CollectiveMainloop = typename GemmInfo::CollectiveMainloop; - using CollectiveEpilogue = typename GemmInfo::CollectiveEpilogue; - using GemmKernel = typename GemmInfo::GemmKernel; - using GemmGrouped = typename GemmInfo::GemmGrouped; - - if (kernel_occupancy != nullptr) - { - *kernel_occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel(); - return; - } - - cutlass::KernelHardwareInfo hw_info; - hw_info.device_id = 0; - hw_info.sm_count = multi_processor_count; - - GemmGrouped gemm; - - if (workspace_size != nullptr) - { - // Make a mock problem shape with just the minimal information actually required to get the workspace size - // This makes some assumptions about CUTLASS's implementation which is suboptimal. We have a check later to - // catch future cutlass updates causing silent breakages, but that is not fool proof. - // The alternative is to wait until we have data and then dynamically allocate the workspace - typename HopperGroupedGemmInput::ProblemShape shape_info{num_experts, nullptr, nullptr}; - - typename GemmGrouped::Arguments args{ - cutlass::gemm::GemmUniversalMode::kGrouped, shape_info, {}, {}, hw_info}; - *workspace_size = gemm.get_workspace_size(args); - return; - } - - using MainloopArguments = typename CollectiveMainloop::Arguments; - TLLM_CHECK(hopper_input.stride_a); - TLLM_CHECK(hopper_input.stride_b); - TLLM_CHECK(hopper_input.ptr_a); - TLLM_CHECK(hopper_input.ptr_b); - - MainloopArguments const mainloop_params = {reinterpret_cast(hopper_input.ptr_b), - hopper_input.stride_b, reinterpret_cast(hopper_input.ptr_a), hopper_input.stride_a}; - - typename GemmGrouped::EpilogueOutputOp::Params epilogue_scalars{ - ElementAccumulator(1.f), hopper_input.ptr_c ? ElementAccumulator(1.f) : ElementAccumulator(0.f)}; - epilogue_scalars.alpha_ptr_array = hopper_input.alpha_scale_ptr_array; - using EpilogueArguments = typename CollectiveEpilogue::Arguments; - // TODO(dastokes) ptr_c casts to ElementCNoVoid** because there is a workaround in CUTLASS - auto make_epi_args = [&]() - { - if constexpr (FUSION == EpilogueFusion::NONE) - { - auto epi_params = hopper_input.default_epilogue; - return EpilogueArguments{epilogue_scalars, reinterpret_cast(hopper_input.ptr_c), - hopper_input.stride_c, reinterpret_cast(epi_params.ptr_d), epi_params.stride_d}; - } - else if constexpr (FUSION == EpilogueFusion::FINALIZE) - { - // Parameters for fused finalize - auto epi_params = hopper_input.fused_finalize_epilogue; - return EpilogueArguments{ - epilogue_scalars, // Parameters to underlying epilogue - reinterpret_cast(hopper_input.ptr_c), hopper_input.stride_c, // C params - reinterpret_cast(epi_params.ptr_final_output), - epi_params.stride_final_output, // D (output) params - reinterpret_cast(epi_params.ptr_bias), - epi_params.stride_bias, // Bias params - epi_params.ptr_router_scales, epi_params.stride_router_scales, // Router scales - epi_params.ptr_expert_first_token_offset, // Offset of this expert's token in the router scales - epi_params.ptr_source_token_index, // Index of the source token to sum into - epi_params.num_rows_in_final_output // Number of tokens in the output buffer - }; - } - else - { - static_assert( - sizeof(EpilogueArguments) == 0, "Unimplemented fusion provided to SM90+ MoE gemm launcher"); - } - }; - EpilogueArguments const epilogue_params = make_epi_args(); - - typename GemmKernel::TileScheduler::Arguments scheduler_args{ - 1, GemmKernel::TileScheduler::RasterOrderOptions::AlongN}; - - typename GemmGrouped::Arguments args{cutlass::gemm::GemmUniversalMode::kGrouped, hopper_input.shape_info, - mainloop_params, epilogue_params, hw_info, scheduler_args}; - - size_t calculated_ws_size = gemm.get_workspace_size(args); - TLLM_CHECK_WITH_INFO(calculated_ws_size <= hopper_input.gemm_workspace_size, - "Workspace is size %zu but only %zu were allocated", calculated_ws_size, hopper_input.gemm_workspace_size); - - auto can_implement = gemm.can_implement(args); - TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess, - "Grouped GEMM kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement))); - - auto init_status = gemm.initialize(args, hopper_input.gemm_workspace); - TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess, - "Failed to initialize cutlass SM90 grouped gemm. Error: " - + std::string(cutlassGetStatusString(init_status))); - - auto run_status = gemm.run(stream); - TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess, - "Failed to run cutlass SM90 grouped gemm. Error: " + std::string(cutlassGetStatusString(run_status))); - sync_check_cuda_error(); - } - else - { - TLLM_THROW("Configuration was disabled by FAST_BUILD"); - } - -#else // COMPILE_HOPPER_TMA_GEMMS - TLLM_THROW("Please recompile with support for hopper by passing 90-real as an arch to build_wheel.py."); -#endif // COMPILE_HOPPER_TMA_GEMMS -} - -} // namespace cutlass_kernels -} // namespace kernels -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_hopper_input.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_hopper_input.cu deleted file mode 100644 index 9862460dd6a..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_hopper_input.cu +++ /dev/null @@ -1,131 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" - -#include "cutlass/cutlass.h" - -#include "cute/tensor.hpp" -#include "cutlass/conv/convolution.h" -// Order matters here, packed_stride.hpp is missing cute and convolution includes -#include "cutlass/util/packed_stride.hpp" - -#include "tensorrt_llm/common/logger.h" - -namespace tensorrt_llm -{ -std::array HopperGroupedGemmInput::workspaceBuffers(int num_experts) -{ - size_t problem_shape_size = sizeof(ProblemShape::UnderlyingProblemShape) * num_experts; - size_t stride_a_size = sizeof(StrideA) * num_experts; - size_t stride_b_size = sizeof(StrideB) * num_experts; - size_t stride_c_size = sizeof(StrideC) * num_experts; - size_t stride_d_size = sizeof(DefaultEpilogue::StrideD) * num_experts; - - size_t ptr_buf_size = sizeof(void*) * num_experts; - size_t scale_buf_size = sizeof(float*) * num_experts; - - return std::array{problem_shape_size, stride_a_size, stride_b_size, stride_c_size, stride_d_size, ptr_buf_size, - ptr_buf_size, ptr_buf_size, ptr_buf_size, scale_buf_size}; -} - -size_t HopperGroupedGemmInput::workspaceSize(int num_experts) -{ - auto buffers = workspaceBuffers(num_experts); - return tensorrt_llm::common::calculateTotalWorkspaceSize(buffers.data(), buffers.size()); -} - -void HopperGroupedGemmInput::configureWorkspace( - int8_t* start_ptr, int num_experts, void* gemm_workspace, size_t gemm_workspace_size) -{ - auto buffers = workspaceBuffers(num_experts); - std::array pointers{}; - TLLM_CHECK_WITH_INFO(pointers.size() == buffers.size(), "Mismatching workspace size and number of buffers"); - for (int i = 0; i < buffers.size(); i++) - { - pointers[i] = start_ptr; - start_ptr = tensorrt_llm::common::nextWorkspacePtr(start_ptr, buffers[i]); - } - - shape_info.num_groups = num_experts; - shape_info.problem_shapes = reinterpret_cast(pointers[0]); - shape_info.host_problem_shapes = nullptr; - stride_a = reinterpret_cast(pointers[1]); - stride_b = reinterpret_cast(pointers[2]); - stride_c = reinterpret_cast(pointers[3]); - default_epilogue.stride_d = reinterpret_cast(pointers[4]); - - ptr_a = reinterpret_cast(pointers[5]); - ptr_b = reinterpret_cast(pointers[6]); - ptr_c = reinterpret_cast(pointers[7]); - default_epilogue.ptr_d = reinterpret_cast(pointers[8]); - - alpha_scale_ptr_array = reinterpret_cast(pointers[9]); - - this->gemm_workspace = reinterpret_cast(gemm_workspace); - this->gemm_workspace_size = gemm_workspace_size; -} - -void HopperGroupedGemmInput::setFinalizeFusionParams(void* final_output, float const* router_scales, - int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size, - int num_output_tokens) -{ - fused_finalize_epilogue.ptr_final_output = final_output; - fused_finalize_epilogue.ptr_router_scales = router_scales; - fused_finalize_epilogue.ptr_bias = bias; - fused_finalize_epilogue.ptr_expert_first_token_offset = expert_first_token_offset; - fused_finalize_epilogue.ptr_source_token_index = source_token_index; - - fused_finalize_epilogue.stride_final_output - = cutlass::make_cute_packed_stride(FusedFinalizeEpilogue::StrideFinalOutput{}, - transpose_stride(cute::make_shape(num_output_tokens, hidden_size, 1))); - fused_finalize_epilogue.stride_bias - = transpose_stride(cute::make_stride(cute::Int<0>{}, cute::Int<1>{}, hidden_size)); - fused_finalize_epilogue.stride_router_scales = {}; - - fused_finalize_epilogue.num_rows_in_final_output = num_output_tokens; -} - -std::string HopperGroupedGemmInput::toString() const -{ - std::stringstream ss; - ss << "Hopper Input Information: " << (isValid() ? "valid" : "null") << "\n"; - if (isValid()) - { - ss << "Ptr A: " << ptr_a << ", Ptr B: " << ptr_b << ", Ptr C: " << ptr_c << "\n"; - ss << "Epilogue Fusion: " << (int) fusion; - if (fusion == HopperGroupedGemmInput::EpilogueFusion::FINALIZE) - { - ss << ",\nFinal Output: " << fused_finalize_epilogue.ptr_final_output; - ss << " with Stride: " << fused_finalize_epilogue.stride_router_scales; - ss << ",\nBias: " << fused_finalize_epilogue.ptr_bias; - ss << " with Stride: " << fused_finalize_epilogue.stride_bias; - ss << ",\nRouter Scales: " << fused_finalize_epilogue.ptr_router_scales; - ss << " with Stride: " << fused_finalize_epilogue.stride_router_scales; - ss << ",\nExpert Offset: " << fused_finalize_epilogue.ptr_expert_first_token_offset; - ss << ", Source Map: " << fused_finalize_epilogue.ptr_source_token_index; - } - else - { - ss << ", Ptr D: " << default_epilogue.ptr_d; - } - ss << '\n'; - ss << "Alpha scale ptr: " << alpha_scale_ptr_array << "\n"; - } - return ss.str(); -} -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h deleted file mode 100644 index 0616c063654..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h +++ /dev/null @@ -1,230 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include "tensorrt_llm/common/cudaFp8Utils.h" -#include "tensorrt_llm/common/workspace.h" -#include "tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h" -#include -#include -#include -#include - -#include "cute/tensor.hpp" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/group_array_problem_shape.hpp" -#include "cutlass/layout/layout.h" - -namespace tensorrt_llm -{ -template -constexpr auto transpose_stride(T const& t) -{ - return cute::prepend(cute::prepend(cute::take<2, cute::rank_v>(t), cute::get<0>(t)), cute::get<1>(t)); -} - -struct HopperGroupedGemmInput -{ - template - using TransposeStride = decltype(transpose_stride(T{})); - template - using TransposeLayoutTag = std::conditional_t, - cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>; - - static_assert(std::is_same_v>); - static_assert(std::is_same_v>); - - // Layout for A and B is transposed and then swapped in the implementation - // This uses B^T * A^T = (A * B)^T to get a better layout for the GEMM - using LayoutA = TransposeLayoutTag; // Layout type for A matrix operand - using LayoutB = TransposeLayoutTag; // Layout type for B matrix operand - using LayoutC = TransposeLayoutTag; // Layout type for C matrix operand - - using StrideA - = std::remove_pointer_t>; // Use B because they will be swapped - using StrideB - = std::remove_pointer_t>; // Use A because they will be swapped - using StrideC = std::remove_pointer_t>; - - template - constexpr static bool IsFP8_v = std::is_same_v || std::is_same_v; - - // Currently this should always just be T - template - using OutputTypeAdaptor_t = std::conditional_t, nv_bfloat16, T>; - - using ProblemShape = cutlass::gemm::GroupProblemShape>; - - ProblemShape shape_info{}; - StrideA* stride_a = nullptr; - StrideB* stride_b = nullptr; - - void const** ptr_a = nullptr; - void const** ptr_b = nullptr; - - // C is currently the same in both epilogues - StrideC* stride_c = nullptr; - void const** ptr_c = nullptr; - - struct DefaultEpilogue - { - using LayoutD = TransposeLayoutTag; // Layout type for D matrix operand - using StrideD = std::remove_pointer_t>; - - StrideD* stride_d = nullptr; - void** ptr_d = nullptr; - }; - - struct FusedFinalizeEpilogue - { - using StrideFinalOutput = DefaultEpilogue::StrideD; - using StrideBias = TransposeStride>; - using StrideRouterScales = TransposeStride>; - - void* ptr_final_output = nullptr; - StrideFinalOutput stride_final_output{}; - - void const* ptr_bias = nullptr; - StrideBias stride_bias{}; - - float const* ptr_router_scales = nullptr; - StrideRouterScales stride_router_scales{}; - - int64_t const* ptr_expert_first_token_offset = nullptr; - int const* ptr_source_token_index = nullptr; - - size_t num_rows_in_final_output = 0; - }; - - DefaultEpilogue default_epilogue; - FusedFinalizeEpilogue fused_finalize_epilogue; - - enum class EpilogueFusion - { - NONE, - ACTIVATION, - GATED_ACTIVATION, - FINALIZE - }; - EpilogueFusion fusion = EpilogueFusion::NONE; - - float const** alpha_scale_ptr_array = nullptr; - - uint8_t* gemm_workspace = nullptr; - size_t gemm_workspace_size = 0; - - static std::array workspaceBuffers(int num_experts); - - static size_t workspaceSize(int num_experts); - - void configureWorkspace(int8_t* start_ptr, int num_experts, void* gemm_workspace, size_t gemm_workspace_size); - - bool isValid() const - { - return stride_a != nullptr && ptr_a != nullptr; - } - - void setFinalizeFusionParams(void* final_output, float const* router_scales, - int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size, - int num_output_tokens); - - std::string toString() const; -}; - -// Note update moe.py to match -enum class ActivationType -{ - Gelu = 0, - Relu, - Silu, - Swiglu, - Geglu, - Identity, - InvalidType -}; - -constexpr bool isGatedActivation(ActivationType activation_type) -{ - return activation_type == ActivationType::Swiglu || activation_type == ActivationType::Geglu; -} - -template -class MoeGemmRunner -{ -public: - MoeGemmRunner(); - -#if defined(ENABLE_FP8) - static constexpr bool use_fp8 = std::is_same_v || std::is_same_v; -#else - static constexpr bool use_fp8 = false; -#endif - - void moeGemmBiasAct(T const* A, WeightType const* B, ScaleBiasType const* weight_scales, - ScaleBiasType const* biases, bool bias_is_broadcast, void* C, int64_t const* total_tokens_including_expert, - HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - ActivationType activation_type, bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, - cutlass_extensions::CutlassGemmConfig chosen_conf); - - void moeGemm(T const* A, WeightType const* B, ScaleBiasType const* weight_scales, void* C, - int64_t const* total_tokens_including_expert, HopperGroupedGemmInput layout_info, int64_t total_rows, - int64_t gemm_n, int64_t gemm_k, int num_experts, bool use_fused_moe, float const** alpha_scale_ptr_array, - cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf); - - std::vector getConfigs() const; - static std::vector getConfigs(int sm); - static std::vector getHopperConfigs(int sm); - static std::vector getAmpereConfigs(int sm); - - [[nodiscard]] bool isHopperSpecialised(cutlass_extensions::CutlassGemmConfig gemm_config) const; - [[nodiscard]] bool supportsHopperSpecialisation() const; - [[nodiscard]] bool isFusedGatedActivation( - cutlass_extensions::CutlassGemmConfig gemm_config, bool is_gated_activation, int gemm_n, int gemm_k) const; - [[nodiscard]] bool supportsFusedGatedActivation(bool is_gated_activation, int gemm_n, int gemm_k) const; - - size_t getMaxWorkspaceSize(int num_experts) const; - - [[nodiscard]] int getSM() const; - -private: - template - void dispatchToArch(T const* A, WeightType const* B, ScaleBiasType const* weight_scales, - ScaleBiasType const* biases, bool bias_is_broadcast, void* C, int64_t const* total_tokens_including_expert, - HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, bool use_fused_moe, float const** alpha_scale_ptr_array, - cudaStream_t stream, int* occupancy = nullptr); - - template - void runGemm(T const* A, WeightType const* B, ScaleBiasType const* weight_scales, ScaleBiasType const* biases, - bool bias_is_broadcast, void* C, int64_t const* total_tokens_including_expert, - HopperGroupedGemmInput layout_info, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, - cutlass_extensions::CutlassGemmConfig chosen_conf); - -private: - int sm_{}; - int multi_processor_count_{}; - mutable int num_experts_ = 0; - mutable size_t gemm_workspace_size_ = 0; - size_t calcMaxWorkspaceSize(int num_experts) const; -}; - -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu deleted file mode 100644 index 3aa96502d39..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" - -namespace tensorrt_llm -{ -#ifdef ENABLE_BF16 -template class MoeGemmRunner<__nv_bfloat16, __nv_bfloat16, __nv_bfloat16>; -#endif -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu deleted file mode 100644 index fbb5270455e..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" - -namespace tensorrt_llm -{ -#ifdef ENABLE_BF16 -template class MoeGemmRunner<__nv_bfloat16, cutlass::uint4b_t, __nv_bfloat16>; -#endif -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu deleted file mode 100644 index 78f1a93a6a8..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" - -namespace tensorrt_llm -{ -#ifdef ENABLE_BF16 -template class MoeGemmRunner<__nv_bfloat16, uint8_t, __nv_bfloat16>; -#endif -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu deleted file mode 100644 index 69c4b6a15a8..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" - -namespace tensorrt_llm -{ -template class MoeGemmRunner; -} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu deleted file mode 100644 index 4ffa5485f0f..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" - -namespace tensorrt_llm -{ -template class MoeGemmRunner; -} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu deleted file mode 100644 index 424b817b876..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" - -namespace tensorrt_llm -{ -template class MoeGemmRunner; -} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu deleted file mode 100644 index f317023565c..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" - -namespace tensorrt_llm -{ -template class MoeGemmRunner; -} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu deleted file mode 100644 index c6b8fe78724..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h" - -namespace tensorrt_llm -{ -#ifdef ENABLE_FP8 -template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, half>; -#ifdef ENABLE_BF16 -template class MoeGemmRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>; -#endif -// template class MoeGemmRunner<__nv_fp8_e5m2, __nv_fp8_e5m2>; -#endif -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h deleted file mode 100644 index 2a337e6ca4e..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template.h +++ /dev/null @@ -1,823 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Ignore CUTLASS warnings about type punning -#ifdef __GNUC__ // Check if the compiler is GCC or Clang -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wstrict-aliasing" -#endif - -#include "cutlass/array.h" -#include "cutlass/numeric_conversion.h" - -#include "cutlass/gemm/device/gemm_grouped.h" -#include "cutlass/gemm/kernel/default_gemm_grouped.h" - -#include "cute/tensor.hpp" - -#include "cutlass/cutlass.h" - -#include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/epilogue/thread/linear_combination.h" -#include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/group_array_problem_shape.hpp" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/tensor_ref.h" - -#include "cutlass_extensions/compute_occupancy.h" -#include "cutlass_extensions/epilogue_helpers.h" -#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" -#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h" -#include "cutlass_extensions/gemm/threadblock/default_mma.h" - -#ifdef __GNUC__ // Restore GCC-specific diagnostics -#pragma GCC diagnostic pop -#endif - -#include "tensorrt_llm/common/assert.h" -#include "tensorrt_llm/common/cudaUtils.h" -#include "tensorrt_llm/common/logger.h" - -#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" -#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h" - -#include "moe_gemm_kernels_template_sm90.h" -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h" -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h" -#include - -#include -#include -#include -#include - -namespace tensorrt_llm -{ -namespace kernels::cutlass_kernels -{ - -// ============================= Variable batched Gemm things =========================== -template -void genericMoeGemmKernelLauncher(T const* A, WeightType const* B, GemmOutputType const* weight_scales, - GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, - int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int const multi_processor_count, bool use_fused_moe, - float const** alpha_scale_ptr_array, cudaStream_t stream, int* kernel_occupancy = nullptr) -{ -#if defined(ENABLE_FP8) - static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value - || cutlass::platform::is_same::value - || cutlass::platform::is_same::value || cutlass::platform::is_same::value, - "Specialized for fp8, bfloat16, half, float"); -#elif defined(ENABLE_BF16) - static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value - || cutlass::platform::is_same::value, - "Specialized for bfloat16, half, float"); -#else - static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value, - "Specialized for half, float"); -#endif - - static_assert(cutlass::platform::is_same::value - || cutlass::platform::is_same::value - || cutlass::platform::is_same::value, - ""); - - static_assert(!cutlass::platform::is_same::value, - "Sm90 architecture should use specialised kernels"); - - // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. - using ElementType = typename TllmToCutlassTypeAdapter::type; - using CutlassGemmOutputType = typename TllmToCutlassTypeAdapter::type; - using CutlassWeightType = typename TllmToCutlassTypeAdapter::type; - if (!use_fused_moe) - { - // We need separate config for each architecture since we will target different tensorcore instructions. For - // float, we do not target TCs. - using MixedGemmArchTraits = cutlass::gemm::kernel::MixedGemmArchTraits; - using ElementAccumulator = typename MixedGemmArchTraits::AccType; - - using EpilogueOp = typename tensorrt_llm::cutlass_extensions::Epilogue::Op; - - typename EpilogueOp::Params epilogue_op( - ElementAccumulator(1.f), biases ? ElementAccumulator(1.f) : ElementAccumulator(0.f)); - -#if defined(ENABLE_FP8) - if constexpr ((std::is_same_v - || std::is_same_v) &&std::is_same_v) - { - TLLM_CHECK_WITH_INFO(weight_scales == nullptr && biases == nullptr && alpha_scale_ptr_array, - "weight_scales and biases should be nullptr and alpha_scale_ptr_array shouldn't be nullptr for FP8 " - "Ada"); - epilogue_op.alpha_ptr_array = alpha_scale_ptr_array; - } -#endif - - // Finally, set up the kernel. - using GemmKernel_ = typename cutlass::gemm::kernel::DefaultGemmGrouped::GemmKernel; - - using GemmKernel = cutlass::gemm::kernel::MoeFCGemm; - - using GemmGrouped = cutlass::gemm::device::GemmGrouped; - - if (kernel_occupancy != nullptr) - { - *kernel_occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel(); - return; - } - int occupancy = std::min(2, GemmGrouped::maximum_active_blocks()); - TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run GroupedGEMM kernel"); - int const threadblock_count = multi_processor_count * occupancy; - - int const group_size = gemm_k; - typename GemmGrouped::Arguments args(num_experts, threadblock_count, group_size, epilogue_op, - reinterpret_cast(A), reinterpret_cast(B), - reinterpret_cast(weight_scales), - reinterpret_cast(biases), bias_is_broadcast, - reinterpret_cast(C), total_tokens_including_expert, gemm_n, gemm_k); - - GemmGrouped gemm; - - auto can_implement = gemm.can_implement(args); - TLLM_CHECK_WITH_INFO(can_implement == cutlass::Status::kSuccess, - "MoE FC kernel will fail for params. Error: " + std::string(cutlassGetStatusString(can_implement))); - - auto init_status = gemm.initialize(args); - TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess, - "Failed to initialize cutlass grouped gemm. Error: " + std::string(cutlassGetStatusString(init_status))); - - auto run_status = gemm.run(stream); - TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess, - "Failed to run cutlass grouped gemm. Error: " + std::string(cutlassGetStatusString(run_status))); - } - else if constexpr (sizeof(ElementType) == 2 && sizeof(CutlassWeightType) == 2 - && (std::is_same_v - || std::is_same_v) ) // use fused moe gemm - // kernel.. (only support - // fp16 or bf16) - { - sm80_generic_fused_moe_gemm_kernelLauncher(reinterpret_cast(A), - reinterpret_cast(B), reinterpret_cast(biases), - bias_is_broadcast, reinterpret_cast(C), total_tokens_including_expert, num_rows, gemm_n, - gemm_k, num_experts, multi_processor_count, stream, kernel_occupancy); - } -} - -} // namespace kernels::cutlass_kernels - -template -static void dispatch(T const* A, WeightType const* B, GemmOutputType const* weight_scales, GemmOutputType const* biases, - bool bias_is_broadcast, GemmOutputType* C, int64_t const* total_tokens_including_expert, int64_t num_rows, - int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, - int multi_processor_count, bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, - int* occupancy = nullptr) -{ - - static_assert(!std::is_same_v, "Use TMA specialised functions for arch SM90"); -#if defined(ENABLE_FP8) - constexpr bool isFp8 = std::is_same_v || std::is_same_v; -#else - constexpr bool isFp8 = false; -#endif - - if constexpr ((Stages == 2 || Arch::kMinComputeCapability >= 80) - && (!isFp8 || std::is_same_v) ) - { - kernels::cutlass_kernels::genericMoeGemmKernelLauncher(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - } - else - { - TLLM_THROW( - "Cutlass gemm. Not instantiated for arch %d with stages set to %d", Arch::kMinComputeCapability, Stages); - } -} - -template -void dispatchGemmConfig(T const* A, WeightType const* B, GemmOutputType const* weight_scales, - GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, - int64_t const* total_tokens_including_expert, int64_t num_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, - float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr) -{ - switch (gemm_config.stages) - { - case 2: - dispatch(A, B, weight_scales, - biases, bias_is_broadcast, C, total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts, - gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case 3: - dispatch(A, B, weight_scales, - biases, bias_is_broadcast, C, total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts, - gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case 4: - dispatch(A, B, weight_scales, - biases, bias_is_broadcast, C, total_tokens_including_expert, num_rows, gemm_n, gemm_k, num_experts, - gemm_config, multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - default: TLLM_THROW("dispatchGemmConfig does not support stages %d", gemm_config.stages); break; - } -} - -// This overload will handle tensorop gemms. It is disabled via SFINAE for fp32. -// This overload is only enabled when T == WeightType. -template ::value -#if defined(ENABLE_FP8) - && !std::is_same::value && !std::is_same::value -#endif - && std::is_same::value>::type* = nullptr> -void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales, - GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, - int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, - float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr) -{ - switch (gemm_config.tile_config) - { - case cutlass_extensions::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: - TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); - if constexpr (arch::kMinComputeCapability >= 75) - { - dispatchGemmConfig, - cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, - multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - } - break; - case cutlass_extensions::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: - TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); - if constexpr (arch::kMinComputeCapability >= 75) - { - dispatchGemmConfig, - cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, - multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - } - break; - case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; - case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: - TLLM_THROW("GEMM config should have already been set by heuristic."); - break; - default: TLLM_THROW("Config is invalid for same type tensorop GEMM."); break; - } -} - -// Tensorop GEMM overload -// Overload for quantize MoE GEMMs. We disable some warp configs here since they will not be used and we can improve -// compile time -template ::value && !std::is_same::value>::type* = nullptr> -void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales, - GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, - int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, - float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr) -{ - switch (gemm_config.tile_config) - { - case cutlass_extensions::CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: - TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); - if constexpr (arch::kMinComputeCapability >= 75) - { - dispatchGemmConfig, - cutlass::gemm::GemmShape<16, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, - multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - } - break; - case cutlass_extensions::CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: - TLLM_CHECK_WITH_INFO(arch::kMinComputeCapability >= 75, "Invalid config on Volta"); - if constexpr (arch::kMinComputeCapability >= 75) - { - dispatchGemmConfig, - cutlass::gemm::GemmShape<16, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, - multi_processor_count, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - } - break; - case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<128, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; - case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: - TLLM_THROW("GEMM config should have already been set by heuristic."); - break; - default: TLLM_THROW("Config is invalid for mixed type tensorop GEMM."); break; - } -} - -// This overload will handle tensorop gemms. -// This overload is only enabled when T == WeightType and T == __nv_fp8_e4m3 or __nv_fp8_e5m2 -#if defined(ENABLE_FP8) -template ::value || std::is_same::value) - && std::is_same::value>::type* = nullptr> -void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales, - GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, - int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, - float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr) -{ - switch (gemm_config.tile_config) - { - case cutlass_extensions::CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128: - dispatchGemmConfig, - cutlass::gemm::GemmShape<16, 64, 128>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<32, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<64, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: - dispatchGemmConfig, - cutlass::gemm::GemmShape<64, 64, 64>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; - case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: - TLLM_THROW("GEMM config should have already been set by heuristic."); - break; - default: TLLM_THROW("Config is invalid for same type tensorop GEMM."); break; - } -} -#endif - -// This overload will handle simt gemms. It is disabled via SFINAE for tensorop. -template ::value>::type* = nullptr> -void dispatchMoeGemmToCutlass(T const* A, WeightType const* B, GemmOutputType const* weight_scales, - GemmOutputType const* biases, bool bias_is_broadcast, GemmOutputType* C, - int64_t const* total_tokens_including_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, bool use_fused_moe, - float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy = nullptr) -{ - switch (gemm_config.tile_config) - { - case cutlass_extensions::CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: - dispatchGemmConfig, - cutlass::gemm::GemmShape<64, 64, 8>>(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, multi_processor_count, - use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - break; - case cutlass_extensions::CutlassTileConfig::Undefined: TLLM_THROW("GEMM config undefined."); break; - case cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic: - TLLM_THROW("GEMM config should have already been set by heuristic."); - break; - default: TLLM_THROW("Unsupported config for float MoE gemm."); break; - } -} - -template -std::vector -MoeGemmRunner::getConfigs() const -{ - return getConfigs(sm_); -} - -template -std::vector MoeGemmRunner::getConfigs( - int sm) -{ - std::vector candidate_configs = getHopperConfigs(sm); - std::vector ampere_configs = getAmpereConfigs(sm); - std::copy(ampere_configs.begin(), ampere_configs.end(), std::back_inserter(candidate_configs)); - - return candidate_configs; -} - -template -std::vector -MoeGemmRunner::getAmpereConfigs(int sm) -{ - using tensorrt_llm::cutlass_extensions::CutlassGemmConfig; - static constexpr auto weight_only_flag - = std::is_same::value ? CutlassGemmConfig::NONE : CutlassGemmConfig::WEIGHT_ONLY; - static constexpr auto simt_only_flag - = std::is_same::value ? CutlassGemmConfig::SIMT_ONLY : CutlassGemmConfig::NONE; - static constexpr auto fp8_only_flag = use_fp8 ? CutlassGemmConfig::FP8_ONLY : CutlassGemmConfig::NONE; - int const max_split_k = 1; - int const grouped_gemm_flag = CutlassGemmConfig::GROUPED_GEMM; - int const enable_hopper = CutlassGemmConfig::NONE; - - auto config_type_param = static_cast( - weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper | fp8_only_flag); - - if (!kernels::cutlass_kernels::isValidAmpereMOESpecialisation()) - { - return {}; - } - - std::vector ampere_configs - = kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, config_type_param); - return ampere_configs; -} - -template -std::vector -MoeGemmRunner::getHopperConfigs(int sm) -{ - using tensorrt_llm::cutlass_extensions::CutlassGemmConfig; - static constexpr auto weight_only_flag - = std::is_same::value ? CutlassGemmConfig::NONE : CutlassGemmConfig::WEIGHT_ONLY; - static constexpr auto simt_only_flag - = std::is_same::value ? CutlassGemmConfig::SIMT_ONLY : CutlassGemmConfig::NONE; - int const max_split_k = 1; - int const grouped_gemm_flag = CutlassGemmConfig::GROUPED_GEMM; - int const enable_hopper = CutlassGemmConfig::HOPPER; - static constexpr auto fp8_only_flag = use_fp8 ? CutlassGemmConfig::FP8_ONLY : CutlassGemmConfig::NONE; - auto config_type_param = static_cast( - weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper | fp8_only_flag); - - if (!kernels::cutlass_kernels::isValidHopperMOESpecialisation()) - { - return {}; - } - - std::vector hopper_configs - = kernels::cutlass_kernels::get_candidate_configs(sm, max_split_k, config_type_param); - return hopper_configs; -} - -template -bool MoeGemmRunner::isHopperSpecialised( - cutlass_extensions::CutlassGemmConfig gemm_config) const -{ - bool config_is_sm90 = gemm_config.is_sm90; - return supportsHopperSpecialisation() && config_is_sm90; -} - -template -bool MoeGemmRunner::supportsHopperSpecialisation() const -{ - return sm_ == 90 && kernels::cutlass_kernels::isValidHopperMOESpecialisation(); -} - -template -int MoeGemmRunner::getSM() const -{ - return this->sm_; -} - -// currently support sm80 bf16/fp16 gate activation, only set predication tensor for m direction -template -bool MoeGemmRunner::supportsFusedGatedActivation( - bool is_gated_activation, int gemm_n, int gemm_k) const -{ - constexpr bool ENABLE_FUSED_GATED_ACTIVATION = true; - return is_gated_activation && std::is_same_v && !std::is_same_v && !use_fp8 - && (this->getSM() >= 80) && (gemm_k % 64 == 0) && (gemm_n % 64 == 0) && ENABLE_FUSED_GATED_ACTIVATION; -} - -template -bool MoeGemmRunner::isFusedGatedActivation( - cutlass_extensions::CutlassGemmConfig gemm_config, bool is_gated_activation, int gemm_n, int gemm_k) const -{ - return supportsFusedGatedActivation(is_gated_activation, gemm_n, gemm_k) && !gemm_config.is_sm90; -} - -template -MoeGemmRunner::MoeGemmRunner() -{ - int device{-1}; - tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device)); - sm_ = tensorrt_llm::common::getSMVersion(); - tensorrt_llm::common::check_cuda_error( - cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device)); -} - -template -template -void MoeGemmRunner::dispatchToArch(T const* A, - WeightType const* B, ScaleBiasType const* weight_scales, ScaleBiasType const* biases, bool bias_is_broadcast, - void* C_void, int64_t const* total_tokens_including_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows, - int64_t gemm_n, int64_t gemm_k, int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, - bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, int* occupancy) -{ - static_assert(std::is_same_v, - "Separate Scale/Bias type is not supported. This is assumed to be the gemm output type"); - - // For now we always cast this to output type. - // In the future this will vary based on what fusions are applied for FP8 - auto* C = reinterpret_cast(C_void); - - TLLM_CHECK_WITH_INFO( - sm_ >= 89 || !hopper_input.isValid(), "Hopper input information is set for non specialised implementation"); - TLLM_CHECK_WITH_INFO( - sm_ == 90 || !gemm_config.is_sm90, "Hopper configuration provided for non-Hopper architecture"); - - if (sm_ >= 75 && sm_ < 80) - { - dispatchMoeGemmToCutlass(A, B, weight_scales, - biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k, num_experts, - gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream, occupancy); - } - else if (sm_ >= 80 && sm_ < 90) - { - if constexpr (use_fp8) - { -#if defined(ENABLE_FP8) - static_assert(!std::is_same_v && !std::is_same_v, - "FP8 GEMM Output not supported"); -#endif - - TLLM_CHECK_WITH_INFO(sm_ == 89, "For sm >= 80 and < 90, fp8 is only supported with sm == 89"); - dispatchMoeGemmToCutlass(A, B, - weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k, - num_experts, gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream, - occupancy); - } - else - { - dispatchMoeGemmToCutlass(A, B, - weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k, - num_experts, gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream, - occupancy); - } - } - else if (sm_ >= 90) - { - if constexpr (kernels::cutlass_kernels::isValidHopperMOESpecialisation()) - { - - // We allow both SM90 and SM80 configurations to coexist because for some cases with small numbers of tokens - // SM80 is faster. We check here to see which is selected - if (gemm_config.is_sm90) - { - TLLM_CHECK_WITH_INFO(biases != nullptr || hopper_input.ptr_c == nullptr, - "Input biases and hopper input disagree if bias is enabled"); - TLLM_CHECK_WITH_INFO(hopper_input.isValid(), "Calling SM90 configuration with invalid hopper config"); - - // Select the appropriate fusion function - auto select_function = [&]() - { - switch (hopper_input.fusion) - { - case HopperGroupedGemmInput::EpilogueFusion::FINALIZE: - return &dispatchMoeGemmSelectTileShapeSM90; - case HopperGroupedGemmInput::EpilogueFusion::NONE: - return &dispatchMoeGemmSelectTileShapeSM90; - case HopperGroupedGemmInput::EpilogueFusion::ACTIVATION: - case HopperGroupedGemmInput::EpilogueFusion::GATED_ACTIVATION: - default: TLLM_THROW("Unimplemented fusion %d requested", (int) hopper_input.fusion); - }; - }; - auto selected_func = select_function(); - selected_func( - hopper_input, num_experts, gemm_config, multi_processor_count_, stream, occupancy, nullptr); - return; - } - - // Fallthrough to SM80 impl below - } - - // Do Ampere case instead - if constexpr (kernels::cutlass_kernels::isValidAmpereMOESpecialisation()) - { - TLLM_CHECK_WITH_INFO(!hopper_input.isValid(), - "Non-specialised Hopper implementation is being rerouted to fallback implementation so input " - "information is not required"); - TLLM_CHECK_WITH_INFO(!gemm_config.is_sm90, - "GEMM config is for SM90 configuration, but this configuration is not valid for Hppper"); - dispatchMoeGemmToCutlass(A, B, - weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, total_rows, gemm_n, gemm_k, - num_experts, gemm_config, multi_processor_count_, use_fused_moe, alpha_scale_ptr_array, stream, - occupancy); - } - else - { - TLLM_THROW("Configuration expects SM80 but configuration is not supported by SM80 kernels"); - } - } - else - { - TLLM_THROW("Arch unsupported for MoE GEMM"); - } -} - -template -size_t MoeGemmRunner::getMaxWorkspaceSize(int num_experts) const -{ - if (num_experts != num_experts_) - { - TLLM_LOG_TRACE("Calling getMaxWorkspaceSize() with a new expert count %d vs %d", num_experts, num_experts_); - num_experts_ = num_experts; - gemm_workspace_size_ = calcMaxWorkspaceSize(num_experts); - } - return gemm_workspace_size_; -} - -template -size_t MoeGemmRunner::calcMaxWorkspaceSize(int num_experts) const -{ - if (!supportsHopperSpecialisation()) - { - return 0; - } - if constexpr (kernels::cutlass_kernels::isValidHopperMOESpecialisation()) - { - auto configs = getHopperConfigs(sm_); - size_t max_size = 0; - bool has_config = false; - for (auto conf : configs) - { -#define CALC_SIZE_FUSION(FUSION) \ - do \ - { \ - try \ - { \ - size_t size = calcMaxWorkspaceSizeSM90( \ - num_experts, conf, multi_processor_count_); \ - max_size = std::max(max_size, size); \ - has_config = true; \ - } \ - catch (tensorrt_llm::common::TllmException const& e) \ - { \ - TLLM_LOG_TRACE("Unsupported config skipped when calculating MOE workspace size"); \ - } \ - } while (0) - - CALC_SIZE_FUSION(HopperGroupedGemmInput::EpilogueFusion::NONE); - CALC_SIZE_FUSION(HopperGroupedGemmInput::EpilogueFusion::FINALIZE); - } - TLLM_CHECK_WITH_INFO(has_config, "Could not find valid config when calculating workspace size"); - return max_size; - } - else - { - TLLM_THROW("Attempting to calculate Hopper GEMM workspace size with unsupported weight combination"); - return 0; - } -} - -template -template -void MoeGemmRunner::runGemm(T const* A, WeightType const* B, - ScaleBiasType const* weight_scales, ScaleBiasType const* biases, bool bias_is_broadcast, void* C, - int64_t const* total_tokens_including_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows, - int64_t gemm_n, int64_t gemm_k, int num_experts, bool use_fused_moe, float const** alpha_scale_ptr_array, - cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf) -{ - dispatchToArch(A, B, weight_scales, biases, bias_is_broadcast, C, total_tokens_including_expert, - hopper_input, total_rows, gemm_n, gemm_k, num_experts, chosen_conf, use_fused_moe, alpha_scale_ptr_array, - stream, nullptr); -} - -template -void MoeGemmRunner::moeGemmBiasAct(T const* A, WeightType const* B, - ScaleBiasType const* weight_scales, ScaleBiasType const* biases, bool bias_is_broadcast, void* C, - int64_t const* total_tokens_including_expert, HopperGroupedGemmInput hopper_input, int64_t total_rows, - int64_t gemm_n, int64_t gemm_k, int num_experts, ActivationType activation_type, bool use_fused_moe, - float const** alpha_scale_ptr_array, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig chosen_conf) -{ - switch (activation_type) - { - case ActivationType::Relu: - runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, - alpha_scale_ptr_array, stream, chosen_conf); - break; - case ActivationType::Gelu: - runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, - alpha_scale_ptr_array, stream, chosen_conf); - break; - case ActivationType::Silu: - runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, - alpha_scale_ptr_array, stream, chosen_conf); - break; - case ActivationType::Identity: - runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, - alpha_scale_ptr_array, stream, chosen_conf); - break; - case ActivationType::Swiglu: - runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, - alpha_scale_ptr_array, stream, chosen_conf); - break; - case ActivationType::Geglu: - runGemm(A, B, weight_scales, biases, bias_is_broadcast, C, - total_tokens_including_expert, hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, - alpha_scale_ptr_array, stream, chosen_conf); - break; - case ActivationType::InvalidType: TLLM_THROW("Activation type for fpA_intB must be valid."); break; - default: TLLM_THROW("Invalid activation type."); break; - } -} - -template -void MoeGemmRunner::moeGemm(T const* A, WeightType const* B, - ScaleBiasType const* weight_scales, void* C, int64_t const* total_tokens_including_expert, - HopperGroupedGemmInput hopper_input, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, - bool use_fused_moe, float const** alpha_scale_ptr_array, cudaStream_t stream, - cutlass_extensions::CutlassGemmConfig chosen_conf) -{ - runGemm(A, B, weight_scales, nullptr, true, C, total_tokens_including_expert, - hopper_input, total_rows, gemm_n, gemm_k, num_experts, use_fused_moe, alpha_scale_ptr_array, stream, - chosen_conf); -} - -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template_sm90.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template_sm90.h deleted file mode 100644 index 3efb42f41ef..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_template_sm90.h +++ /dev/null @@ -1,222 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Ignore CUTLASS warnings about type punning -#ifdef __GNUC__ // Check if the compiler is GCC or Clang -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wstrict-aliasing" -#endif // __GNUC__ - -#include "cutlass/array.h" -#include "cutlass/numeric_conversion.h" - -#include "cutlass/gemm/device/gemm_grouped.h" -#include "cutlass/gemm/kernel/default_gemm_grouped.h" - -#include "cutlass/cutlass.h" - -#include "cute/tensor.hpp" - -#include "cutlass/epilogue/collective/collective_builder.hpp" -#include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/epilogue/thread/linear_combination.h" -#include "cutlass/gemm/collective/collective_builder.hpp" -#include "cutlass/gemm/device/gemm_universal_adapter.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/gemm/group_array_problem_shape.hpp" -#include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/tensor_ref.h" - -#include "cutlass_extensions/compute_occupancy.h" -#include "cutlass_extensions/epilogue_helpers.h" -#include "cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h" -#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h" -#include "cutlass_extensions/gemm/threadblock/default_mma.h" - -#ifdef __GNUC__ // Check if the compiler is GCC or Clang -#pragma GCC diagnostic pop -#endif // __GNUC__ - -#include "tensorrt_llm/common/assert.h" -#include "tensorrt_llm/common/cudaUtils.h" -#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" - -#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h" - -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_launcher_sm90.h" -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels.h" -#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h" - -#include -#include -#include -#include - -namespace tensorrt_llm -{ -using EpilogueFusion = HopperGroupedGemmInput::EpilogueFusion; - -template -void dispatchMoeGemmSelectBiasSM90(HopperGroupedGemmInput hopper_input, int num_experts, int multi_processor_count, - cudaStream_t stream, int* occupancy, size_t* workspace_size) -{ - static_assert(kernels::cutlass_kernels::isValidHopperMOESpecialisation(), - "Invalid hopper configuration invoked, fallback to Sm80"); - - TLLM_CHECK_WITH_INFO( - workspace_size || hopper_input.isValid(), "Hopper specialisation is missing additional input information"); - - // auto func = hopper_input.ptr_c ? - // kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper - // : - // kernels::cutlass_kernels::genericMoeGemmKernelLauncherHopper; - // TODO(dastokes) Re-enable bias when CUTLASS supports it - auto func = kernels::cutlass_kernels::sm90_generic_moe_gemm_kernelLauncher; - func(hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size); -} - -/* - 1x1x1 cluster shape is are supported for any tile shape. - - 2x1x1 cluster shape is only supported for when the M tile is at least 128. - - 1x2x1 cluster shape is only supported when the N tile is at least 128. - - 2x2x1 cluster shape is only supported when both the M and N tiles are at least 128. - - We make the above restrictions are to improve compilation speed in TRT-LLM by pruning kernels - that may not be very useful in practice. - */ -template -constexpr bool are_tile_shapes_supported() -{ - using namespace cute; - [[maybe_unused]] constexpr int cta_m = get<0>(CTAShape{}); - [[maybe_unused]] constexpr int cta_n = get<1>(CTAShape{}); - constexpr int cga_m = get<0>(ClusterShape{}); - constexpr int cga_n = get<1>(ClusterShape{}); - - if constexpr (cga_m == _1{} && cga_n == _1{}) - { - return true; - } - else if constexpr (cga_m == _2{} && cga_n == _1{} && cta_m >= _128{}) - { - return true; - } - else if constexpr (cga_m == _1{} && cga_n == _2{} && cta_n >= _128{}) - { - return true; - } - else if constexpr (cga_m == _2{} && cga_n == _2{} && cta_m >= _128{} && cta_n >= _128{}) - { - return true; - } - else - { - return false; - } -} - -template -void dispatchMoeGemmSelectClusterShapeSM90(HopperGroupedGemmInput hopper_input, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy, - size_t* workspace_size) -{ - using namespace cute; - switch (gemm_config.cluster_shape) - { -#define SHAPE_CASE(M, N, K) \ - case cutlass_extensions::ClusterShape::ClusterShape_##M##x##N##x##K: \ - { \ - using ClusterShape = Shape<_##M, _##N, _##K>; \ - if constexpr (are_tile_shapes_supported()) \ - { \ - dispatchMoeGemmSelectBiasSM90( \ - hopper_input, num_experts, multi_processor_count, stream, occupancy, workspace_size); \ - break; \ - } \ - else \ - { \ - TLLM_THROW("Unsupported tile and cluster shape combination"); \ - } \ - } - - SHAPE_CASE(1, 1, 1) - SHAPE_CASE(1, 2, 1) - - SHAPE_CASE(2, 1, 1) - SHAPE_CASE(2, 2, 1) - -#undef SHAPE_CASE - default: TLLM_THROW("Unsupported config for MoE gemm."); - } -} // namespace tensorrt_llm - -template -void dispatchMoeGemmSelectTileShapeSM90(HopperGroupedGemmInput hopper_input, int num_experts, - cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, int* occupancy, - size_t* workspace_size) -{ - using namespace cute; - - switch (gemm_config.tile_config_sm90) - { -#define SHAPE_CASE(M, N, K) \ - case cutlass_extensions::CutlassTileConfigSM90::CtaShape##M##x##N##x##K##B: \ - { \ - constexpr int KtileBytes = K / sizeof(T); \ - using KTileDim = Int; \ - using TileShape = Shape<_##M, _##N, KTileDim>; \ - dispatchMoeGemmSelectClusterShapeSM90( \ - hopper_input, num_experts, gemm_config, multi_processor_count, stream, occupancy, workspace_size); \ - break; \ - } - - SHAPE_CASE(128, 16, 128) - SHAPE_CASE(128, 32, 128) - SHAPE_CASE(128, 64, 128) - SHAPE_CASE(128, 128, 128) - SHAPE_CASE(128, 256, 128) - SHAPE_CASE(256, 128, 128) - -#undef SHAPE_CASE - case cutlass_extensions::CutlassTileConfigSM90::Undefined: TLLM_THROW("GEMM config undefined."); break; - case cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic: - TLLM_THROW("GEMM config should have already been set by heuristic."); - break; - default: TLLM_THROW("Unsupported config for MoE gemm."); break; - } -} - -template -size_t calcMaxWorkspaceSizeSM90( - int num_experts, cutlass_extensions::CutlassGemmConfig gemm_config, int multi_processor_count) -{ - size_t count; - // Most of the values are ignored for WS size calculation. We reuse the function to reduce the template bloat - dispatchMoeGemmSelectTileShapeSM90( - HopperGroupedGemmInput{}, num_experts, gemm_config, multi_processor_count, cudaStream_t{0}, nullptr, &count); - return count; -} - -} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h deleted file mode 100644 index 959d0ea088c..00000000000 --- a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h +++ /dev/null @@ -1,44 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "cutlass/arch/mma_sm90.h" -#include "cutlass_extensions/epilogue_helpers.h" - -namespace tensorrt_llm::kernels::cutlass_kernels -{ - -// Hopper arch -template -constexpr bool isValidHopperMOESpecialisation() -{ -#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) - return cutlass::platform::is_same::value - && cutlass::platform::is_same::value; -#else - return false; // CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED is set when Hopper kernels are enabled -#endif -} - -// Hopper arch -template -constexpr bool isValidAmpereMOESpecialisation() -{ - return true; // Default to true -} - -} // namespace tensorrt_llm::kernels::cutlass_kernels diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 50299140312..90c3cbc1d3c 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -39,8 +39,6 @@ def _get_version(): cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default)) flashinfer = root / "3rdparty" / "flashinfer" turbomind = root / "3rdparty" / "turbomind" -tensorrt_llm_parent = root / "3rdparty" -tensorrt_llm = root / "3rdparty" / "tensorrt_llm" include_dirs = [ cutlass.resolve() / "include", cutlass.resolve() / "tools" / "util" / "include", @@ -53,8 +51,6 @@ def _get_version(): "cublasLt", turbomind.resolve(), turbomind.resolve() / "src", - tensorrt_llm_parent.resolve(), - tensorrt_llm.resolve() / "cutlass_extensions" / "include", ] nvcc_flags = [ diff --git a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py index 34565c9ff65..d9d77a9ae24 100644 --- a/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py +++ b/test/srt/sampling/penaltylib/test_srt_endpoint_with_penalizers.py @@ -36,7 +36,7 @@ def tearDownClass(cls): def run_decode( self, return_logprob=True, - top_logprobs_num=3, + top_logprobs_num=5, return_text=True, n=1, **sampling_params, @@ -58,8 +58,7 @@ def run_decode( "logprob_start_len": 0, }, ) - print(json.dumps(response.json())) - print("=" * 100) + assert response.status_code == 200, "Request failed: " + response.text def test_default_values(self): self.run_decode() @@ -112,4 +111,4 @@ def test_repetition_penalty(self): if __name__ == "__main__": - unittest.main() + unittest.main(verbosity=3)