diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/CMakeLists.txt b/sgl-kernel/3rdparty/tensorrt_llm/common/CMakeLists.txt new file mode 100644 index 00000000000..e479b298db4 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/CMakeLists.txt @@ -0,0 +1,22 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & +# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# +file(GLOB SRCS *.cpp) +file(GLOB CU_SRCS *.cu) + +add_library(common_src OBJECT ${SRCS} ${CU_SRCS}) +set_property(TARGET common_src PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET common_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/assert.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/assert.cpp new file mode 100755 index 00000000000..eaaf6624472 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/assert.cpp @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/assert.h" + +namespace +{ + +bool initCheckDebug() +{ + auto constexpr kDebugEnabled = "TLLM_DEBUG_MODE"; + auto const debugEnabled = std::getenv(kDebugEnabled); + return debugEnabled && debugEnabled[0] == '1'; +} +} // namespace + +bool DebugConfig::isCheckDebugEnabled() +{ + static bool const debugEnabled = initCheckDebug(); + return debugEnabled; +} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.cpp new file mode 100644 index 00000000000..351257f4d2e --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.cpp @@ -0,0 +1,360 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/cublasMMWrapper.h" +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cublasVersionCheck.h" +#include + +#ifndef CUDART_VERSION +#error CUDART_VERSION Undefined! +#endif + +namespace tensorrt_llm +{ +namespace common +{ + +CublasMMWrapper::CublasMMWrapper(std::shared_ptr cublasHandle, + std::shared_ptr cublasltHandle, cudaStream_t stream, void* workspace) + : mCublasHandle(cublasHandle) + , mCublasLtHandle(cublasltHandle) + , mStream(stream) + , mCublasWorkspace(workspace) +{ +} + +CublasMMWrapper::~CublasMMWrapper() {} + +CublasMMWrapper::CublasMMWrapper(CublasMMWrapper const& wrapper) + : mCublasHandle(wrapper.mCublasHandle) + , mCublasLtHandle(wrapper.mCublasLtHandle) + , mStream(wrapper.mStream) +{ +} + +void CublasMMWrapper::createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, + int const k, int const lda, int const ldb, int const ldc, int8_t fastAcc) +{ + // -------------------------------------- + // Create descriptors for the original matrices + check_cuda_error( + cublasLtMatrixLayoutCreate(&mADesc, mAType, transa == CUBLAS_OP_N ? m : k, transa == CUBLAS_OP_N ? k : m, lda)); + check_cuda_error( + cublasLtMatrixLayoutCreate(&mBDesc, mBType, transb == CUBLAS_OP_N ? k : n, transb == CUBLAS_OP_N ? n : k, ldb)); + check_cuda_error(cublasLtMatrixLayoutCreate(&mCDesc, mCType, m, n, ldc)); + check_cuda_error(cublasLtMatmulDescCreate(&mOperationDesc, mComputeType, mScaleType)); + check_cuda_error(cublasLtMatmulDescSetAttribute( + mOperationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &transa, sizeof(cublasOperation_t))); + check_cuda_error(cublasLtMatmulDescSetAttribute( + mOperationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &transb, sizeof(cublasOperation_t))); + check_cuda_error( + cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAcc, sizeof(int8_t))); +} + +void CublasMMWrapper::setScaleDescriptors(void* scale_a, void* scale_b) +{ + check_cuda_error( + cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scale_a, sizeof(void*))); + check_cuda_error( + cublasLtMatmulDescSetAttribute(mOperationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scale_b, sizeof(void*))); +} + +void CublasMMWrapper::destroyDescriptors() +{ + check_cuda_error(cublasLtMatmulDescDestroy(mOperationDesc)); + check_cuda_error(cublasLtMatrixLayoutDestroy(mADesc)); + check_cuda_error(cublasLtMatrixLayoutDestroy(mBDesc)); + check_cuda_error(cublasLtMatrixLayoutDestroy(mCDesc)); + mOperationDesc = NULL; + mADesc = NULL; + mBDesc = NULL; + mCDesc = NULL; +} + +void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc) +{ + Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f); +} + +void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, + std::optional const& heuristic) +{ + if (heuristic) + { + Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, /* hasAlgo */ (*heuristic).algo, + (*heuristic).state == CUBLAS_STATUS_SUCCESS && (*heuristic).workspaceSize < CUBLAS_WORKSPACE_SIZE, + /* usingCublasLt */ true); + } + else + { + Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f, {}, /* hasAlgo */ false, + /* usingCublasLt */ true); + } +} + +void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta, + std::optional const& heuristic) +{ + if (heuristic) + { + Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, /* hasAlgo */ (*heuristic).algo, + (*heuristic).state == CUBLAS_STATUS_SUCCESS && (*heuristic).workspaceSize < CUBLAS_WORKSPACE_SIZE, + /* usingCublasLt */ true); + } + else + { + Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, {}, /* hasAlgo */ false, + /* usingCublasLt */ true); + } +} + +void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta) +{ + bool usingCublasLt = mAType == CUDA_R_16F || mAType == CUDA_R_8F_E4M3; + + Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, f_alpha, f_beta, {}, /* hasAlgo */ false, + /* usingCublasLt */ usingCublasLt); +} + +void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta, + cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt) +{ + half h_alpha = (half) (f_alpha); + half h_beta = (half) (f_beta); + + // TODO: default cublas libs + usingCublasLt = usingCublasLt && (mAType == CUDA_R_16F || mAType == CUDA_R_8F_E4M3); + bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F; + int batch_count = 1; + // fp32 use cublas as default + // fp16 use cublasLt as default + void const* alpha = isFp16ComputeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); + void const* beta = isFp16ComputeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); + int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE; + + if (usingCublasLt) + { + if (hasAlgo) + { + hasAlgo = checkTactic(transa, transb, m, n, k, lda, ldb, ldc, algo); + } + + check_cuda_error(cublasLtMatmul(getCublasLtHandle(), mOperationDesc, alpha, A, mADesc, B, mBDesc, beta, C, + mCDesc, C, mCDesc, (hasAlgo ? (&algo) : NULL), mCublasWorkspace, workspaceSize, mStream)); + + sync_check_cuda_error(); + } + else + { + check_cuda_error(cublasSetStream(getCublasHandle(), mStream)); + check_cuda_error(cublasSetWorkspace(getCublasHandle(), mCublasWorkspace, workspaceSize)); + // Go with default heuristic to choose tactic as cuBLAS does not allow to choose tactics in Ampere+ + cublasGemmAlgo_t cublasAlgo = CUBLAS_GEMM_DEFAULT; + check_cuda_error(cublasGemmEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda, B, mBType, ldb, + beta, C, mCType, ldc, mComputeType, static_cast(cublasAlgo))); + sync_check_cuda_error(); + } +} + +void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, + int const k, void const* A, int const lda, const int64_t strideA, void const* B, int const ldb, + const int64_t strideB, void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha, + float const f_beta) +{ + half h_alpha = (half) f_alpha; + half h_beta = (half) f_beta; + + int isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0; + void const* alpha = isFp16ComputeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); + void const* beta = isFp16ComputeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); + + check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda, + strideA, B, mBType, ldb, strideB, beta, C, mCType, ldc, strideC, batchCount, mComputeType, + mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP)); +} + +void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, + int const k, float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA, + void const* B, cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C, + cudaDataType_t CType, int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType) +{ + half h_alpha = (half) f_alpha; + half h_beta = (half) f_beta; + + bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0; + void const* alpha = isFp16ComputeType ? reinterpret_cast(&h_alpha) : reinterpret_cast(&f_alpha); + void const* beta = isFp16ComputeType ? reinterpret_cast(&h_beta) : reinterpret_cast(&f_beta); + + check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, AType, lda, + strideA, B, BType, ldb, strideB, beta, C, CType, ldc, strideC, batchCount, computeType, + mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP)); +} + +void CublasMMWrapper::setWorkspace(void* workspace) +{ + mCublasWorkspace = workspace; +} + +void CublasMMWrapper::setFP32GemmConfig() +{ + setGemmConfig(CUDA_R_32F, CUDA_R_32F, CUDA_R_32F, CUDA_R_32F); +} + +void CublasMMWrapper::setFP16GemmConfig(cudaDataType_t outputType) +{ + setGemmConfig(CUDA_R_16F, CUDA_R_16F, outputType, CUDA_R_32F); +} + +#ifdef ENABLE_BF16 +void CublasMMWrapper::setBF16GemmConfig(cudaDataType_t outputType) +{ + setGemmConfig(CUDA_R_16BF, CUDA_R_16BF, outputType, CUDA_R_32F); +} +#endif + +#ifdef ENABLE_FP8 +void CublasMMWrapper::setFP8GemmConfig(cudaDataType_t outputType) +{ + setGemmConfig(CUDA_R_8F_E4M3, CUDA_R_8F_E4M3, outputType, CUDA_R_32F); +} +#endif + +void CublasMMWrapper::setGemmConfig( + cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType) +{ + mAType = aType; + mBType = bType; + mCType = cType; + bool isFp16ComputeType = computeType == CUDA_R_16F; + if (isFp16ComputeType) + { + mComputeType = CUBLAS_COMPUTE_16F; + mScaleType = CUDA_R_16F; + } + else + { + mComputeType = CUBLAS_COMPUTE_32F; + mScaleType = CUDA_R_32F; + } +} + +CublasDataType CublasMMWrapper::getCublasDataType(cudaDataType_t data_type) +{ + if (data_type == CUDA_R_16F) + { + return HALF_DATATYPE; + } + else if (data_type == CUDA_R_32F) + { + return FLOAT_DATATYPE; + } + else if (data_type == CUDA_R_8I) + { + return INT8_DATATYPE; + } +#ifdef ENABLE_BF16 + else if (data_type == CUDA_R_16BF) + { + return BFLOAT16_DATATYPE; + } +#endif + return FLOAT_DATATYPE; +} + +void CublasMMWrapper::setStream(cudaStream_t stream) +{ + mStream = stream; +} + +bool CublasMMWrapper::checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, + int const k, int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo) +{ + TLLM_CHECK_WITH_INFO( + descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function"); + + int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE; + + cublasLtMatmulHeuristicResult_t heurResult; + cublasStatus_t algoStatus = cublasLtMatmulAlgoCheck( + getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc, &algo, &heurResult); + + if (algoStatus != CUBLAS_STATUS_SUCCESS || heurResult.state != CUBLAS_STATUS_SUCCESS + || heurResult.workspaceSize > CUBLAS_WORKSPACE_SIZE) + { + return false; + } + + sync_check_cuda_error(); + + return true; +} + +std::vector CublasMMWrapper::getTactics(cublasOperation_t transa, + cublasOperation_t transb, int const m, int const n, int const k, int const lda, int const ldb, int const ldc) +{ + TLLM_CHECK_WITH_INFO( + descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function"); + + auto const heuristics = getTactics(getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc); + + sync_check_cuda_error(); + + return heuristics; +} + +std::vector CublasMMWrapper::getTactics(cublasLtHandle_t lightHandle, + cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc, + cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc) +{ +#if TLLM_CUBLAS_VER_LE(11, 4, 2) + TLLM_CHECK_WITH_INFO(false, "CUBLAS version too low, must be > 11.4.2."); + return {}; +#else + std::vector heuristics(200); + cublasLtMatmulPreference_t preference; + check_cuda_error(cublasLtMatmulPreferenceCreate(&preference)); + check_cuda_error(cublasLtMatmulPreferenceInit(preference)); + uint64_t workspace_size = CUBLAS_WORKSPACE_SIZE; + check_cuda_error(cublasLtMatmulPreferenceSetAttribute( + preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size))); + // Restrict reduction algorithms for numerical stability and better determinism + uint32_t reduction_mask = CUBLASLT_REDUCTION_SCHEME_MASK; + check_cuda_error(cublasLtMatmulPreferenceSetAttribute( + preference, CUBLASLT_MATMUL_PREF_REDUCTION_SCHEME_MASK, &reduction_mask, sizeof(reduction_mask))); +#if TLLM_CUBLAS_VER_LT(12, 0, 0) + uint32_t pointer_mode_mask = 0; + check_cuda_error(cublasLtMatmulPreferenceSetAttribute( + preference, CUBLASLT_MATMUL_PREF_EPILOGUE_MASK, &pointer_mode_mask, sizeof(pointer_mode_mask))); +#endif + + int return_count = 0; + check_cuda_error(cublasLtMatmulAlgoGetHeuristic(lightHandle, computeDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, + heuristics.size(), heuristics.data(), &return_count)); + heuristics.resize(return_count); + + return heuristics; +#endif +} + +} // namespace common + +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.h new file mode 100644 index 00000000000..79b7c92a47d --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cublasMMWrapper.h @@ -0,0 +1,148 @@ +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/common/cudaUtils.h" +#include +#include +#include +#include +#include +#include + +namespace tensorrt_llm +{ +namespace common +{ + +class CublasMMWrapper +{ +protected: + std::shared_ptr mCublasHandle; + std::shared_ptr mCublasLtHandle; + + cudaDataType_t mAType{}; + cudaDataType_t mBType{}; + cudaDataType_t mCType{}; + cublasComputeType_t mComputeType{}; + cudaDataType_t mScaleType{}; + + cublasLtMatmulDesc_t mOperationDesc{NULL}; + cublasLtMatrixLayout_t mADesc{NULL}; + cublasLtMatrixLayout_t mBDesc{NULL}; + cublasLtMatrixLayout_t mCDesc{NULL}; + + cudaStream_t mStream; + + void* mCublasWorkspace = nullptr; + +private: + bool descriptorsCreated() const + { + return mOperationDesc != NULL && mADesc != NULL && mBDesc != NULL && mCDesc != NULL; + } + +public: + CublasMMWrapper(std::shared_ptr cublasHandle, std::shared_ptr cublasLtHandle, + cudaStream_t stream, void* workspace); + + ~CublasMMWrapper(); + + CublasMMWrapper(CublasMMWrapper const& wrapper); + + /********************** GEMMs **********************/ + void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A, + int const lda, void const* B, int const ldb, void* C, int const ldc); + + void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A, + int const lda, void const* B, int const ldb, void* C, int const ldc, + std::optional const& algo); + + void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A, + int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta, + std::optional const& algo); + + void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A, + int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta); + + void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A, + int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta, + cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt); + + void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + void const* A, int const lda, const int64_t strideA, void const* B, int const ldb, const int64_t strideB, + void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha = 1.0f, + float const f_beta = 0.0f); + + void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA, void const* B, + cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C, cudaDataType_t CType, + int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType); + + /********************** Tactic selection helpers **********************/ + bool checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo); + + std::vector getTactics(cublasOperation_t transa, cublasOperation_t transb, + int const m, int const n, int const k, int const lda, int const ldb, int const ldc); + + std::vector getTactics(cublasLtHandle_t lightHandle, + cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc, + cublasLtMatrixLayout_t Cdesc, cublasLtMatrixLayout_t Ddesc); + + using MatrixLayout = std::tuple; + using cache_idx_t = std::tuple>; + + MatrixLayout createMatrixLayout(cublasLtMatrixLayout_t Mdesc); + + /********************** Utils **********************/ + void setWorkspace(void* workspace); + + void setFP32GemmConfig(); + void setFP16GemmConfig(cudaDataType_t outputType = CUDA_R_16F); +#ifdef ENABLE_BF16 + void setBF16GemmConfig(cudaDataType_t outputType = CUDA_R_16BF); +#endif +#ifdef ENABLE_FP8 + void setFP8GemmConfig(cudaDataType_t outputType = CUDA_R_16F); +#endif + + void setStream(cudaStream_t stream); + + void setGemmConfig(cudaDataType_t aType, cudaDataType_t bType, cudaDataType_t cType, cudaDataType_t computeType); + + CublasDataType getCublasDataType(cudaDataType_t data_type); + + void createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, + int const lda, int const ldb, int const ldc, int8_t fastAcc = 0); + void setScaleDescriptors(void* scale_a, void* scale_b); + void destroyDescriptors(); + + cublasHandle_t getCublasHandle() + { + return *(this->mCublasHandle); + } + + cublasLtHandle_t getCublasLtHandle() const + { + return *(this->mCublasLtHandle); + } +}; + +} // namespace common + +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cublasVersionCheck.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cublasVersionCheck.h new file mode 100644 index 00000000000..1ee72c63566 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cublasVersionCheck.h @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// We don't want to include cublas_api.h. It contains the CUBLAS_VER_* macro +// definition which is not sufficient to determine if we include cublas.h, +// cublas_v2.h or cublasLt.h. + +#define TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) (MAJOR * 10000 + MINOR * 100 + PATCH) +#define TLLM_CUBLAS_VER_LE(MAJOR, MINOR, PATCH) \ + TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \ + <= TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) +#define TLLM_CUBLAS_VER_LT(MAJOR, MINOR, PATCH) \ + TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \ + < TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) +#define TLLM_CUBLAS_VER_GE(MAJOR, MINOR, PATCH) \ + TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \ + >= TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) +#define TLLM_CUBLAS_VER_GT(MAJOR, MINOR, PATCH) \ + TLLM_CUBLAS_VERSION_CALC(CUBLAS_VER_MAJOR, CUBLAS_VER_MINOR, CUBLAS_VER_PATCH) \ + > TLLM_CUBLAS_VERSION_CALC(MAJOR, MINOR, PATCH) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Fallbacks.cuh b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Fallbacks.cuh new file mode 100644 index 00000000000..0519251e6fd --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaBf16Fallbacks.cuh @@ -0,0 +1,313 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/common/cudaBf16Wrapper.h" +#include +#include + +namespace tensorrt_llm +{ +namespace common +{ + +#ifdef ENABLE_BF16 +inline __device__ float2 bf1622float2(const __nv_bfloat162 val) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float2 f_val; + f_val.x = __low2float(val); + f_val.y = __high2float(val); + return f_val; +#else + return __bfloat1622float2(val); +#endif +} + +inline __device__ int16_t bf1622int16(__nv_bfloat162 val) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float2 f_val; + f_val.x = max(min(__low2float(val), 127.f), -128.f); + f_val.y = max(min(__high2float(val), 127.f), -128.f); + + union + { + int8_t int8[2]; + int16_t int16; + }; + + int8[0] = static_cast(static_cast(f_val.x)); + int8[1] = static_cast(static_cast(f_val.y)); + return int16; +#else + val = __hmin2(val, make_bfloat162(127., 127.)); + val = __hmax2(val, make_bfloat162(-128., -128.)); + + union + { + int8_t int8[2]; + int16_t int16; + }; + + int8[0] = static_cast(static_cast(val.x)); + int8[1] = static_cast(static_cast(val.y)); + return int16; +#endif +} + +inline __device__ __nv_bfloat162 float22bf162(const float2 val) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __floats2bfloat162_rn(val.x, val.y); +#else + return __float22bfloat162_rn(val); +#endif +} + +inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + __nv_bfloat162 val2; + val2.x = val; + val2.y = val; + return val2; +#else + return __bfloat162bfloat162(val); +#endif +} + +inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl + fyl, fxh + fyh); +#else + return __hadd2(x, y); +#endif +} + +inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(x) + __bfloat162float(y)); +#else + return __hadd(x, y); +#endif +} + +inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl - fyl, fxh - fyh); +#else + return __hsub2(x, y); +#endif +} + +inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(x) - __bfloat162float(y)); +#else + return __hsub(x, y); +#endif +} + +inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl * fyl, fxh * fyh); +#else + return __hmul2(x, y); +#endif +} + +inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y)); +#else + return __hmul(x, y); +#endif +} + +inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh, fzl, fzh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + fzl = __low2float(z); + fzh = __high2float(z); + return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh); +#else + return __hfma2(x, y, z); +#endif +} + +inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z)); +#else + return __hfma(x, y, z); +#endif +} + +inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh; + fxl = __low2float(x); + fxh = __high2float(x); + ; + return __floats2bfloat162_rn(expf(fxl), expf(fxh)); +#else + return h2exp(x); +#endif +} + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) +#if defined(CUDART_VERSION) && (CUDART_VERSION < 12020) + +inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y) +{ + __nv_bfloat162 t; + t.x = x; + t.y = y; + return t; +} +#endif +#endif + +inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c)); +#else + return a + b + c; +#endif +} + +inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d)); +#else + return (__nv_bfloat16) ((float) a + (float) b + (float) c + (float) d); +#endif +} + +inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fal, fah, fbl, fbh, fcl, fch; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch); +#else + return a + b + c; +#endif +} + +inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c)); +#else + return a * b * c; +#endif +} + +inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fal, fah, fbl, fbh, fcl, fch; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch); +#else + return a * b * c; +#endif +} + +inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fal, fah, fbl, fbh, fcl, fch, fdl, fdh; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + fdl = __low2float(d); + fdh = __high2float(d); + return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh); +#else + return a * b * c + d; +#endif +} + +#endif // ENABLE_BF16 + +} // namespace common +} // namespace tensorrt_llm + +// Operator definitions intentionally in global namespace +namespace +{ +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) +#if defined(CUDART_VERSION) && (CUDART_VERSION < 12020) + +inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y) +{ + return tensorrt_llm::common::bf16hmul2(x, y); +}; + +inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y) +{ + return tensorrt_llm::common::bf16hadd2(x, y); +}; +#endif +#endif +} // namespace diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp new file mode 100644 index 00000000000..7eca46a1cab --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#define CUDA_LIB_NAME "cuda" + +#if defined(_WIN32) +#include +#define dllOpen(name) LoadLibrary("nv" name ".dll") +#define dllClose(handle) FreeLibrary(static_cast(handle)) +#define dllGetSym(handle, name) static_cast(GetProcAddress(static_cast(handle), name)) +#else // For non-Windows platforms +#include +#define dllOpen(name) dlopen("lib" name ".so.1", RTLD_LAZY) +#define dllClose(handle) dlclose(handle) +#define dllGetSym(handle, name) dlsym(handle, name) +#endif // defined(_WIN32) + +#include "cudaDriverWrapper.h" +#include "tensorrt_llm/common/assert.h" +#include +#include + +namespace tensorrt_llm::common +{ + +std::shared_ptr CUDADriverWrapper::getInstance() +{ + static std::mutex mutex; + static std::weak_ptr instance; + std::shared_ptr result = instance.lock(); + if (result) + { + return result; + } + + std::lock_guard lock(mutex); + result = instance.lock(); + if (!result) + { + result = std::shared_ptr(new CUDADriverWrapper()); + instance = result; + } + return result; +} + +CUDADriverWrapper::CUDADriverWrapper() + : handle(dllOpen(CUDA_LIB_NAME)) +{ + + TLLM_CHECK_WITH_INFO(handle != nullptr, "CUDA driver library is not open correctly."); + + auto load_sym = [](void* handle, char const* name) + { + void* ret = dllGetSym(handle, name); + return ret; + }; + + *reinterpret_cast(&_cuGetErrorName) = load_sym(handle, "cuGetErrorName"); + *reinterpret_cast(&_cuGetErrorMessage) = load_sym(handle, "cuGetErrorMessage"); + *reinterpret_cast(&_cuFuncSetAttribute) = load_sym(handle, "cuFuncSetAttribute"); + *reinterpret_cast(&_cuLinkComplete) = load_sym(handle, "cuLinkComplete"); + *reinterpret_cast(&_cuModuleUnload) = load_sym(handle, "cuModuleUnload"); + *reinterpret_cast(&_cuLinkDestroy) = load_sym(handle, "cuLinkDestroy"); + *reinterpret_cast(&_cuModuleLoadData) = load_sym(handle, "cuModuleLoadData"); + *reinterpret_cast(&_cuLinkCreate) = load_sym(handle, "cuLinkCreate_v2"); + *reinterpret_cast(&_cuModuleGetFunction) = load_sym(handle, "cuModuleGetFunction"); + *reinterpret_cast(&_cuModuleGetGlobal) = load_sym(handle, "cuModuleGetGlobal_v2"); + *reinterpret_cast(&_cuLinkAddFile) = load_sym(handle, "cuLinkAddFile_v2"); + *reinterpret_cast(&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2"); + *reinterpret_cast(&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel"); + *reinterpret_cast(&_cuLaunchKernel) = load_sym(handle, "cuLaunchKernel"); + *reinterpret_cast(&_cuTensorMapEncodeTiled) = load_sym(handle, "cuTensorMapEncodeTiled"); + *reinterpret_cast(&_cuMemcpyDtoH) = load_sym(handle, "cuMemcpyDtoH_v2"); +} + +CUDADriverWrapper::~CUDADriverWrapper() +{ + dllClose(handle); +} + +CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, char const** pStr) const +{ + return (*_cuGetErrorName)(error, pStr); +} + +CUresult CUDADriverWrapper::cuGetErrorMessage(CUresult error, char const** pStr) const +{ + return (*_cuGetErrorMessage)(error, pStr); +} + +CUresult CUDADriverWrapper::cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const +{ + return (*_cuFuncSetAttribute)(hfunc, attrib, value); +} + +CUresult CUDADriverWrapper::cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const +{ + return (*_cuLinkComplete)(state, cubinOut, sizeOut); +} + +CUresult CUDADriverWrapper::cuModuleUnload(CUmodule hmod) const +{ + return (*_cuModuleUnload)(hmod); +} + +CUresult CUDADriverWrapper::cuLinkDestroy(CUlinkState state) const +{ + return (*_cuLinkDestroy)(state); +} + +CUresult CUDADriverWrapper::cuModuleLoadData(CUmodule* module, void const* image) const +{ + return (*_cuModuleLoadData)(module, image); +} + +CUresult CUDADriverWrapper::cuLinkCreate( + unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const +{ + return (*_cuLinkCreate)(numOptions, options, optionValues, stateOut); +} + +CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const +{ + return (*_cuModuleGetFunction)(hfunc, hmod, name); +} + +CUresult CUDADriverWrapper::cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const +{ + return (*_cuModuleGetGlobal)(dptr, bytes, hmod, name); +} + +CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, + unsigned int numOptions, CUjit_option* options, void** optionValues) const +{ + return (*_cuLinkAddFile)(state, type, path, numOptions, options, optionValues); +} + +CUresult CUDADriverWrapper::cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, + char const* name, unsigned int numOptions, CUjit_option* options, void** optionValues) const +{ + return (*_cuLinkAddData)(state, type, data, size, name, numOptions, options, optionValues); +} + +CUresult CUDADriverWrapper::cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, + unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, + unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const +{ + return (*_cuLaunchCooperativeKernel)( + f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams); +} + +CUresult CUDADriverWrapper::cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, + unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, + unsigned int sharedMemBytes, CUstream hStream, void** kernelParams, void** extra) const +{ + return (*_cuLaunchKernel)( + f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams, extra); +} + +CUresult CUDADriverWrapper::cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, + cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const +{ + return (*_cuTensorMapEncodeTiled)(tensorMap, tensorDataType, tensorRank, globalAddress, globalDim, globalStrides, + boxDim, elementStrides, interleave, swizzle, l2Promotion, oobFill); +} + +CUresult CUDADriverWrapper::cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const +{ + return (*_cuMemcpyDtoH)(dstHost, srcDevice, ByteCount); +} + +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h new file mode 100644 index 00000000000..c4d470a85f0 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaDriverWrapper.h @@ -0,0 +1,138 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef CUDA_DRIVER_WRAPPER_H +#define CUDA_DRIVER_WRAPPER_H + +#include "tensorrt_llm/common/assert.h" +#include +#include +#include +#include + +namespace tensorrt_llm::common +{ + +class CUDADriverWrapper +{ +public: + static std::shared_ptr getInstance(); + + ~CUDADriverWrapper(); + CUDADriverWrapper(CUDADriverWrapper const&) = delete; + CUDADriverWrapper operator=(CUDADriverWrapper const&) = delete; + CUDADriverWrapper(CUDADriverWrapper&&) = delete; + CUDADriverWrapper operator=(CUDADriverWrapper&&) = delete; + + CUresult cuGetErrorName(CUresult error, char const** pStr) const; + + CUresult cuGetErrorMessage(CUresult error, char const** pStr) const; + + CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const; + + CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const; + + CUresult cuModuleUnload(CUmodule hmod) const; + + CUresult cuLinkDestroy(CUlinkState state) const; + + CUresult cuModuleLoadData(CUmodule* module, void const* image) const; + + CUresult cuLinkCreate( + unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const; + + CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const; + + CUresult cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const; + + CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, unsigned int numOptions, + CUjit_option* options, void** optionValues) const; + + CUresult cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, char const* name, + unsigned int numOptions, CUjit_option* options, void** optionValues) const; + + CUresult cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, + unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, + unsigned int sharedMemBytes, CUstream hStream, void** kernelParams) const; + + CUresult cuLaunchKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, + unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, + CUstream hStream, void** kernelParams, void** extra) const; + + CUresult cuTensorMapEncodeTiled(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, cuuint32_t tensorRank, + void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, cuuint32_t const* boxDim, + cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, CUtensorMapSwizzle swizzle, + CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill) const; + + CUresult cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount) const; + +private: + void* handle; + CUDADriverWrapper(); + + CUresult (*_cuGetErrorName)(CUresult, char const**); + CUresult (*_cuGetErrorMessage)(CUresult, char const**); + CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int); + CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*); + CUresult (*_cuModuleUnload)(CUmodule); + CUresult (*_cuLinkDestroy)(CUlinkState); + CUresult (*_cuLinkCreate)(unsigned int, CUjit_option*, void**, CUlinkState*); + CUresult (*_cuModuleLoadData)(CUmodule*, void const*); + CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, char const*); + CUresult (*_cuModuleGetGlobal)(CUdeviceptr*, size_t*, CUmodule, char const*); + CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, char const*, unsigned int, CUjit_option*, void**); + CUresult (*_cuLinkAddData)( + CUlinkState, CUjitInputType, void*, size_t, char const*, unsigned int, CUjit_option*, void**); + CUresult (*_cuLaunchCooperativeKernel)(CUfunction, unsigned int, unsigned int, unsigned int, unsigned int, + unsigned int, unsigned int, unsigned int, CUstream, void**); + CUresult (*_cuLaunchKernel)(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, + unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, + CUstream hStream, void** kernelParams, void** extra); + CUresult (*_cuTensorMapEncodeTiled)(CUtensorMap* tensorMap, CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, void* globalAddress, cuuint64_t const* globalDim, cuuint64_t const* globalStrides, + cuuint32_t const* boxDim, cuuint32_t const* elementStrides, CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion, CUtensorMapFloatOOBfill oobFill); + CUresult (*_cuMemcpyDtoH)(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount); +}; + +template +void checkDriver( + T result, CUDADriverWrapper const& wrap, char const* const func, char const* const file, int const line) +{ + if (result) + { + char const* errorName = nullptr; + char const* errorMsg = nullptr; + wrap.cuGetErrorName(result, &errorName); + wrap.cuGetErrorMessage(result, &errorMsg); + throw TllmException( + file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA driver error in %s: %s: %s", func, errorName, errorMsg)); + } +} + +} // namespace tensorrt_llm::common + +/* + * Macros compliant with TensorRT coding conventions + */ +#define TLLM_CU_CHECK(stat) \ + do \ + { \ + tensorrt_llm::common::checkDriver( \ + (stat), *tensorrt_llm::common::CUDADriverWrapper::getInstance(), #stat, __FILE__, __LINE__); \ + } while (0) + +#endif // CUDA_DRIVER_WRAPPER_H diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.cu b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.cu new file mode 100644 index 00000000000..8e140609f2a --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaFp8Utils.cu @@ -0,0 +1,436 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/cudaFp8Utils.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/reduceKernelUtils.cuh" +#include +#include +#include +#include +#include + +namespace tensorrt_llm +{ +namespace common +{ +#ifdef ENABLE_FP8 + +constexpr int CTA_SIZE = 256; + +template +__inline__ __device__ float scale(float a, float b) +{ + return QUANTIZE ? a / b : a * b; +} + +template +__global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda) +{ + for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < numel; i += blockDim.x * gridDim.x) + { + + if (QUANTIZE_MODE == QuantizeMode::PER_CHANNEL) + { + output[i] = T_OUT(scale(static_cast(input[i]), static_cast(input_scale[i % lda]))); + } + else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN) + { + output[i] = T_OUT(scale(static_cast(input[i]), static_cast(input_scale[i / lda]))); + } + else if (QUANTIZE_MODE == QuantizeMode::PER_TENSOR) + { + output[i] = T_OUT(scale(static_cast(input[i]), static_cast(input_scale[0]))); + } + } +} + +template +void invokeQuantizeMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda, + QuantizeMode quantize_mode, cudaStream_t stream) +{ + dim3 grid(1024); + dim3 block(CTA_SIZE); + if (quantize_mode == QuantizeMode::PER_CHANNEL) + { + scaleMatrix + <<>>(output, input_scale, input, numel, lda); + } + else if (quantize_mode == QuantizeMode::PER_TOKEN) + { + scaleMatrix<<>>(output, input_scale, input, numel, lda); + } + else if (quantize_mode == QuantizeMode::PER_TENSOR) + { + scaleMatrix<<>>(output, input_scale, input, numel, lda); + } + sync_check_cuda_error(); +} + +template +void invokeDequantizeMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda, + QuantizeMode quantize_mode, cudaStream_t stream) +{ + dim3 grid(1024); + dim3 block(CTA_SIZE); + if (quantize_mode == QuantizeMode::PER_CHANNEL) + { + scaleMatrix + <<>>(output, input_scale, input, numel, lda); + } + else if (quantize_mode == QuantizeMode::PER_TOKEN) + { + scaleMatrix<<>>(output, input_scale, input, numel, lda); + } + else if (quantize_mode == QuantizeMode::PER_TENSOR) + { + scaleMatrix + <<>>(output, input_scale, input, numel, lda); + } + sync_check_cuda_error(); +} + +template +__global__ void fakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel) +{ + for (int64_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < numel; tid += blockDim.x * gridDim.x) + { + T_FAKE tmp = (T_FAKE) (static_cast(src[tid])); + dst[tid] = (T_OUT) (static_cast(tmp)); + } +} + +template +void invokeFakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel, cudaStream_t stream) +{ + fakeQuantize<<<1024, CTA_SIZE, 0, stream>>>(dst, src, numel); + sync_check_cuda_error(); +} + +template void invokeFakeQuantize<__nv_fp8_e4m3, float, float>( + float* dst, float const* src, const int64_t numel, cudaStream_t stream); +template void invokeFakeQuantize( + float* dst, __nv_fp8_e4m3 const* src, const int64_t numel, cudaStream_t stream); +template void invokeFakeQuantize<__nv_fp8_e4m3, half, half>( + half* dst, half const* src, const int64_t numel, cudaStream_t stream); +template void invokeFakeQuantize<__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16>( + __nv_bfloat16* dst, __nv_bfloat16 const* src, const int64_t numel, cudaStream_t stream); + +template void invokeFakeQuantize( + half* dst, float const* src, const int64_t numel, cudaStream_t stream); + +__device__ float atomicMaxExtd(float* address, float val) +{ + assert(val >= 0); + unsigned int* address_as_u = reinterpret_cast(address); + unsigned int old = atomicMax(address_as_u, __float_as_uint(val)); + return __uint_as_float(old); +} + +template +inline __device__ T atomicMaxExtdV2(T* address, T val) +{ +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + static_assert(std::is_same_v | std::is_same_v, "T needs to be either half or bfloat16"); + // The address in 64 bits. + uint64_t address_u64 = reinterpret_cast(address); + + // Pack the input value into 32 bits. + union + { + T v[2]; + uint16_t u[2]; + } old, tmp = {}; + + int const loc = (address_u64 & 0x2) >> 1; + tmp.v[loc] = val; + + // 4B aligned pointer. + auto aligned_address = reinterpret_cast(address_u64 & ~0x3ull); + + if constexpr (std::is_same_v) + { + asm volatile("atom.global.v2.f16.max.noftz {%0, %1}, [%2], {%3, %4};" + : "=h"(old.u[0]), "=h"(old.u[1]) + : "l"(aligned_address), "h"(tmp.u[0]), "h"(tmp.u[1])); + } + if constexpr (std::is_same_v) + { + asm volatile("atom.global.v2.bf16.max.noftz {%0, %1}, [%2], {%3, %4};" + : "=h"(old.u[0]), "=h"(old.u[1]) + : "l"(aligned_address), "h"(tmp.u[0]), "h"(tmp.u[1])); + } + + // Return the correct half. + return old.v[loc]; +#endif +} + +__device__ half atomicMaxExtd(half* address, half val) +{ + unsigned short int* address_as_u = reinterpret_cast(address); + unsigned short int old = *address_as_u, assumed; + + while (val > __ushort_as_half(old)) + { + assumed = old; + old = atomicCAS(address_as_u, assumed, __half_as_ushort(val)); + } + + return __ushort_as_half(old); +} + +__device__ __nv_bfloat16 atomicMaxExtd(__nv_bfloat16* address, __nv_bfloat16 val) +{ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + unsigned short int* address_as_u = reinterpret_cast(address); + unsigned short int old = *address_as_u, assumed; + + while (val > __ushort_as_bfloat16(old)) + { + assumed = old; + old = atomicCAS(address_as_u, assumed, __bfloat16_as_ushort(val)); + } + + return __ushort_as_bfloat16(old); +#else + assert(0); + asm volatile("brkpt;\n" ::); + return __nv_bfloat16(0); +#endif +} + +template +__global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t size, const int64_t n) +{ + constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f); + if (QUANTIZE_MODE == QuantizeMode::PER_CHANNEL) + { + for (int64_t col = threadIdx.x; col < n; col += blockDim.x) + { + float max = 0.f; + for (int64_t i = col + n * blockIdx.x; i < size; i += gridDim.x * n) + { + auto val = fabs(static_cast(weights[i])); + max = max > val ? max : val; + } + auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + if constexpr (std::is_same_v) + { + atomicMaxExtd(quant_ptr + col, scale); + } + else + { + auto const address_u64 = reinterpret_cast(quant_ptr + col); + if ((col == 0 && address_u64 % 4 != 0) || (col == n - 1 && address_u64 % 4 == 0)) + atomicMaxExtd(quant_ptr + col, scale); + else + atomicMaxExtdV2(quant_ptr + col, scale); + } +#else // Vector atomics require __CUDA_ARCH__ >= 900 + atomicMaxExtd(quant_ptr + col, scale); +#endif + } + } + else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN) + { + auto const nrows = size / n; + for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x) + { + float max = 0.f; + for (int64_t i = threadIdx.x; i < n; i += blockDim.x) + { + auto val = fabs(static_cast(weights[row * n + i])); + max = max > val ? max : val; + } + max = blockReduceMax(max); + if (threadIdx.x == 0) + { + auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); + quant_ptr[row] = scale; + } + } + } + else if (QUANTIZE_MODE == QuantizeMode::PER_TENSOR) + { + float max = 0.f; + for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < size; i += gridDim.x * blockDim.x) + { + auto val = fabs(static_cast(weights[i])); + max = max > val ? max : val; + } + max = blockReduceMax(max); + if (threadIdx.x == 0) + { + auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); + atomicMaxExtd(quant_ptr, scale); + } + } +} + +template +void invokeComputeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, const int64_t numel, const int64_t lda, + QuantizeMode quantize_mode, cudaStream_t stream) +{ + if (quantize_mode == QuantizeMode::PER_TOKEN) + { + dim3 block(CTA_SIZE); + dim3 grid(numel / lda); + computeFP8QuantizeScale<<>>(quant_ptr, weights, numel, lda); + } + else if (quantize_mode == QuantizeMode::PER_CHANNEL) + { + dim3 block(CTA_SIZE); + dim3 grid((lda + CTA_SIZE - 1) / CTA_SIZE); + cudaMemsetAsync(quant_ptr, 0, lda * sizeof(T_S), stream); + sync_check_cuda_error(); + computeFP8QuantizeScale<<>>(quant_ptr, weights, numel, lda); + } + else if (quantize_mode == QuantizeMode::PER_TENSOR) + { + dim3 block(1024); + dim3 grid(1024); + cudaMemsetAsync(quant_ptr, 0, sizeof(T_S), stream); + sync_check_cuda_error(); + computeFP8QuantizeScale<<>>(quant_ptr, weights, numel, lda); + } + sync_check_cuda_error(); +} + +#define DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(type_scale, type_in) \ + template void invokeComputeFP8QuantizeScale(type_scale * input_scale, type_in const* weights, \ + int64_t numel, int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream); + +DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(half, half); +DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, half); +DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, float); +#ifdef ENABLE_BF16 +DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(__nv_bfloat16, __nv_bfloat16); +DEFINE_INVOKE_COMPUTE_FP8_QUANTIZE_SCALE(float, __nv_bfloat16); +#endif + +template +__global__ void dynamicQuantizeMatrixPerToken( + T_OUT* output, T_S* quant_ptr, T_IN const* input, int64_t numel, int64_t lda) +{ + extern __shared__ __align__(sizeof(float)) char _shmem[]; + T_IN* shmem = reinterpret_cast(_shmem); + constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f); + auto const nrows = numel / lda; + for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x) + { + float max = 0.f; + for (int64_t i = threadIdx.x; i < lda; i += blockDim.x) + { + auto const in = input[row * lda + i]; + shmem[i] = in; + auto val = fabs(static_cast(in)); + max = max > val ? max : val; + } + max = blockAllReduceMax(max); // __syncthreads() called so we can read shmem + auto const s = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); + for (int64_t i = threadIdx.x; i < lda; i += blockDim.x) + { + // true means we are quantizing + output[row * lda + i] = (T_OUT) scale(static_cast(shmem[i]), static_cast(s)); + } + if (threadIdx.x == 0) + { + quant_ptr[row] = s; + } + } +} + +template +void invokeComputeScalesAndQuantizeMatrix(T_OUT* output, T_S* quant_ptr, const T_IN* input, const int64_t numel, + const int64_t lda, QuantizeMode quantize_mode, cudaStream_t stream) +{ + if (quantize_mode == QuantizeMode::PER_TOKEN) + { + dim3 grid(numel / lda); + bool use_shmem = true; + auto const shmem_size = lda * sizeof(T_IN); + if (shmem_size >= (48 << 10)) + { + cudaError_t ret = cudaFuncSetAttribute(dynamicQuantizeMatrixPerToken, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + use_shmem = ret == cudaSuccess; + } + if (use_shmem) + { + // ensure the threadblock is as large as possible to increase occupancy + dim3 block(std::min((lda + 31) / 32 * 32, static_cast(1024))); + dynamicQuantizeMatrixPerToken<<>>(output, quant_ptr, input, numel, lda); + } + else + { + dim3 block(CTA_SIZE); + computeFP8QuantizeScale<<>>(quant_ptr, input, numel, lda); + sync_check_cuda_error(); + invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream); + } + } + else if (quantize_mode == QuantizeMode::PER_CHANNEL) + { + dim3 block(CTA_SIZE); + dim3 grid((lda + CTA_SIZE - 1) / CTA_SIZE); + cudaMemsetAsync(quant_ptr, 0, lda * sizeof(T_S), stream); + sync_check_cuda_error(); + computeFP8QuantizeScale<<>>(quant_ptr, input, numel, lda); + sync_check_cuda_error(); + invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream); + } + else if (quantize_mode == QuantizeMode::PER_TENSOR) + { + dim3 block(1024); + dim3 grid(1024); + cudaMemsetAsync(quant_ptr, 0, sizeof(T_S), stream); + sync_check_cuda_error(); + computeFP8QuantizeScale<<>>(quant_ptr, input, numel, lda); + sync_check_cuda_error(); + invokeQuantizeMatrix(output, quant_ptr, input, numel, lda, quantize_mode, stream); + } + sync_check_cuda_error(); +} + +#define DEFINE_INVOKE_QUANTIZE_MATRIX(type_out, type_scale, type_in) \ + template void invokeQuantizeMatrix(type_out * output, \ + type_scale const* input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \ + cudaStream_t stream); \ + template void invokeDequantizeMatrix(type_out * output, \ + type_scale const* input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \ + cudaStream_t stream); \ + template void invokeComputeScalesAndQuantizeMatrix(type_out * output, \ + type_scale * input_scale, type_in const* input, int64_t numel, int64_t lda, QuantizeMode quantize_mode, \ + cudaStream_t stream); + +#ifdef ENABLE_FP8 +DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, float, float); +DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, float, half); +DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, half, half); +DEFINE_INVOKE_QUANTIZE_MATRIX(half, half, __nv_fp8_e4m3); +DEFINE_INVOKE_QUANTIZE_MATRIX(float, float, __nv_fp8_e4m3); +DEFINE_INVOKE_QUANTIZE_MATRIX(half, float, __nv_fp8_e4m3); +#ifdef ENABLE_BF16 +DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16); +DEFINE_INVOKE_QUANTIZE_MATRIX(__nv_bfloat16, __nv_bfloat16, __nv_fp8_e4m3); +#endif +#endif + +#endif // ENABLE_FP8 +} // namespace common +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaProfilerUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaProfilerUtils.cpp new file mode 100644 index 00000000000..5576fe782fa --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaProfilerUtils.cpp @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/cudaProfilerUtils.h" +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/common/stringUtils.h" +#include +#include + +namespace +{ + +std::tuple, std::unordered_set> populateIterationIndexesImpl( + std::string const& envVarName) +{ + auto envVarVal = std::getenv(envVarName.c_str()); + auto envVarValStr = std::string{envVarVal != nullptr ? envVarVal : ""}; + auto values = tensorrt_llm::common::str2set(envVarValStr, ','); + std::unordered_set startSet; + std::unordered_set endSet; + for (std::string const& value : values) + { + size_t dashIdx = value.find("-"); + if (dashIdx != std::string::npos) + { + int32_t start = std::stoi(value.substr(0, dashIdx)); + startSet.insert(start); + int32_t end = std::stoi(value.substr(dashIdx + 1)); + endSet.insert(end); + } + else + { + int32_t start_end = std::stoi(value); + startSet.insert(start_end); + endSet.insert(start_end); + } + } + + return std::make_pair(startSet, endSet); +} + +} // namespace + +namespace tensorrt_llm::common +{ + +std::pair, std::unordered_set> populateIterationIndexes( + std::string const& envVarName, std::optional const& legacyEnvVarName) +{ + auto [profileIterIdxs, stopIterIdxs] = populateIterationIndexesImpl(envVarName); + + // If empty, try to use legacy env var name + if (legacyEnvVarName && profileIterIdxs.empty() && stopIterIdxs.empty()) + { + std::tie(profileIterIdxs, stopIterIdxs) = populateIterationIndexesImpl(legacyEnvVarName.value()); + + if (!profileIterIdxs.empty() || !stopIterIdxs.empty()) + { + TLLM_LOG_WARNING( + "Using deprecated environment variable %s to specify cudaProfiler start and stop iterations. " + "Please " + "use %s " + "instead.", + legacyEnvVarName.value().c_str(), envVarName.c_str()); + } + } + + return std::make_pair(profileIterIdxs, stopIterIdxs); +} + +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/cudaTypeUtils.cuh b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaTypeUtils.cuh new file mode 100644 index 00000000000..a0463a3a49e --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/cudaTypeUtils.cuh @@ -0,0 +1,752 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh" +#include "tensorrt_llm/common/cudaBf16Wrapper.h" +#include "tensorrt_llm/common/cudaFp8Utils.h" +#include +#include +#include +#if ENABLE_BF16 +#include +#endif + +namespace tensorrt_llm +{ +namespace common +{ + +template +inline __device__ T ldg(T const* val) +{ + return __ldg(val); +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat162 ldg(__nv_bfloat162 const* val) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return val[0]; +#else + return __ldg(val); +#endif +} + +template <> +inline __device__ __nv_bfloat16 ldg(__nv_bfloat16 const* val) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return val[0]; +#else + return __ldg(val); +#endif +} +#endif // ENABLE_BF16 + +// Get type2 from type or vice versa (applied to half and bfloat16) +template +struct TypeConverter +{ + using Type = half2; +}; // keep for generality + +template <> +struct TypeConverter +{ + using Type = half; +}; + +template <> +struct TypeConverter +{ + using Type = half2; +}; + +#if ENABLE_BF16 +template <> +struct TypeConverter<__nv_bfloat162> +{ + using Type = __nv_bfloat16; +}; + +template <> +struct TypeConverter<__nv_bfloat16> +{ + using Type = __nv_bfloat162; +}; +#endif // ENABLE_BF16 + +// Defined math operations (bfloat16 fallback to fp32 when it is not supported) +template +inline __device__ T hadd2(T a, T b) +{ + return __hadd2(a, b); +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat162 hadd2(__nv_bfloat162 a, __nv_bfloat162 b) +{ + return bf16hadd2(a, b); +} +#endif // ENABLE_BF16 + +template +inline __device__ T add(T a, T b) +{ + return a + b; +} + +template <> +inline __device__ half2 add(half2 a, half2 b) +{ + return __hadd2(a, b); +} + +template <> +inline __device__ half add(half a, half b) +{ + return __hadd(a, b); +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) +{ + return bf16hadd2(a, b); +} + +template <> +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) +{ + return bf16hadd(a, b); +} + +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, float b) +{ + return bf16hadd(a, __float2bfloat16(b)); +} +#endif // ENABLE_BF16 + +// applies to all 4 values addition +template +inline __device__ T add(T a, T b, T c) +{ + return a + b + c; +} + +#if ENABLE_BF16 +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) +{ + return bf16hadd(a, b, c); +} + +inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) +{ + return bf16hadd2(a, b, c); +} +#endif // ENABLE_BF16 + +// applies to all 4 values addition +template +inline __device__ T add(T a, T b, T c, T d) +{ + return (T) ((float) a + (float) b + (float) c + (float) d); +} + +#if ENABLE_BF16 +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) +{ + return bf16hadd(a, b, c, d); +} +#endif // ENABLE_BF16 + +template +inline __device__ T hsub2(T a, T b) +{ + return __hsub2(a, b); +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat162 hsub2(__nv_bfloat162 a, __nv_bfloat162 b) +{ + return bf16hsub2(a, b); +} +#endif // ENABLE_BF16 + +template +inline __device__ T hmul2(T a, T b) +{ + return __hmul2(a, b); +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b) +{ + return bf16hmul2(a, b); +} +#endif // ENABLE_BF16 + +template +inline __device__ T hmul2(T a, T b, T c) +{ + return a * b * c; +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat162 hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) +{ + return bf16hmul2(a, b, c); +} +#endif // ENABLE_BF16 + +template +inline __device__ T mul(T a, T b, T c) +{ + return a * b * c; +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) +{ + return bf16hmul(a, b, c); +} + +inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) +{ + return bf16hmul2(a, b, c); +} +#endif // ENABLE_BF16 + +template +inline __device__ T fma(T a, T b, T c, T d) +{ + return a * b * c + d; +} + +#if ENABLE_BF16 +inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) +{ + return bf16hfma2(a, b, c, d); +} +#endif // ENABLE_BF16 + +template +inline __device__ T fma(T a, T b, T c) +{ + return a * b + c; +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) +{ + return bf16hfma2(a, b, c); +} + +template <> +inline __device__ __nv_bfloat16 fma(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) +{ + return bf16hfma(a, b, c); +} +#endif // ENABLE_BF16 + +template +inline __device__ T hexp2(T a) +{ + return h2exp(a); +} + +#if ENABLE_BF16 +template <> +inline __device__ __nv_bfloat162 hexp2(__nv_bfloat162 a) +{ + return bf16exp2(a); +} +#endif // ENABLE_BF16 + +template +__device__ inline T_OUT cuda_cast(T_IN val) +{ + return val; +} + +template <> +__device__ inline float2 cuda_cast(int2 val) +{ + return make_float2(val.x, val.y); +} + +template <> +__device__ inline float2 cuda_cast(float val) +{ + return make_float2(val, val); +} + +template <> +__device__ inline float2 cuda_cast(half2 val) +{ + return __half22float2(val); +} + +template <> +__device__ inline half2 cuda_cast(float2 val) +{ + return __float22half2_rn(val); +} + +template <> +__device__ inline half2 cuda_cast(float val) +{ + return __float2half2_rn(val); +} + +template <> +__device__ inline half2 cuda_cast(half val) +{ + return __half2half2(val); +} + +template <> +__device__ inline int8_t cuda_cast(half val) +{ + union + { + int8_t int8[2]; + int16_t int16; + }; + + union + { + half fp16; + int16_t int16_in; + }; + + fp16 = val; + asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in)); + return int8[0]; +} + +template <> +__device__ inline int16_t cuda_cast(half2 val) +{ + union + { + int8_t int8[2]; + int16_t int16; + }; + + int8[0] = cuda_cast(val.x); + int8[1] = cuda_cast(val.y); + return int16; +} + +template <> +__device__ inline int8_t cuda_cast(float val) +{ + union + { + int8_t int8[2]; + int16_t int16; + }; + + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val)); + return int8[0]; +} + +template <> +__device__ inline int16_t cuda_cast(float2 val) +{ + union + { + int8_t int8[2]; + int16_t int16; + }; + + int8[0] = cuda_cast(val.x); + int8[1] = cuda_cast(val.y); + return int16; +} + +template <> +__device__ inline half2 cuda_cast(int16_t val) +{ + union + { + int8_t int8[2]; + int16_t int16; + }; + + int16 = val; + return make_half2(int8[0], int8[1]); +} + +template <> +__device__ inline float2 cuda_cast(int16_t val) +{ + union + { + int8_t int8[2]; + int16_t int16; + }; + + int16 = val; + return make_float2(int8[0], int8[1]); +} + +#ifdef ENABLE_BF16 +template <> +__device__ inline __nv_bfloat16 cuda_cast(int32_t val) +{ + return static_cast(val); +} + +template <> +__device__ inline __nv_bfloat16 cuda_cast(int8_t val) +{ + return static_cast(val); +} + +template <> +__device__ inline int8_t cuda_cast(__nv_bfloat16 val) +{ + return static_cast(val); +} + +template <> +__device__ inline float cuda_cast(__nv_bfloat16 val) +{ + return __bfloat162float(val); +} + +template <> +__device__ inline float2 cuda_cast(__nv_bfloat162 val) +{ + return bf1622float2(val); +} + +template <> +__device__ inline half cuda_cast(__nv_bfloat16 val) +{ + return __float2half(__bfloat162float(val)); +} + +template <> +__device__ inline int16_t cuda_cast(__nv_bfloat162 val) +{ + return bf1622int16(val); +} + +template <> +__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val) +{ + return __float2bfloat16(val); +} + +template <> +__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, half>(half val) +{ + return __float2bfloat16(__half2float(val)); +} + +template <> +__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_bfloat16>(__nv_bfloat16 val) +{ + return bf162bf162(val); +} + +template <> +__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float>(float val) +{ + return __float2bfloat162_rn(val); +} + +template <> +__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float2>(float2 val) +{ + return float22bf162(val); +} + +template <> +__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(int16_t val) +{ + union + { + int8_t int8[2]; + int16_t int16; + }; + + int16 = val; + __nv_bfloat162 res; + res.x = cuda_cast<__nv_bfloat16>(int8[0]); + res.y = cuda_cast<__nv_bfloat16>(int8[1]); + return res; +} + +template <> +__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val) +{ + return float22bf162(__half22float2(val)); +} + +#endif // ENABLE BF16 + +template +__device__ inline T cuda_abs(T val) +{ + assert(false); + return {}; +} + +template <> +__device__ inline float cuda_abs(float val) +{ + return fabs(val); +} + +template <> +__device__ inline float2 cuda_abs(float2 val) +{ + return make_float2(fabs(val.x), fabs(val.y)); +} + +template <> +__device__ inline half cuda_abs(half val) +{ + return __habs(val); +} + +template <> +__device__ inline half2 cuda_abs(half2 val) +{ + return __habs2(val); +} + +#ifdef ENABLE_BF16 + +#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__) +template <> +__device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val) +{ + return __habs(val); +} + +template <> +__device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val) +{ + return __habs2(val); +} +#endif + +#endif // ENABLE_FP16 + +template +__device__ inline To cuda_sum(Ti val) +{ + return cuda_cast(val); +}; + +template +__device__ inline To cuda_sum(float2 val) +{ + return cuda_cast(val.x + val.y); +}; + +// Unary maximum: compute the max of a vector type +template +__device__ inline To cuda_max(Ti val) +{ + return cuda_cast(val); +}; + +template <> +__device__ inline float cuda_max(float2 val) +{ + return fmaxf(val.x, val.y); +} + +template <> +__device__ inline half cuda_max(half2 val) +{ + return __hmax(val.x, val.y); +} + +#ifdef ENABLE_BF16 +template <> +__device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val) +{ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + return __hmax(val.x, val.y); +#else + assert(0); + asm volatile("brkpt;\n" ::); + return __nv_bfloat16(0); +#endif +} +#endif + +// Binary maximum: compute the max of two values. +template +__device__ inline T cuda_max(T val1, T val2) +{ + return (val1 > val2) ? val1 : val2; +} + +template <> +__device__ inline float2 cuda_max(float2 val1, float2 val2) +{ + float2 out; + out.x = fmaxf(val1.x, val2.x); + out.y = fmaxf(val1.y, val2.y); + return out; +} + +template <> +__device__ inline half2 cuda_max(half2 val1, half2 val2) +{ + return __hmax2(val1, val2); +} + +#ifdef ENABLE_BF16 +template <> +__device__ inline __nv_bfloat162 cuda_max(__nv_bfloat162 val1, __nv_bfloat162 val2) +{ + return __hmax2(val1, val2); +} +#endif // ENABLE_BF16 + +// Binary maximum: compute the min of two values. +template +__device__ inline T cuda_min(T val1, T val2) +{ + return (val1 < val2) ? val1 : val2; +} + +template <> +__device__ inline float2 cuda_min(float2 val1, float2 val2) +{ + float2 out; + out.x = fminf(val1.x, val2.x); + out.y = fminf(val1.y, val2.y); + return out; +} + +template <> +__device__ inline half2 cuda_min(half2 val1, half2 val2) +{ + return __hmin2(val1, val2); +} + +#ifdef ENABLE_BF16 +template <> +__device__ inline __nv_bfloat162 cuda_min(__nv_bfloat162 val1, __nv_bfloat162 val2) +{ + return __hmin2(val1, val2); +} +#endif // ENABLE_BF16 + +// Helper function of clamping the val into the given range. +template +inline __device__ T cuda_clamp(T val, T minVal, T maxVal) +{ + return cuda_min(cuda_max(val, minVal), maxVal); +} + +#ifdef ENABLE_FP8 +template <> +__device__ inline float2 cuda_cast(__nv_fp8x2_e4m3 val) +{ + return bf1622float2(fp8x2_e4m3_to_bfloat2(&val)); +} + +template <> +__device__ inline half2 cuda_cast(__nv_fp8x2_e4m3 val) +{ + return fp8x2_e4m3_to_half2(&val); +} + +template <> +__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, float2>(float2 val) +{ + return __nv_fp8x2_e4m3(bf1622float2(float22bf162(val))); +} + +template <> +__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, half2>(half2 val) +{ + return __nv_fp8x2_e4m3(cuda_cast(val)); +} + +template <> +__device__ inline __nv_fp8x2_e4m3 cuda_cast<__nv_fp8x2_e4m3, __nv_bfloat162>(__nv_bfloat162 val) +{ + return __nv_fp8x2_e4m3(cuda_cast(val)); +} + +template <> +__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, half>(half val) +{ + return __nv_fp8_e4m3(val); +} + +template <> +__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, __nv_bfloat16>(__nv_bfloat16 val) +{ + return __nv_fp8_e4m3(val); +} + +template <> +__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, float>(float val) +{ + return __nv_fp8_e4m3(val); +} + +template <> +__device__ inline float cuda_cast(__nv_fp8_e4m3 val) +{ + return (float) val; +} + +template <> +__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_fp8x2_e4m3>(__nv_fp8x2_e4m3 val) +{ + return fp8x2_e4m3_to_bfloat2(&val); +} + +template <> +__device__ inline int8_t cuda_cast(__nv_fp8_e4m3 val) +{ + // no impl + return 0; +} + +template <> +__device__ inline __nv_fp8_e4m3 cuda_cast<__nv_fp8_e4m3, int8_t>(int8_t val) +{ + return cuda_cast<__nv_fp8_e4m3>(cuda_cast<__nv_bfloat16>(cuda_cast(val))); +} + +#endif // ENABLE_FP8 + +} // namespace common +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/customAllReduceUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/customAllReduceUtils.h new file mode 100644 index 00000000000..d7bf43b4075 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/customAllReduceUtils.h @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace tensorrt_llm::utils::customAllReduceUtils +{ + +constexpr size_t NUM_POINTERS_PER_RANK = 7; + +// WARNING: MUST BE KEPT IN SYNC with tensorrt_llm/plugin/plugin.py +inline size_t getMaxRequiredWorkspaceSize(int worldSize) noexcept +{ + if (worldSize <= 2) + { + return 16 * 1000 * 1000; + } + return 8 * 1000 * 1000; +} + +} // namespace tensorrt_llm::utils::customAllReduceUtils diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.cpp new file mode 100644 index 00000000000..64d3d44acb8 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.cpp @@ -0,0 +1,214 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "envUtils.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/logger.h" +#include + +namespace tensorrt_llm::common +{ + +std::optional getIntEnv(char const* name) +{ + char const* const env = std::getenv(name); + if (env == nullptr) + { + return std::nullopt; + } + int32_t const val = std::stoi(env); + if (val <= 0) + { + return std::nullopt; + } + return {val}; +}; + +// Returns true if the env variable exists and is set to "1" +static bool getBoolEnv(char const* name) +{ + char const* env = std::getenv(name); + return env && env[0] == '1' && env[1] == '\0'; +} + +// XQA kernels (optimized kernels for generation phase). +bool forceXQAKernels() +{ + static bool const forceXQA = (getIntEnv("TRTLLM_FORCE_XQA").value_or(0) != 0); + return forceXQA; +} + +std::optional getEnvEnableXQAJIT() +{ + static bool init = false; + static bool exists = false; + static bool enableXQAJIT = false; + if (!init) + { + init = true; + char const* enable_xqa_jit_var = std::getenv("TRTLLM_ENABLE_XQA_JIT"); + if (enable_xqa_jit_var) + { + exists = true; + if (enable_xqa_jit_var[0] == '1' && enable_xqa_jit_var[1] == '\0') + { + enableXQAJIT = true; + } + } + } + if (exists) + { + return enableXQAJIT; + } + else + { + return std::nullopt; + } +} + +// Tune the number of blocks per sequence for accuracy/performance purpose. +bool getEnvMmhaMultiblockDebug() +{ + static bool init = false; + static bool forceMmhaMaxSeqLenTile = false; + if (!init) + { + init = true; + char const* enable_mmha_debug_var = std::getenv("TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG"); + if (enable_mmha_debug_var) + { + if (enable_mmha_debug_var[0] == '1' && enable_mmha_debug_var[1] == '\0') + { + forceMmhaMaxSeqLenTile = true; + } + } + } + return forceMmhaMaxSeqLenTile; +} + +int getEnvMmhaBlocksPerSequence() +{ + static bool init = false; + static int mmhaBlocksPerSequence = 0; + if (!init) + { + init = true; + char const* mmhaBlocksPerSequenceEnv = std::getenv("TRTLLM_MMHA_BLOCKS_PER_SEQUENCE"); + if (mmhaBlocksPerSequenceEnv) + { + mmhaBlocksPerSequence = std::atoi(mmhaBlocksPerSequenceEnv); + if (mmhaBlocksPerSequence <= 0) + { + TLLM_LOG_WARNING("Invalid value for TRTLLM_MMHA_BLOCKS_PER_SEQUENCE. Will use default values instead!"); + } + } + } + return mmhaBlocksPerSequence; +} + +int getEnvMmhaKernelBlockSize() +{ + static bool init = false; + static int mmhaKernelBlockSize = 0; + if (!init) + { + init = true; + char const* mmhaKernelBlockSizeEnv = std::getenv("TRTLLM_MMHA_KERNEL_BLOCK_SIZE"); + if (mmhaKernelBlockSizeEnv) + { + mmhaKernelBlockSize = std::atoi(mmhaKernelBlockSizeEnv); + if (mmhaKernelBlockSize <= 0) + { + TLLM_LOG_WARNING("Invalid value for TRTLLM_MMHA_KERNEL_BLOCK_SIZE. Will use default values instead!"); + } + } + } + return mmhaKernelBlockSize; +} + +bool getEnvEnablePDL() +{ + static bool init = false; + static bool enablePDL = false; + if (!init) + { + init = true; + // PDL only available when arch >= 90 + if (getSMVersion() >= 90) + { + // PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1` + enablePDL = getBoolEnv("TRTLLM_ENABLE_PDL"); + } + } + return enablePDL; +} + +bool getEnvUseUCXKvCache() +{ + static bool const useUCXKVCache = getBoolEnv("TRTLLM_USE_UCX_KVCACHE"); + return useUCXKVCache; +} + +std::string getEnvUCXInterface() +{ + static bool init = false; + static std::string ucxInterface; + if (!init) + { + init = true; + { + char const* ucx_interface = std::getenv("TRTLLM_UCX_INTERFACE"); + if (ucx_interface) + { + ucxInterface = ucx_interface; + } + } + } + return ucxInterface; +} + +bool getEnvDisaggLayerwise() +{ + static bool const disaggLayerwise = getBoolEnv("TRTLLM_DISAGG_LAYERWISE"); + return disaggLayerwise; +} + +bool getEnvParallelCacheSend() +{ + static bool const parallelCacheSend = getBoolEnv("TRTLLM_PARALLEL_CACHE_SEND"); + return parallelCacheSend; +} + +bool getEnvRequestKVCacheSerial() +{ + static bool const requestKVCacheSerial = getBoolEnv("TRTLLM_REQUEST_KV_CACHE_SERIAL"); + return requestKVCacheSerial; +} + +bool getEnvDisableKVCacheTransferOverlap() +{ + static bool const disableKVCacheTransferOverlap = getBoolEnv("TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP"); + return disableKVCacheTransferOverlap; +} + +bool getEnvDisableReceiveKVCacheParallel() +{ + static bool const disableReceiveParallel = getBoolEnv("TRTLLM_DISABLE_KVCACHE_RECEIVE_PARALLEL"); + return disableReceiveParallel; +} + +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.h new file mode 100644 index 00000000000..027c7cfbb3b --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/envUtils.h @@ -0,0 +1,60 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include + +namespace tensorrt_llm::common +{ +// Useful when you want to inject some debug code controllable with env var. +std::optional getIntEnv(char const* name); + +// XQA kernels (optimized kernels for generation phase). +bool forceXQAKernels(); + +// Whether XQA JIT is enabled. +// +// Returns the value of TRTLLM_ENABLE_XQA_JIT env var. If such env var doesn't exist, std::nullopt is returned. +std::optional getEnvEnableXQAJIT(); + +// Tune the number of blocks per sequence for accuracy/performance purpose. +bool getEnvMmhaMultiblockDebug(); + +int getEnvMmhaBlocksPerSequence(); + +int getEnvMmhaKernelBlockSize(); + +// Whether PDL is enabled. +bool getEnvEnablePDL(); + +bool getEnvUseUCXKvCache(); + +std::string getEnvUCXInterface(); + +bool getEnvDisaggLayerwise(); + +bool getEnvParallelCacheSend(); + +bool getEnvRequestKVCacheSerial(); + +bool getEnvDisableKVCacheTransferOverlap(); + +bool getEnvDisableReceiveKVCacheParallel(); + +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/logger.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/logger.cpp new file mode 100644 index 00000000000..334ad236906 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/logger.cpp @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/tllmException.h" +#include + +namespace tensorrt_llm::common +{ + +Logger::Logger() +{ + char* isFirstRankOnlyChar = std::getenv("TLLM_LOG_FIRST_RANK_ONLY"); + bool isFirstRankOnly = (isFirstRankOnlyChar != nullptr && std::string(isFirstRankOnlyChar) == "ON"); + + auto const* levelName = std::getenv("TLLM_LOG_LEVEL"); + if (levelName != nullptr) + { + auto level = [levelName = std::string(levelName)]() + { + if (levelName == "TRACE") + return TRACE; + if (levelName == "DEBUG") + return DEBUG; + if (levelName == "INFO") + return INFO; + if (levelName == "WARNING") + return WARNING; + if (levelName == "ERROR") + return ERROR; + TLLM_THROW("Invalid log level: %s", levelName.c_str()); + }(); + // If TLLM_LOG_FIRST_RANK_ONLY=ON, set LOG LEVEL of other device to ERROR + if (isFirstRankOnly) + { + auto const deviceId = getDevice(); + if (deviceId != 1) + { + level = ERROR; + } + } + setLevel(level); + } +} + +void Logger::log(std::exception const& ex, Logger::Level level) +{ + log(level, "%s: %s", TllmException::demangle(typeid(ex).name()).c_str(), ex.what()); +} + +Logger* Logger::getLogger() +{ + thread_local Logger instance; + return &instance; +} +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/mathUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/mathUtils.h new file mode 100644 index 00000000000..1bad3a2c152 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/mathUtils.h @@ -0,0 +1,37 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace tensorrt_llm +{ +namespace common +{ + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ __host__ T divUp(T m, T n) +{ + return (m + n - 1) / n; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace common +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.cu b/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.cu new file mode 100644 index 00000000000..d13217b203a --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.cu @@ -0,0 +1,906 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaTypeUtils.cuh" +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/common/memoryUtils.h" + +#include +#include +#include + +namespace tensorrt_llm +{ +namespace common +{ + +template +void deviceMalloc(T** ptr, size_t size, bool is_random_initialize) +{ + check_cuda_error(cudaMalloc((void**) (ptr), sizeof(T) * size)); + if (is_random_initialize) + { + cudaRandomUniform(*ptr, size); + } +} + +template void deviceMalloc(float** ptr, size_t size, bool is_random_initialize); +template void deviceMalloc(half** ptr, size_t size, bool is_random_initialize); +#ifdef ENABLE_BF16 +template void deviceMalloc(__nv_bfloat16** ptr, size_t size, bool is_random_initialize); +#endif +template void deviceMalloc(uint16_t** ptr, size_t size, bool is_random_initialize); +template void deviceMalloc(int** ptr, size_t size, bool is_random_initialize); +template void deviceMalloc(bool** ptr, size_t size, bool is_random_initialize); +template void deviceMalloc(char** ptr, size_t size, bool is_random_initialize); +template void deviceMalloc(int8_t** ptr, size_t size, bool is_random_initialize); +#ifdef ENABLE_FP8 +template void deviceMalloc(__nv_fp8_e4m3** ptr, size_t size, bool is_random_initialize); +#endif + +template +void deviceMemSetZero(T* ptr, size_t size) +{ + check_cuda_error(cudaMemset(static_cast(ptr), 0, sizeof(T) * size)); +} + +template void deviceMemSetZero(float* ptr, size_t size); +template void deviceMemSetZero(half* ptr, size_t size); +template void deviceMemSetZero(int* ptr, size_t size); +template void deviceMemSetZero(uint32_t* ptr, size_t size); +template void deviceMemSetZero(bool* ptr, size_t size); +#ifdef ENABLE_FP8 +template void deviceMemSetZero(__nv_fp8_e4m3* ptr, size_t size); +#endif +#ifdef ENABLE_BF16 +template void deviceMemSetZero(__nv_bfloat16* ptr, size_t size); +#endif + +template +void deviceFree(T*& ptr) +{ + if (ptr != NULL) + { + check_cuda_error(cudaFree(ptr)); + ptr = NULL; + } +} + +template void deviceFree(float*& ptr); +template void deviceFree(half*& ptr); +#ifdef ENABLE_BF16 +template void deviceFree(__nv_bfloat16*& ptr); +#endif +template void deviceFree(unsigned short*& ptr); +template void deviceFree(int*& ptr); +template void deviceFree(bool*& ptr); +template void deviceFree(char*& ptr); +template void deviceFree(int8_t*& ptr); +#ifdef ENABLE_FP8 +template void deviceFree(__nv_fp8_e4m3*& ptr); +#endif + +template +void deviceFill(T* devptr, size_t size, T value, cudaStream_t stream) +{ + T* arr = new T[size]; + std::fill(arr, arr + size, value); + check_cuda_error(cudaMemcpyAsync(devptr, arr, sizeof(T) * size, cudaMemcpyHostToDevice, stream)); + delete[] arr; +} + +template void deviceFill(float* devptr, size_t size, float value, cudaStream_t stream); +template void deviceFill(half* devptr, size_t size, half value, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void deviceFill(__nv_bfloat16* devptr, size_t size, __nv_bfloat16 value, cudaStream_t stream); +#endif +template void deviceFill(int* devptr, size_t size, int value, cudaStream_t stream); +template void deviceFill(bool* devptr, size_t size, bool value, cudaStream_t stream); + +template +void cudaD2Hcpy(T* tgt, T const* src, const size_t size) +{ + check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyDeviceToHost)); +} + +template void cudaD2Hcpy(float* tgt, float const* src, size_t size); +template void cudaD2Hcpy(half* tgt, half const* src, size_t size); +#ifdef ENABLE_BF16 +template void cudaD2Hcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size); +#endif +template void cudaD2Hcpy(int* tgt, int const* src, size_t size); +template void cudaD2Hcpy(bool* tgt, bool const* src, size_t size); +#ifdef ENABLE_FP8 +template void cudaD2Hcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size); +#endif +template void cudaD2Hcpy(unsigned long long* tgt, unsigned long long const* src, size_t size); +template void cudaD2Hcpy(unsigned int* tgt, unsigned int const* src, size_t size); +template void cudaD2Hcpy(int8_t* tgt, int8_t const* src, size_t size); + +template +void cudaH2Dcpy(T* tgt, T const* src, const size_t size) +{ + check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyHostToDevice)); +} + +template void cudaH2Dcpy(float* tgt, float const* src, size_t size); +template void cudaH2Dcpy(half* tgt, half const* src, size_t size); +#ifdef ENABLE_BF16 +template void cudaH2Dcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size); +#endif +template void cudaH2Dcpy(int* tgt, int const* src, size_t size); +template void cudaH2Dcpy(bool* tgt, bool const* src, size_t size); +#ifdef ENABLE_FP8 +template void cudaH2Dcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size); +#endif +template void cudaH2Dcpy(unsigned long long* tgt, unsigned long long const* src, size_t size); +template void cudaH2Dcpy(unsigned int* tgt, unsigned int const* src, size_t size); +template void cudaH2Dcpy(int8_t* tgt, int8_t const* src, size_t size); + +template +void cudaD2Dcpy(T* tgt, T const* src, const size_t size, cudaStream_t stream) +{ + check_cuda_error(cudaMemcpyAsync(tgt, src, sizeof(T) * size, cudaMemcpyDeviceToDevice, stream)); +} + +template void cudaD2Dcpy(float* tgt, float const* src, size_t size, cudaStream_t stream); +template void cudaD2Dcpy(half* tgt, half const* src, size_t size, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void cudaD2Dcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size, cudaStream_t stream); +#endif +template void cudaD2Dcpy(int* tgt, int const* src, size_t size, cudaStream_t stream); +template void cudaD2Dcpy(bool* tgt, bool const* src, size_t size, cudaStream_t stream); +template void cudaD2Dcpy(int8_t* tgt, int8_t const* src, size_t size, cudaStream_t stream); +#ifdef ENABLE_FP8 +template void cudaD2Dcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size, cudaStream_t stream); +#endif +template void cudaD2Dcpy(unsigned long long* tgt, unsigned long long const* src, size_t size, cudaStream_t stream); + +template +__global__ void cudaCast(T_OUT* dst, T_IN* src, const size_t size) +{ + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) + { + dst[tid] = (T_OUT) ((float) (src[tid])); + } +} + +template +void invokeCudaCast(T_OUT* dst, T_IN const* const src, const size_t size, cudaStream_t stream) +{ + cudaCast<<<256, 256, 0, stream>>>(dst, src, size); +} + +template void invokeCudaCast(float* dst, half const* const src, const size_t size, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeCudaCast(float* dst, __nv_bfloat16 const* const src, const size_t size, cudaStream_t stream); +template void invokeCudaCast(__nv_bfloat16* dst, float const* const src, const size_t size, cudaStream_t stream); +template void invokeCudaCast(__nv_bfloat16* dst, half const* const src, const size_t size, cudaStream_t stream); +template void invokeCudaCast(half* dst, __nv_bfloat16 const* const src, const size_t size, cudaStream_t stream); +#endif +#ifdef ENABLE_FP8 +template void invokeCudaCast(float* dst, __nv_fp8_e4m3 const* const src, const size_t size, cudaStream_t stream); +template void invokeCudaCast( + __nv_bfloat16* dst, __nv_fp8_e4m3 const* const src, const size_t size, cudaStream_t stream); +template void invokeCudaCast(half* dst, __nv_fp8_e4m3 const* const src, const size_t size, cudaStream_t stream); +template void invokeCudaCast(__nv_fp8_e4m3* dst, float const* const src, const size_t size, cudaStream_t stream); +template void invokeCudaCast( + __nv_fp8_e4m3* dst, __nv_bfloat16 const* const src, const size_t size, cudaStream_t stream); +template void invokeCudaCast(__nv_fp8_e4m3* dst, half const* const src, const size_t size, cudaStream_t stream); +#endif + +template +void cudaAutoCpy(T* tgt, T const* src, const size_t size, cudaStream_t stream) +{ + if (stream != NULL) + { + check_cuda_error(cudaMemcpyAsync(tgt, src, sizeof(T) * size, cudaMemcpyDefault, stream)); + } + else + { + check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyDefault)); + } +} + +template void cudaAutoCpy(float* tgt, float const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(half* tgt, half const* src, size_t size, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void cudaAutoCpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size, cudaStream_t stream); +#endif +template void cudaAutoCpy(int* tgt, int const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(bool* tgt, bool const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(int8_t* tgt, int8_t const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(uint8_t* tgt, uint8_t const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(uint32_t* tgt, uint32_t const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(unsigned long long* tgt, unsigned long long const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(unsigned long* tgt, unsigned long const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(char* tgt, char const* src, size_t size, cudaStream_t stream); + +template void cudaAutoCpy(float const** tgt, float const* const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(half const** tgt, half const* const* src, size_t size, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void cudaAutoCpy(__nv_bfloat16 const** tgt, __nv_bfloat16 const* const* src, size_t size, cudaStream_t stream); +#endif +template void cudaAutoCpy(int const** tgt, int const* const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(bool const** tgt, bool const* const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy(int8_t const** tgt, int8_t const* const* src, size_t size, cudaStream_t stream); +template void cudaAutoCpy( + unsigned long long const** tgt, unsigned long long const* const* src, size_t size, cudaStream_t stream); + +template +__global__ void cuda_random_uniform_kernel(T* buffer, const size_t size, int const seq_offset) +{ + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + curandState_t local_state; + curand_init((unsigned long long int) 1337, idx + seq_offset, 0, &local_state); + for (size_t index = idx; index < size; index += blockDim.x * gridDim.x) + { + buffer[index] = (T) (curand_uniform(&local_state) * 0.2f - 0.1f); + } +} + +template <> +__global__ void cuda_random_uniform_kernel(int* buffer, const size_t size, int const seq_offset) +{ + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + curandState_t local_state; + curand_init((float) 1337.f, idx + seq_offset, 0, &local_state); + for (size_t index = idx; index < size; index += blockDim.x * gridDim.x) + { + buffer[index] = curand(&local_state); + } +} + +template <> +__global__ void cuda_random_uniform_kernel(bool* buffer, const size_t size, int const seq_offset) +{ + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + curandState_t local_state; + curand_init((float) 1337.f, idx + seq_offset, 0, &local_state); + for (size_t index = idx; index < size; index += blockDim.x * gridDim.x) + { + buffer[index] = (curand(&local_state) % 2 == 0); + } +} + +template <> +__global__ void cuda_random_uniform_kernel(char* buffer, const size_t size, int const seq_offset) +{ + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + curandState_t local_state; + curand_init((float) 1337.f, idx + seq_offset, 0, &local_state); + for (size_t index = idx; index < size; index += blockDim.x * gridDim.x) + { + buffer[index] = curand(&local_state) % 0xFF; + } +} + +template +void cudaRandomUniform(T* buffer, const size_t size) +{ + static int seq_offset = 0; + cuda_random_uniform_kernel<<<256, 256>>>(buffer, size, seq_offset); + seq_offset += 256 * 256; +} + +template void cudaRandomUniform(float* buffer, const size_t size); +template void cudaRandomUniform(half* buffer, const size_t size); +#ifdef ENABLE_BF16 +template void cudaRandomUniform(__nv_bfloat16* buffer, const size_t size); +#endif +template void cudaRandomUniform(int* buffer, const size_t size); +template void cudaRandomUniform(bool* buffer, const size_t size); +template void cudaRandomUniform(char* buffer, const size_t size); +#ifdef ENABLE_FP8 +template void cudaRandomUniform(__nv_fp8_e4m3* buffer, const size_t size); +#endif + +// loads data from binary file. If it succeeds, returns a non-empty vector. If loading fails or +// the product of the elements in shape is 0, this function will return an empty vector. +template +std::vector loadWeightFromBinHelper(std::vector shape, std::string filename) +{ + if (shape.size() > 2) + { + printf("[ERROR] shape should have less than two dims \n"); + return std::vector(); + } + size_t dim0 = shape[0], dim1 = 1; + if (shape.size() == 2) + { + dim1 = shape[1]; + } + size_t size = dim0 * dim1; + if (size == 0) + { + TLLM_LOG_WARNING("shape is zero, skip loading weight from file %s \n", filename.c_str()); + return std::vector(); + } + + std::vector host_array(size); + std::ifstream in(filename, std::ios::in | std::ios::binary); + if (!in.is_open()) + { + TLLM_LOG_WARNING("file %s cannot be opened, loading model fails! \n", filename.c_str()); + return std::vector(); + } + + size_t loaded_data_size = sizeof(T) * size; + in.seekg(0, in.end); + in.seekg(0, in.beg); + + TLLM_LOG_DEBUG("Read " + std::to_string(loaded_data_size) + " bytes from " + filename); + in.read((char*) host_array.data(), loaded_data_size); + + size_t in_get_size = in.gcount(); + if (in_get_size != loaded_data_size) + { + TLLM_LOG_WARNING("file %s only has %ld, but request %ld, loading model fails! \n", filename.c_str(), + in_get_size, loaded_data_size); + return std::vector(); + } + in.close(); + // If we succeed, return an array with values. + return host_array; +} + +template +int loadWeightFromBinFunc(T* ptr, std::vector shape, std::string filename) +{ + std::vector host_array = loadWeightFromBinHelper(shape, filename); + + if (host_array.empty()) + { + return 0; + } + + if (std::is_same::value == true) + { + cudaH2Dcpy(ptr, (T*) host_array.data(), host_array.size()); + } + else + { + T_IN* ptr_2 = nullptr; + deviceMalloc(&ptr_2, host_array.size(), false); + cudaH2Dcpy(ptr_2, host_array.data(), host_array.size()); + invokeCudaD2DcpyConvert(ptr, ptr_2, host_array.size()); + deviceFree(ptr_2); + } + return 0; +} + +template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(int8_t* ptr, std::vector shape, std::string filename); +#ifdef ENABLE_BF16 +template int loadWeightFromBinFunc<__nv_bfloat16, float>( + __nv_bfloat16* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc<__nv_bfloat16, half>( + __nv_bfloat16* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc<__nv_bfloat16, __nv_bfloat16>( + __nv_bfloat16* ptr, std::vector shape, std::string filename); +#endif // ENABLE_BF16 +template int loadWeightFromBinFunc(int* ptr, std::vector shape, std::string filename); +#ifdef ENABLE_FP8 +template int loadWeightFromBinFunc<__nv_fp8_e4m3, float>( + __nv_fp8_e4m3* ptr, std::vector shape, std::string filename); +#endif // ENABLE_FP8 + +template +int loadWeightFromBin(T* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type) +{ + switch (model_file_type) + { + case TRTLLMCudaDataType::FP32: loadWeightFromBinFunc(ptr, shape, filename); break; + case TRTLLMCudaDataType::FP16: loadWeightFromBinFunc(ptr, shape, filename); break; + case TRTLLMCudaDataType::INT8: loadWeightFromBinFunc(ptr, shape, filename); break; +#ifdef ENABLE_BF16 + case TRTLLMCudaDataType::BF16: loadWeightFromBinFunc(ptr, shape, filename); break; +#endif +#ifdef ENABLE_FP8 + case TRTLLMCudaDataType::FP8: loadWeightFromBinFunc(ptr, shape, filename); break; +#endif + default: TLLM_LOG_ERROR("Does not support TRTLLMCudaDataType=%d", model_file_type); TLLM_CHECK(false); + } + return 0; +} + +template <> +int loadWeightFromBin(int* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type) +{ + loadWeightFromBinFunc(ptr, shape, filename); + return 0; +} + +template int loadWeightFromBin( + float* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); +template int loadWeightFromBin( + half* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); +template int loadWeightFromBin( + int8_t* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); +#ifdef ENABLE_BF16 +template int loadWeightFromBin( + __nv_bfloat16* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); +#endif +#ifdef ENABLE_FP8 +template int loadWeightFromBin( + __nv_fp8_e4m3* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); +#endif +template int loadWeightFromBin( + int* ptr, std::vector shape, std::string filename, TRTLLMCudaDataType model_file_type); + +template +__global__ void cudaD2DcpyConvert(T_OUT* dst, const T_IN* src, const size_t size) +{ + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) + { + dst[tid] = cuda_cast(src[tid]); + } +} + +template +void invokeCudaD2DcpyConvert(T_OUT* tgt, const T_IN* src, const size_t size, cudaStream_t stream) +{ + cudaD2DcpyConvert<<<256, 256, 0, stream>>>(tgt, src, size); +} + +template void invokeCudaD2DcpyConvert(int8_t* tgt, float const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(float* tgt, int8_t const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(float* tgt, int const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(half* tgt, int const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(float* tgt, float const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(half* tgt, float const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(float* tgt, half const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(uint32_t* tgt, int const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(int* tgt, uint32_t const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(int* tgt, float const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(int* tgt, half const* src, const size_t size, cudaStream_t stream); + +#ifdef ENABLE_BF16 +template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, float const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, int const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(float* tgt, __nv_bfloat16 const* src, const size_t size, cudaStream_t stream); +template void invokeCudaD2DcpyConvert(int* tgt, __nv_bfloat16 const* src, const size_t size, cudaStream_t stream); +#endif // ENABLE_BF16 + +template +__global__ void cudaD2DScaleCpyConvert( + T_OUT* dst, const T_IN* src, float const* scale, bool invert_scale, const size_t size) +{ + float const scale_value = invert_scale ? 1.0f / scale[0] : scale[0]; + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) + { + dst[tid] = cuda_cast(cuda_cast(src[tid]) * scale_value); + } +} + +template +void invokeCudaD2DScaleCpyConvert( + T_OUT* tgt, const T_IN* src, float const* scale, bool invert_scale, const size_t size, cudaStream_t stream) +{ + cudaD2DScaleCpyConvert<<<256, 256, 0, stream>>>(tgt, src, scale, invert_scale, size); +} + +// clang-format off +template void invokeCudaD2DScaleCpyConvert(float* tgt, const int32_t* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); +template void invokeCudaD2DScaleCpyConvert(int32_t* tgt, const float* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); +template void invokeCudaD2DScaleCpyConvert(half* tgt, const int32_t* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); +template void invokeCudaD2DScaleCpyConvert(int32_t* tgt, const half* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeCudaD2DScaleCpyConvert(__nv_bfloat16* tgt, const int32_t* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); +template void invokeCudaD2DScaleCpyConvert(int32_t* tgt, const __nv_bfloat16* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); +#endif // ENABLE_BF16 +#ifdef ENABLE_FP8 +template void invokeCudaD2DScaleCpyConvert(float* tgt, const __nv_fp8_e4m3* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream); +#endif // ENABLE_FP8 +// clang-format on + +void invokeCudaD2DcpyHalf2Float(float* dst, half* src, const size_t size, cudaStream_t stream) +{ + invokeCudaD2DcpyConvert(dst, src, size, stream); +} + +void invokeCudaD2DcpyFloat2Half(half* dst, float* src, const size_t size, cudaStream_t stream) +{ + invokeCudaD2DcpyConvert(dst, src, size, stream); +} + +template +void saveToBinary(T const* ptr, const size_t size, std::string filename) +{ + + std::vector h_ptr(size); + cudaD2Hcpy(h_ptr.data(), ptr, size); + std::vector float_ptr(size); + for (size_t i = 0; i < size; i++) + { + float_ptr[i] = (float) h_ptr[i]; + } + + std::ofstream out(filename, std::ios::out | std::ios::binary); + TLLM_CHECK_WITH_INFO(out.is_open(), "Fail to open file " + filename); + + out.write((char*) float_ptr.data(), size * sizeof(float)); +} + +template void saveToBinary(float const* ptr, const size_t size, std::string filename); +template void saveToBinary(half const* ptr, const size_t size, std::string filename); +#ifdef ENABLE_BF16 +template void saveToBinary(__nv_bfloat16 const* ptr, const size_t size, std::string filename); +#endif // ENABLE_BF16 + +template <> +void saveToBinary(int const* ptr, const size_t size, std::string filename) +{ + std::vector h_ptr(size); + cudaD2Hcpy(h_ptr.data(), ptr, size); + std::ofstream out(filename, std::ios::out | std::ios::binary); + TLLM_CHECK_WITH_INFO(out.is_open(), "Fail to open file " + filename); + out.write((char*) h_ptr.data(), size * sizeof(int)); +} + +template +__global__ void fakeCast(T_IN* input_ptr, const size_t size) +{ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) + { + T_fake_type tmp_val = (T_fake_type) ((float) input_ptr[i]); + input_ptr[i] = (T_IN) ((float) tmp_val); + } +} + +template +void invokeFakeCast(T_IN* input_ptr, const size_t size, cudaStream_t stream) +{ + dim3 block(256); + dim3 grid((size + 255) / 256); + fakeCast<<>>(input_ptr, size); +} + +#ifdef ENABLE_FP8 +__global__ void cudaD2Dcpyfp82Float(float* dst, __nv_fp8_e4m3* src, const size_t size) +{ + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) + { + dst[tid] = (float) (src[tid]); + } +} + +void invokeCudaD2Dcpyfp82Float(float* dst, __nv_fp8_e4m3* src, const size_t size, cudaStream_t stream) +{ + cudaD2Dcpyfp82Float<<<256, 256, 0, stream>>>(dst, src, size); +} + +__global__ void cudaD2Dcpyfp82Half(half* dst, __nv_fp8_e4m3* src, const size_t size) +{ + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) + { + dst[tid] = (half) ((float) (src[tid])); + } +} + +void invokeCudaD2Dcpyfp82Half(half* dst, __nv_fp8_e4m3* src, const size_t size, cudaStream_t stream) +{ + cudaD2Dcpyfp82Half<<<256, 256, 0, stream>>>(dst, src, size); +} + +__global__ void cudaD2DcpyFloat2fp8(__nv_fp8_e4m3* dst, float* src, const size_t size) +{ + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) + { + dst[tid] = (__nv_fp8_e4m3) src[tid]; + } +} + +void invokeCudaD2DcpyFloat2fp8(__nv_fp8_e4m3* dst, float* src, const size_t size, cudaStream_t stream) +{ + cudaD2DcpyFloat2fp8<<<256, 256, 0, stream>>>(dst, src, size); +} + +__global__ void cudaD2DcpyHalf2fp8(__nv_fp8_e4m3* dst, half* src, const size_t size) +{ + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) + { + dst[tid] = (__nv_fp8_e4m3) src[tid]; + } +} + +void invokeCudaD2DcpyHalf2fp8(__nv_fp8_e4m3* dst, half* src, const size_t size, cudaStream_t stream) +{ + cudaD2DcpyHalf2fp8<<<256, 256, 0, stream>>>(dst, src, size); +} + +__global__ void cudaD2DcpyBfloat2fp8(__nv_fp8_e4m3* dst, __nv_bfloat16* src, const size_t size) +{ + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) + { + dst[tid] = (__nv_fp8_e4m3) src[tid]; + } +} + +void invokeCudaD2DcpyBfloat2fp8(__nv_fp8_e4m3* dst, __nv_bfloat16* src, const size_t size, cudaStream_t stream) +{ + cudaD2DcpyBfloat2fp8<<<256, 256, 0, stream>>>(dst, src, size); +} + +#endif // ENABLE_FP8 + +template +__global__ void transpose(T_OUT* dst, T_IN* src, const size_t dim0, const size_t dim1) +{ + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < dim0 * dim1; tid += blockDim.x * gridDim.x) + { + const size_t src_col_id = tid % dim1; + const size_t src_row_id = tid / dim1; + dst[src_col_id * dim0 + src_row_id] = (T_OUT) (src[tid]); + } +} + +template +void invokeInPlaceTranspose(T* data, T* workspace, const size_t dim0, const size_t dim1) +{ + // copy data to workspace, and then transpose from workspace to data + cudaD2Dcpy(workspace, data, dim0 * dim1); + transpose<<<256, 256>>>(data, workspace, dim0, dim1); +} + +#ifdef ENABLE_FP8 +template void invokeInPlaceTranspose( + __nv_fp8_e4m3* data, __nv_fp8_e4m3* workspace, const size_t dim0, const size_t dim1); +#endif // ENABLE_FP8 +#ifdef ENABLE_BF16 +template void invokeInPlaceTranspose( + __nv_bfloat16* data, __nv_bfloat16* workspace, const size_t dim0, const size_t dim1); +#endif // ENABLE_BF16 +template void invokeInPlaceTranspose(float* data, float* workspace, const size_t dim0, const size_t dim1); + +template +__global__ void transpose0213( + T_OUT* dst, T_IN* src, const size_t dim0, const size_t dim1, const size_t dim2, const size_t dim3) +{ + // src permutation: [0, 1, 2, 3] + // dst permutation: [0, 2, 1, 3] + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < dim0 * dim1 * dim2 * dim3; + tid += blockDim.x * gridDim.x) + { + size_t tmp_idx = tid; + const size_t dim_3_idx = tmp_idx % dim3; + tmp_idx = (tmp_idx - dim_3_idx) / dim3; + const size_t dim_2_idx = tmp_idx % dim2; + tmp_idx = (tmp_idx - dim_2_idx) / dim2; + const size_t dim_1_idx = tmp_idx % dim1; + tmp_idx = (tmp_idx - dim_1_idx) / dim1; + const size_t dim_0_idx = tmp_idx % dim0; + dst[dim_0_idx * dim1 * dim2 * dim3 + dim_2_idx * dim1 * dim3 + dim_1_idx * dim3 + dim_3_idx] = src[tid]; + } +} + +template +void invokeInPlaceTranspose0213( + T* data, T* workspace, const size_t dim0, const size_t dim1, const size_t dim2, const size_t dim3) +{ + // copy data to workspace, and then transpose from workspace to data + // Note that this kernel is used for pre-processing and not very efficient. + cudaD2Dcpy(workspace, data, dim0 * dim1 * dim2 * dim3); + transpose0213<<<256, 256>>>(data, workspace, dim0, dim1, dim2, dim3); +} + +#ifdef ENABLE_FP8 +template void invokeInPlaceTranspose0213(__nv_fp8_e4m3* data, __nv_fp8_e4m3* workspace, const size_t dim0, + const size_t dim1, const size_t dim2, const size_t dim3); +#endif // ENABLE_FP8 +#ifdef ENABLE_BF16 +template void invokeInPlaceTranspose0213(__nv_bfloat16* data, __nv_bfloat16* workspace, const size_t dim0, + const size_t dim1, const size_t dim2, const size_t dim3); +#endif // ENABLE_BF16 +template void invokeInPlaceTranspose0213( + float* data, float* workspace, const size_t dim0, const size_t dim1, const size_t dim2, const size_t dim3); + +template +__global__ void transpose102(T_OUT* dst, T_IN* src, const size_t dim0, const size_t dim1, const size_t dim2) +{ + // src permutation: [0, 1, 2] + // dst permutation: [1, 0, 2] + for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < dim0 * dim1 * dim2; tid += blockDim.x * gridDim.x) + { + size_t tmp_idx = tid; + const size_t dim_2_idx = tmp_idx % dim2; + tmp_idx = (tmp_idx - dim_2_idx) / dim2; + const size_t dim_1_idx = tmp_idx % dim1; + tmp_idx = (tmp_idx - dim_1_idx) / dim1; + const size_t dim_0_idx = tmp_idx % dim0; + dst[dim_1_idx * dim0 * dim2 + dim_0_idx * dim2 + dim_2_idx] = src[tid]; + } +} + +template +void invokeInPlaceTranspose102(T* data, T* workspace, const size_t dim0, const size_t dim1, const size_t dim2) +{ + // copy data to workspace, and then transpose from workspace to data + // Note that this kernel is used for pre-processing and not very efficient. + cudaD2Dcpy(workspace, data, dim0 * dim1 * dim2); + transpose102<<<256, 256>>>(data, workspace, dim0, dim1, dim2); +} + +#ifdef ENABLE_FP8 +template void invokeInPlaceTranspose102( + __nv_fp8_e4m3* data, __nv_fp8_e4m3* workspace, const size_t dim0, const size_t dim1, const size_t dim2); +#endif // ENABLE_FP8 +#ifdef ENABLE_BF16 +template void invokeInPlaceTranspose102( + __nv_bfloat16* data, __nv_bfloat16* workspace, const size_t dim0, const size_t dim1, const size_t dim2); +#endif // ENABLE_BF16 +template void invokeInPlaceTranspose102( + float* data, float* workspace, const size_t dim0, const size_t dim1, const size_t dim2); + +template +void __global__ multiplyScale(T* tensor, float scale, const size_t size) +{ + for (size_t index = threadIdx.x + blockIdx.x * blockDim.x; index < size; index += blockDim.x * gridDim.x) + { + tensor[index] = (T) (((float) tensor[index]) * scale); + } +} + +template +void invokeMultiplyScale(T* tensor, float scale, const size_t size, cudaStream_t stream) +{ + int block = 256; + int grid = (size + 255) / 256; + multiplyScale<<>>(tensor, scale, size); +} + +template void invokeMultiplyScale(float* tensor, float scale, const size_t size, cudaStream_t stream); +template void invokeMultiplyScale(half* tensor, float scale, const size_t size, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeMultiplyScale(__nv_bfloat16* tensor, float scale, const size_t size, cudaStream_t stream); +#endif +#ifdef ENABLE_FP8 +template void invokeMultiplyScale(__nv_fp8_e4m3* tensor, float scale, const size_t size, cudaStream_t stream); +#endif + +template +void __global__ divideScale(T* tensor, float scale, const size_t size) +{ + for (size_t index = threadIdx.x + blockIdx.x * blockDim.x; index < size; index += blockDim.x * gridDim.x) + { + tensor[index] = (T) (((float) tensor[index]) / scale); + } +} + +template +void invokeDivideScale(T* tensor, float scale, const size_t size, cudaStream_t stream) +{ + int block = 256; + int grid = (size + 255) / 256; + divideScale<<>>(tensor, scale, size); +} + +template void invokeDivideScale(float* tensor, float scale, const size_t size, cudaStream_t stream); +template void invokeDivideScale(half* tensor, float scale, const size_t size, cudaStream_t stream); +#ifdef ENABLE_BF16 +template void invokeDivideScale(__nv_bfloat16* tensor, float scale, const size_t size, cudaStream_t stream); +#endif +#ifdef ENABLE_FP8 +template void invokeDivideScale(__nv_fp8_e4m3* tensor, float scale, const size_t size, cudaStream_t stream); +#endif +#ifdef ENABLE_BF16 +template void invokeFakeCast(float* input_ptr, const size_t size, cudaStream_t stream); +template void invokeFakeCast<__nv_bfloat16, __nv_bfloat16>( + __nv_bfloat16* input_ptr, const size_t size, cudaStream_t stream); +template void invokeFakeCast(half* input_ptr, const size_t size, cudaStream_t stream); +#endif +template void invokeFakeCast(float* input_ptr, const size_t size, cudaStream_t stream); +template void invokeFakeCast(float* input_ptr, const size_t size, cudaStream_t stream); +#ifdef ENABLE_FP8 +template void invokeFakeCast(float* input_ptr, const size_t size, cudaStream_t stream); +template void invokeFakeCast(half* input_ptr, const size_t size, cudaStream_t stream); +template void invokeFakeCast<__nv_bfloat16, __nv_fp8_e4m3>( + __nv_bfloat16* input_ptr, const size_t size, cudaStream_t stream); +#endif + +size_t cuda_datatype_size(TRTLLMCudaDataType dt) +{ + static const std::unordered_map sizes{ + {TRTLLMCudaDataType::FP32, sizeof(float)}, {TRTLLMCudaDataType::FP16, sizeof(half)} +#ifdef ENABLE_BF16 + , + {TRTLLMCudaDataType::BF16, sizeof(__nv_bfloat16)} +#endif + }; + + return sizes.at(dt); +} + +template +__global__ void check_range(T const* buffer, size_t size, T min, T max, bool* d_within_range) +{ + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) + { + const T val = buffer[i]; + if (val < min || val > max) + { + *d_within_range = false; + } + } +} + +template +bool invokeCheckRange(T const* buffer, const size_t size, T min, T max, bool* d_within_range, cudaStream_t stream) +{ + cudaMemsetAsync(d_within_range, true, sizeof(bool), stream); + + dim3 block(256); + dim3 grid((size + 255) / 256); + check_range<<>>(buffer, size, min, max, d_within_range); + + bool result; + cudaD2Hcpy(&result, d_within_range, 1); + return result; +} + +template bool invokeCheckRange( + int const* buffer, const size_t size, int min, int max, bool* d_within_range, cudaStream_t stream); + +/* + * Determine the total workspace size based on a vector containing multiple variable sizes. + */ +size_t calcAlignedSize(std::vector const& sizes, const size_t ALIGN_BYTES) +{ + const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1); + // Check ALIGN_BYTES is a power of 2 + assert((ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0); + + size_t total = 0; + for (auto sz : sizes) + { + total += (sz + ALIGN_BYTES - 1) & ALIGN_MASK; + } + + // We add extra "ALIGN_BYTES - 1" bytes in case the start address passed to the function calcAlignedPointers() is + // not aligned. + return total + ALIGN_BYTES - 1; +} + +/* + * Given the address of the workspace and the vector containing multiple variable sizes, calculate the start addresses + * of each variable. + */ +void calcAlignedPointers( + std::vector& outPtrs, void const* p, std::vector const& sizes, size_t ALIGN_BYTES) +{ + const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1); + // Check ALIGN_BYTES is a power of 2 + assert((ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0); + + // In case the start address is not aligned + char* ptr = reinterpret_cast((reinterpret_cast(p) + ALIGN_BYTES - 1) & ALIGN_MASK); + + outPtrs.reserve(sizes.size()); + for (auto sz : sizes) + { + outPtrs.push_back(ptr); + ptr += (sz + ALIGN_BYTES - 1) & ALIGN_MASK; + } +} + +} // namespace common +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.h new file mode 100644 index 00000000000..9e413a1beb8 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/memoryUtils.h @@ -0,0 +1,292 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/common/cudaFp8Utils.h" +#include "tensorrt_llm/common/cudaUtils.h" + +#include + +namespace tensorrt_llm +{ +namespace common +{ + +template +void deviceMalloc(T** ptr, size_t size, bool is_random_initialize = true); + +template +void deviceMemSetZero(T* ptr, size_t size); + +template + +void deviceFree(T*& ptr); + +template +void deviceFill(T* devptr, size_t size, T value, cudaStream_t stream = 0); + +template +void cudaD2Hcpy(T* tgt, T const* src, size_t const size); + +template +void cudaH2Dcpy(T* tgt, T const* src, size_t const size); + +template +void cudaD2Dcpy(T* tgt, T const* src, size_t const size, cudaStream_t stream = NULL); + +template +void cudaAutoCpy(T* tgt, T const* src, size_t const size, cudaStream_t stream = NULL); + +template +void cudaRandomUniform(T* buffer, size_t const size); + +template +int loadWeightFromBin(T* ptr, std::vector shape, std::string filename, + TRTLLMCudaDataType model_file_type = TRTLLMCudaDataType::FP32); + +// template +// int loadWeightFromBinAndQuantizeForWeightOnly(int8_t* quantized_weight_ptr, +// T* scale_ptr, +// std::vector shape, +// std::string filename, +// TRTLLMCudaDataType model_file_type = TRTLLMCudaDataType::FP32); + +void invokeCudaD2DcpyHalf2Float(float* dst, half* src, size_t const size, cudaStream_t stream); +void invokeCudaD2DcpyFloat2Half(half* dst, float* src, size_t const size, cudaStream_t stream); +#ifdef ENABLE_FP8 +void invokeCudaD2Dcpyfp82Float(float* dst, __nv_fp8_e4m3* src, size_t const size, cudaStream_t stream); +void invokeCudaD2Dcpyfp82Half(half* dst, __nv_fp8_e4m3* src, size_t const size, cudaStream_t stream); +void invokeCudaD2DcpyFloat2fp8(__nv_fp8_e4m3* dst, float* src, size_t const size, cudaStream_t stream); +void invokeCudaD2DcpyHalf2fp8(__nv_fp8_e4m3* dst, half* src, size_t const size, cudaStream_t stream); +void invokeCudaD2DcpyBfloat2fp8(__nv_fp8_e4m3* dst, __nv_bfloat16* src, size_t const size, cudaStream_t stream); +#endif // ENABLE_FP8 +#ifdef ENABLE_BF16 +void invokeCudaD2DcpyBfloat2Float(float* dst, __nv_bfloat16* src, size_t const size, cudaStream_t stream); +#endif // ENABLE_BF16 + +template +void invokeCudaCast(T_OUT* dst, T_IN const* const src, size_t const size, cudaStream_t stream); + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// The following functions implement conversion of multi-dimensional indices to an index in a flat array. +// The shape of the Tensor dimensions is passed as one array (`dims`), the indices are given as individual arguments. +// For examples on how to use these functions, see their tests `test_memory_utils.cu`. +// All of these functions can be evaluated at compile time by recursive template expansion. + +template +__inline__ __host__ __device__ std::enable_if_t::value, T> constexpr flat_index( + T const& acc, TDim dims, TIndex const& index) +{ + assert(index < dims[0]); + return acc * dims[0] + index; +} + +template +__inline__ __host__ __device__ std::enable_if_t::value, T> constexpr flat_index( + T const& acc, TDim dims, TIndex const& index, TIndices... indices) +{ + assert(index < dims[0]); + return flat_index(acc * dims[0] + index, dims + 1, indices...); +} + +template +__inline__ __host__ __device__ std::enable_if_t::value, T> constexpr flat_index( + [[maybe_unused]] TDim dims, T const& index) +{ + assert(index < dims[0]); + return index; +} + +template +__inline__ __host__ __device__ + std::enable_if_t::value, typename std::remove_pointer::type> constexpr flat_index( + TDim dims, TIndex const& index, TIndices... indices) +{ + assert(index < dims[0]); + return flat_index(static_cast::type>(index), dims + 1, indices...); +} + +template +__inline__ __host__ __device__ T constexpr flat_index( + std::array const& dims, TIndex const& index, TIndices... indices) +{ + static_assert(skip < N); + static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions"); + return flat_index(&dims[skip], index, indices...); +} + +template +__inline__ __host__ __device__ T constexpr flat_index( + T const& acc, std::array const& dims, TIndex const& index, TIndices... indices) +{ + static_assert(skip < N); + static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions"); + return flat_index(acc, &dims[skip], index, indices...); +} + +template +__inline__ __host__ __device__ T constexpr flat_index(T const (&dims)[N], TIndex const& index, TIndices... indices) +{ + static_assert(skip < N); + static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions"); + return flat_index(static_cast(dims) + skip, index, indices...); +} + +template +__inline__ __host__ __device__ T constexpr flat_index( + T const& acc, T const (&dims)[N], TIndex const& index, TIndices... indices) +{ + static_assert(skip < N); + static_assert(sizeof...(TIndices) < N - skip, "Number of indices exceeds number of dimensions"); + return flat_index(acc, static_cast(dims) + skip, index, indices...); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// These are simpler functions for multi-dimensional index conversion. Indices and dimensions are passed as individual +// arguments. These functions are more suitable for usage inside kernels than the corresponding flat_index functions +// which require arrays as arguments. Usage examples can be found in `test_memory_utils.cu`. The functions can be +// evaluated at compile time. + +template +__inline__ __host__ __device__ T constexpr flat_index2(TIndex const& index_0, TIndex const& index_1, T const& dim_1) +{ + assert(index_1 < dim_1); + return index_0 * dim_1 + index_1; +} + +template +__inline__ __host__ __device__ T constexpr flat_index3( + TIndex const& index_0, TIndex const& index_1, TIndex const& index_2, T const& dim_1, T const& dim_2) +{ + assert(index_2 < dim_2); + return flat_index2(index_0, index_1, dim_1) * dim_2 + index_2; +} + +template +__inline__ __host__ __device__ T constexpr flat_index4(TIndex const& index_0, TIndex const& index_1, + TIndex const& index_2, TIndex const& index_3, T const& dim_1, T const& dim_2, T const& dim_3) +{ + assert(index_3 < dim_3); + return flat_index3(index_0, index_1, index_2, dim_1, dim_2) * dim_3 + index_3; +} + +template +__inline__ __host__ __device__ T constexpr flat_index5(TIndex const& index_0, TIndex const& index_1, + TIndex const& index_2, TIndex const& index_3, TIndex const& index_4, T const& dim_1, T const& dim_2, T const& dim_3, + T const& dim_4) +{ + assert(index_4 < dim_4); + return flat_index4(index_0, index_1, index_2, index_3, dim_1, dim_2, dim_3) * dim_4 + index_4; +} + +template +__inline__ __host__ __device__ T constexpr flat_index_strided3( + TIndex const& index_0, TIndex const& index_1, TIndex const& index_2, T const& stride_1, T const& stride_2) +{ + assert(index_1 < stride_1 / stride_2); + assert(index_2 < stride_2); + return index_0 * stride_1 + index_1 * stride_2 + index_2; +} + +template +__inline__ __host__ __device__ T constexpr flat_index_strided4(TIndex const& index_0, TIndex const& index_1, + TIndex const& index_2, TIndex const& index_3, T const& stride_1, T const& stride_2, T const& stride_3) +{ + assert(index_1 < stride_1 / stride_2); + assert(index_2 < stride_2 / stride_3); + assert(index_3 < stride_3); + return index_0 * stride_1 + index_1 * stride_2 + index_2 * stride_3 + index_3; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void invokeInPlaceTranspose(T* data, T* workspace, size_t const dim0, size_t const dim1); + +template +void invokeInPlaceTranspose0213( + T* data, T* workspace, size_t const dim0, size_t const dim1, size_t const dim2, size_t const dim3); + +template +void invokeInPlaceTranspose102(T* data, T* workspace, size_t const dim0, size_t const dim1, size_t const dim2); + +template +void invokeMultiplyScale(T* tensor, float scale, size_t const size, cudaStream_t stream); + +template +void invokeDivideScale(T* tensor, float scale, size_t const size, cudaStream_t stream); + +template +void invokeCudaD2DcpyConvert(T_OUT* tgt, const T_IN* src, size_t const size, cudaStream_t stream = 0); + +template +void invokeCudaD2DScaleCpyConvert( + T_OUT* tgt, const T_IN* src, float const* scale, bool invert_scale, size_t const size, cudaStream_t stream = 0); + +inline bool checkIfFileExist(std::string const& file_path) +{ + std::ifstream in(file_path, std::ios::in | std::ios::binary); + if (in.is_open()) + { + in.close(); + return true; + } + return false; +} + +template +void saveToBinary(T const* ptr, size_t const size, std::string filename); + +template +void invokeFakeCast(T_IN* input_ptr, size_t const size, cudaStream_t stream); + +size_t cuda_datatype_size(TRTLLMCudaDataType dt); + +template +bool invokeCheckRange(T const* buffer, size_t const size, T min, T max, bool* d_within_range, cudaStream_t stream); + +constexpr size_t DEFAULT_ALIGN_BYTES = 256; + +size_t calcAlignedSize(std::vector const& sizes, size_t ALIGN_BYTES = DEFAULT_ALIGN_BYTES); +void calcAlignedPointers(std::vector& outPtrs, void const* p, std::vector const& sizes, + size_t ALIGN_BYTES = DEFAULT_ALIGN_BYTES); + +struct AlignedPointersUnpacker +{ + template + void operator()(T*&... outPtrs) + { + assert(sizeof...(T) == alignedPointers.size()); + auto it = alignedPointers.begin(); + ((outPtrs = static_cast(*it++)), ...); + } + + std::vector alignedPointers; +}; + +AlignedPointersUnpacker inline calcAlignedPointers( + void const* p, std::vector const& sizes, size_t ALIGN_BYTES = DEFAULT_ALIGN_BYTES) +{ + AlignedPointersUnpacker unpacker{}; + calcAlignedPointers(unpacker.alignedPointers, p, sizes, ALIGN_BYTES); + return unpacker; +} + +} // namespace common +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/mpiUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/mpiUtils.cpp new file mode 100644 index 00000000000..dbdaca4ee77 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/mpiUtils.cpp @@ -0,0 +1,588 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "tensorrt_llm/common/mpiUtils.h" + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/runtime/common.h" +#include "tensorrt_llm/runtime/iBuffer.h" + +#include +#include +#include +#include +#include +#ifndef _WIN32 +#include +#endif + +// We rely on SizeType32 being int32_t in some places with weak type checking, +// i.e. we're passing void ptr to some function. To prevent mysterious errors +// in the future, we trigger a compilation error here if SizeType32 isn't int32_t. +static_assert(std::is_same::value); + +namespace tensorrt_llm::mpi +{ + +MPI_Datatype getMpiDtype(MpiType dtype) +{ +#if ENABLE_MULTI_DEVICE + static std::unordered_map const dtype_map{ + {MpiType::kBYTE, MPI_BYTE}, + {MpiType::kHALF, MPI_UINT16_T}, + {MpiType::kFLOAT, MPI_FLOAT}, + {MpiType::kDOUBLE, MPI_DOUBLE}, + {MpiType::kBOOL, MPI_C_BOOL}, + {MpiType::kINT8, MPI_INT8_T}, + {MpiType::kUINT8, MPI_UINT8_T}, + {MpiType::kINT32, MPI_INT32_T}, + {MpiType::kUINT32, MPI_UINT32_T}, + {MpiType::kINT64, MPI_INT64_T}, + {MpiType::kUINT64, MPI_UINT64_T}, + {MpiType::kFP8, MPI_UINT8_T}, + {MpiType::kBF16, MPI_UINT16_T}, + {MpiType::kCHAR, MPI_CHAR}, + }; + return dtype_map.at(dtype); +#else + TLLM_THROW("Multi device support is disabled."); +#endif +} + +MPI_Op getMpiOp(MpiOp op) +{ +#if ENABLE_MULTI_DEVICE + static std::unordered_map const op_map{ + {MpiOp::NULLOP, MPI_OP_NULL}, + {MpiOp::MAX, MPI_MAX}, + {MpiOp::MIN, MPI_MIN}, + {MpiOp::SUM, MPI_SUM}, + {MpiOp::PROD, MPI_PROD}, + {MpiOp::LAND, MPI_LAND}, + {MpiOp::BAND, MPI_BAND}, + {MpiOp::LOR, MPI_LOR}, + {MpiOp::BOR, MPI_BOR}, + {MpiOp::LXOR, MPI_LXOR}, + {MpiOp::BXOR, MPI_BXOR}, + {MpiOp::MINLOC, MPI_MINLOC}, + {MpiOp::MAXLOC, MPI_MAXLOC}, + {MpiOp::REPLACE, MPI_REPLACE}, + }; + return op_map.at(op); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE +} + +namespace +{ + +bool mpiInitialized = false; +std::recursive_mutex mpiMutex; + +MpiComm initLocalSession() +{ +#if ENABLE_MULTI_DEVICE + MPI_Comm localComm = nullptr; + MPI_Comm_split_type(COMM_SESSION, OMPI_COMM_TYPE_HOST, COMM_SESSION.getRank(), MPI_INFO_NULL, &localComm); + MpiComm localSession{localComm, false}; +#else + MpiComm localSession{COMM_SESSION, false}; +#endif // ENABLE_MULTI_DEVICE + return localSession; +} + +} // namespace + +std::vector getWorldRanks(MpiComm const& comm) +{ +#if ENABLE_MULTI_DEVICE + MPI_Group group = nullptr; + MPI_Group worldGroup = nullptr; + + MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup)); + MPICHECK(MPI_Comm_group(comm, &group)); + + int groupSize = 0; + MPICHECK(MPI_Group_size(group, &groupSize)); + std::vector ranks(groupSize); + std::vector worldRanks(groupSize); + std::iota(ranks.begin(), ranks.end(), 0); + + MPICHECK(MPI_Group_translate_ranks(group, groupSize, ranks.data(), worldGroup, worldRanks.data())); + MPICHECK(MPI_Group_free(&group)); + MPICHECK(MPI_Group_free(&worldGroup)); +#else + std::vector worldRanks{0}; +#endif + return worldRanks; +} + +void initialize(MpiThreadSupport threadMode, bool forwardAbortToParent) +{ + // double-checked locking + if (mpiInitialized) + { + return; + } + std::lock_guard lk(mpiMutex); + if (mpiInitialized) + { + return; + } +#if ENABLE_MULTI_DEVICE + int initialized = 0; + TLLM_MPI_CHECK(MPI_Initialized(&initialized)); + if (!initialized) + { + TLLM_LOG_INFO("Initializing MPI with thread mode %d", threadMode); + int providedMode = 0; + auto requiredMode = static_cast(threadMode); + MPICHECK(MPI_Init_thread(nullptr, nullptr, requiredMode, &providedMode)); + TLLM_CHECK_WITH_INFO(providedMode >= requiredMode, "MPI_Init_thread failed"); + std::atexit([]() { MPI_Finalize(); }); + + /* + * We only catch SIGABRT and SIGSEGV because most, of not all errors in the worker will cause one of these 2 + * signals. Signals like SIGINT and SIGTERM should be issued to the parent and should terminate MPI workers + * correctly. + */ + for (int sig : {SIGABRT, SIGSEGV}) + { + __sighandler_t previousHandler = nullptr; + if (forwardAbortToParent) + { + previousHandler = std::signal(sig, + [](int signal) + { +#ifndef _WIN32 + pid_t parentProcessId = getppid(); + kill(parentProcessId, SIGKILL); +#endif + MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); + }); + } + else + { + previousHandler = std::signal(sig, [](int signal) { MPI_Abort(MPI_COMM_WORLD, EXIT_FAILURE); }); + } + TLLM_CHECK_WITH_INFO(previousHandler != SIG_ERR, "Signal handler setup failed"); + } + + // ensure local MPI communicator is initialized + MpiComm::localSession(); + TLLM_LOG_INFO("Initialized MPI"); + } +#endif // ENABLE_MULTI_DEVICE + mpiInitialized = true; +} + +void MpiComm::barrier() const +{ +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Barrier(mComm)); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE +} + +#if ENABLE_MULTI_DEVICE +template >>> +size_t invokeChunked(TMpiFunc func, TBase* buffer, size_t size, MPI_Datatype dtype, TArgs... args) +{ + constexpr auto maxP1 = static_cast(std::numeric_limits::max()) + 1; + if (TLLM_LIKELY(size < maxP1)) + { + MPICHECK(func(buffer, size, dtype, args...)); + return 1; + } + + constexpr size_t alignment = 256; + int elementSize = 1; + MPICHECK(MPI_Type_size(dtype, &elementSize)); + elementSize = std::min(elementSize, alignment); + + // We cap at max alignment-bytes chunks that can be sent at once. + auto const step = maxP1 - (alignment / elementSize); + + using TCast = std::conditional_t, uint8_t const, uint8_t>; + size_t count = 0; + while (size != 0) + { + auto currentStep = static_cast(std::min(size, step)); + MPICHECK(func(buffer, currentStep, dtype, args...)); + size -= currentStep; + size_t diff = static_cast(currentStep) * elementSize; + buffer = static_cast(buffer) + diff; + ++count; + } + + return count; +} +#endif // ENABLE_MULTI_DEVICE + +std::shared_ptr MpiComm::bcastAsync(void* buffer, size_t size, MpiType dtype, int root) const +{ + std::shared_ptr r = std::make_shared(); +#if ENABLE_MULTI_DEVICE + invokeChunked(MPI_Ibcast, buffer, size, getMpiDtype(dtype), root, mComm, &r->mRequest); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE + return r; +} + +std::shared_ptr MpiComm::bcastAsync(runtime::IBuffer& buf, int root) const +{ + TLLM_CHECK(buf.getMemoryType() != runtime::MemoryType::kGPU); + return bcastAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root); +} + +void MpiComm::bcast(void* buffer, size_t size, MpiType dtype, int root) const +{ +#if ENABLE_MULTI_DEVICE + invokeChunked(MPI_Bcast, buffer, size, getMpiDtype(dtype), root, mComm); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE +} + +void MpiComm::bcast(runtime::IBuffer& buf, int root) const +{ + bcast(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, root); +} + +std::shared_ptr MpiComm::sendAsync(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const +{ + TLLM_LOG_DEBUG("start MPI_Isend with size %d", size); + std::shared_ptr r = std::make_shared(); +#if ENABLE_MULTI_DEVICE + invokeChunked(MPI_Isend, buffer, size, getMpiDtype(dtype), dest, tag, mComm, &r->mRequest); +#else + TLLM_THROW("Multi device support is disabled."); +#endif + TLLM_LOG_DEBUG("end MPI_Isend with size %d", size); + return r; +} + +std::shared_ptr MpiComm::sendAsync(runtime::IBuffer const& buf, int dest, int tag) const +{ + return sendAsync(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag); +} + +void MpiComm::send(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const +{ + TLLM_LOG_DEBUG("start MPI_Send with size %d", size); +#if ENABLE_MULTI_DEVICE + invokeChunked(MPI_Send, buffer, size, getMpiDtype(dtype), dest, tag, mComm); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE + TLLM_LOG_DEBUG("end MPI_Send with size %d", size); +} + +void MpiComm::send(runtime::IBuffer const& buf, int dest, int tag) const +{ + send(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, dest, tag); +} + +MPI_Status MpiComm::recv(void* buffer, size_t size, MpiType dtype, int source, int tag) const +{ + TLLM_LOG_DEBUG("start MPI_Recv with size %d", size); + MPI_Status status{}; +#if ENABLE_MULTI_DEVICE + invokeChunked(MPI_Recv, buffer, size, getMpiDtype(dtype), source, tag, mComm, &status); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE + TLLM_LOG_DEBUG("end MPI_Recv with size %d", size); + return status; +} + +MPI_Status MpiComm::recv(runtime::IBuffer& buf, int source, int tag) const +{ + return recv(buf.data(), buf.getSizeInBytes(), MpiType::kBYTE, source, tag); +} + +MpiComm MpiComm::split(int color, int key) const +{ + MPI_Comm splitComm = nullptr; +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Comm_split(mComm, color, key, &splitComm)); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE + return MpiComm{splitComm, true}; +} + +void MpiComm::allreduce(void const* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const +{ +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Allreduce(sendbuf, recvbuf, count, getMpiDtype(dtype), getMpiOp(op), mComm)); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE +} + +void MpiComm::allgather(void const* sendbuf, void* recvbuf, int count, MpiType dtype) const +{ +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Allgather(sendbuf, count, getMpiDtype(dtype), recvbuf, count, getMpiDtype(dtype), mComm)); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE +} + +void MpiComm::allgatherv(void const* sendbuf, int sendcount, MpiType sendtype, void* recvbuf, + std::vector const& recvcounts, std::vector const& displs, MpiType recvtype) const +{ +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Allgatherv(sendbuf, sendcount, getMpiDtype(sendtype), recvbuf, recvcounts.data(), displs.data(), + getMpiDtype(recvtype), mComm)); + +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE +} + +void MpiComm::mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const +{ +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Mprobe(source, tag, mComm, msg, status)); +#else + TLLM_THROW("Multi device support is disabled."); +#endif // ENABLE_MULTI_DEVICE +} + +bool MpiComm::improbe(int source, int tag, MPI_Message* msg, MPI_Status* status) const +{ +#if ENABLE_MULTI_DEVICE + int flag{0}; + MPICHECK(MPI_Improbe(source, tag, mComm, &flag, msg, status)); + return flag != 0; +#else + TLLM_THROW("Multi device support is disabled."); + return false; +#endif +} + +bool MpiComm::iprobe(int source, int tag, MPI_Status* status) const +{ +#if ENABLE_MULTI_DEVICE + int flag{0}; + MPICHECK(MPI_Iprobe(source, tag, mComm, &flag, status)); + return flag != 0; +#else + TLLM_THROW("Multi device support is disabled."); + return false; +#endif +} + +void MpiComm::recvPoll(int source, int tag, int periodMs) const +{ + MPI_Status status; + while (!iprobe(source, tag, &status)) + { + std::this_thread::sleep_for(std::chrono::milliseconds(periodMs)); + } +} + +int MpiComm::getRank() const +{ + int rank = 0; +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Comm_rank(mComm, &rank)); +#endif + return rank; +} + +int MpiComm::getSize() const +{ + int world_size = 1; +#if ENABLE_MULTI_DEVICE + MPICHECK(MPI_Comm_size(mComm, &world_size)); +#endif + return world_size; +} + +MpiComm const& MpiComm::world() +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + static MpiComm commWorld{MPI_COMM_WORLD, false}; + initialize(); + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); + return commWorld; +} + +MpiComm& MpiComm::mutableSession() +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + static MpiComm commSession{MPI_COMM_WORLD, false}; + initialize(); + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); + return commSession; +} + +MpiComm& MpiComm::mutableLocalSession() +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + static MpiComm localSession = initLocalSession(); + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); + return localSession; +} + +void MpiComm::refreshLocalSession() +{ +#if ENABLE_MULTI_DEVICE + static std::mutex mutex; + std::unique_lock lock(mutex); + auto initSessionRanks = getWorldRanks(MpiComm::session()); + auto localSessionRanks = getWorldRanks(MpiComm::localSession()); + + // Add to intersectionRanks in order of initSessionRanks + std::vector intersectionRanks; + std::unordered_set localSessionRanksSet(localSessionRanks.begin(), localSessionRanks.end()); + for (auto rank : initSessionRanks) + { + if (localSessionRanksSet.find(rank) != localSessionRanksSet.end()) + { + intersectionRanks.push_back(rank); + } + } + + MPI_Group worldGroup = nullptr; + MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup)); + MPI_Group localGroup = nullptr; + MPICHECK(MPI_Group_incl(worldGroup, intersectionRanks.size(), intersectionRanks.data(), &localGroup)); + MPI_Comm localComm = nullptr; + MPICHECK(MPI_Comm_create_group(MPI_COMM_WORLD, localGroup, intersectionRanks.front(), &localComm)); + MpiComm::mutableLocalSession().mFreeComm = true; + MpiComm::mutableLocalSession() = MpiComm{localComm, false}; + TLLM_LOG_INFO("Refreshed the MPI local session"); +#endif // ENABLE_MULTI_DEVICE +} + +MpiComm::MpiComm(MPI_Comm g, bool freeComm) + : mComm{g} + , mFreeComm{freeComm} +{ + TLLM_CHECK(mComm != MPI_COMM_NULL); +} + +MpiComm::~MpiComm() noexcept +{ +#if ENABLE_MULTI_DEVICE + if (mFreeComm && mComm) + { + if (MPI_Comm_free(&mComm) != MPI_SUCCESS) + { + TLLM_LOG_ERROR("MPI_Comm_free failed"); + } + } +#endif // ENABLE_MULTI_DEVICE +} + +MpiComm::MpiComm(MpiComm&& comm) noexcept + : mComm{comm.mComm} + , mFreeComm{comm.mFreeComm} +{ + comm.mFreeComm = false; +} + +MpiComm& MpiComm::operator=(MpiComm&& comm) noexcept +{ + this->~MpiComm(); + mComm = comm.mComm; + mFreeComm = comm.mFreeComm; + comm.mFreeComm = false; + return *this; +} + +MpiWaitThread::MpiWaitThread(std::string name, std::function funcWait, std::function funcSetup) + : mName{name.c_str()} + , mFuncWait{funcWait} + , mFuncSetup{funcSetup} +{ + TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); + mThread = std::make_unique(&MpiWaitThread::sideThread, this); + TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); +} + +MpiWaitThread::~MpiWaitThread() +{ + TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); + waitStop(); + mShouldExit.store(true); + notifyStart(); + mThread->join(); + mThread.reset(nullptr); + TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); +} + +void MpiWaitThread::sideThread() +{ + if (mFuncSetup) + { + mFuncSetup(); + } + while (!mShouldExit.load()) + { + notifyStop(); + waitStart(); + mFuncWait(); + } +} + +void MpiWaitThread::waitStart() +{ + TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); + std::unique_lock lock(mMutex); + mCondVar.wait(lock, [this] { return mRunning; }); + TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); +} + +void MpiWaitThread::waitStop() +{ + TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); + std::unique_lock lock(mMutex); + mCondVar.wait(lock, [this] { return !mRunning; }); + TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); +} + +void MpiWaitThread::notifyStart() +{ + TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); + std::lock_guard lock(mMutex); + mRunning = true; + mCondVar.notify_one(); + TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); +} + +void MpiWaitThread::notifyStop() +{ + TLLM_LOG_TRACE("%s: %s start", mName.c_str(), __PRETTY_FUNCTION__); + std::lock_guard lock(mMutex); + mRunning = false; + mCondVar.notify_one(); + TLLM_LOG_TRACE("%s: %s stop", mName.c_str(), __PRETTY_FUNCTION__); +} + +} // namespace tensorrt_llm::mpi diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/nvtxUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/nvtxUtils.h new file mode 100644 index 00000000000..0a9d51975af --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/nvtxUtils.h @@ -0,0 +1,46 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include + +namespace tensorrt_llm::common::nvtx +{ +inline nvtx3::color nextColor() +{ +#ifndef NVTX_DISABLE + constexpr std::array kColors{nvtx3::color{0xff00ff00}, nvtx3::color{0xff0000ff}, nvtx3::color{0xffffff00}, + nvtx3::color{0xffff00ff}, nvtx3::color{0xff00ffff}, nvtx3::color{0xffff0000}, nvtx3::color{0xffffffff}}; + constexpr auto numColors = kColors.size(); + + static thread_local std::size_t colorId = 0; + auto const color = kColors[colorId]; + colorId = colorId + 1 >= numColors ? 0 : colorId + 1; + return color; +#else + return nvtx3::color{0}; +#endif +} + +} // namespace tensorrt_llm::common::nvtx + +#define NVTX3_SCOPED_RANGE_WITH_NAME(range, name) \ + ::nvtx3::scoped_range range(::tensorrt_llm::common::nvtx::nextColor(), name) +#define NVTX3_SCOPED_RANGE(range) NVTX3_SCOPED_RANGE_WITH_NAME(range##_range, #range) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.cpp new file mode 100644 index 00000000000..39aefda481a --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.cpp @@ -0,0 +1,323 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tensorrt_llm/common/opUtils.h" +#include "tensorrt_llm/common/mpiUtils.h" + +#include "cuda.h" +#include +#include +#include +#include +#include +#include +#include + +#ifdef _MSC_VER +#define FN_NAME __FUNCTION__ +#else +#define FN_NAME __func__ +#endif + +#if ENABLE_MULTI_DEVICE + +std::unordered_map* getDtypeMap() +{ + static std::unordered_map dtypeMap = {{nvinfer1::DataType::kFLOAT, ncclFloat32}, + {nvinfer1::DataType::kHALF, ncclFloat16}, {nvinfer1::DataType::kBF16, ncclBfloat16}}; + return &dtypeMap; +} + +namespace +{ + +// Get NCCL unique ID for a group of ranks. +ncclUniqueId getUniqueId(std::set const& group) noexcept +{ + auto const rank = COMM_SESSION.getRank(); + TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank); + ncclUniqueId id; + if (rank == *group.begin()) + { + NCCLCHECK(ncclGetUniqueId(&id)); + for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it) + { + COMM_SESSION.sendValue(id, *it, 0); + } + } + else + { + COMM_SESSION.recvValue(id, *group.begin(), 0); + } + TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank); + return id; +} +} // namespace + +std::shared_ptr getComm(std::set const& group) +{ + auto const rank = COMM_SESSION.getRank(); + TLLM_LOG_TRACE("%s start for rank %d", __PRETTY_FUNCTION__, rank); + static std::map, std::shared_ptr> commMap; + static std::mutex mutex; + std::lock_guard lock(mutex); + std::ostringstream oss; + int index = 0; + for (auto const& rank : group) + { + if (index != 0) + { + oss << ","; + } + oss << rank; + index++; + } + auto groupStr = oss.str(); + auto it = commMap.find(group); + if (it != commMap.end()) + { + auto ncclComm = it->second; + TLLM_LOG_TRACE("NCCL comm for group(%s) is cached for rank %d", groupStr.c_str(), rank); + return ncclComm; + } + + TLLM_LOG_TRACE("Init NCCL comm for group(%s) for rank %d", groupStr.c_str(), rank); + ncclUniqueId id = getUniqueId(group); + int groupRank = 0; + for (auto const& currentRank : group) + { + if (rank == currentRank) + break; + ++groupRank; + } + TLLM_CHECK(groupRank < group.size()); + std::shared_ptr ncclComm(new ncclComm_t, + [](ncclComm_t* comm) + { + ncclCommDestroy(*comm); + delete comm; + }); + NCCLCHECK(ncclCommInitRank(ncclComm.get(), group.size(), id, groupRank)); + commMap[group] = ncclComm; + TLLM_LOG_TRACE("%s stop for rank %d", __PRETTY_FUNCTION__, rank); + return ncclComm; +} +#endif // ENABLE_MULTI_DEVICE + +void const* tensorrt_llm::common::getCommSessionHandle() +{ +#if ENABLE_MULTI_DEVICE + return &COMM_SESSION; +#else + return nullptr; +#endif // ENABLE_MULTI_DEVICE +} + +namespace +{ + +// Get current cuda context, a default context will be created if there is no context. +inline CUcontext getCurrentCudaCtx() +{ + CUcontext ctx{}; + CUresult err = cuCtxGetCurrent(&ctx); + if (err == CUDA_ERROR_NOT_INITIALIZED || ctx == nullptr) + { + TLLM_CUDA_CHECK(cudaFree(nullptr)); + err = cuCtxGetCurrent(&ctx); + } + TLLM_CHECK(err == CUDA_SUCCESS); + return ctx; +} + +// Helper to create per-cuda-context singleton managed by std::shared_ptr. +// Unlike conventional singletons, singleton created with this will be released +// when not needed, instead of on process exit. +// Objects of this class shall always be declared static / global, and shall never own CUDA +// resources. +template +class PerCudaCtxSingletonCreator +{ +public: + using CreatorFunc = std::function()>; + using DeleterFunc = std::function; + + // creator returning std::unique_ptr is by design. + // It forces separation of memory for T and memory for control blocks. + // So when T is released, but we still have observer weak_ptr in mObservers, the T mem block can be released. + // creator itself must not own CUDA resources. Only the object it creates can. + PerCudaCtxSingletonCreator(CreatorFunc creator, DeleterFunc deleter) + : mCreator{std::move(creator)} + , mDeleter{std::move(deleter)} + { + } + + std::shared_ptr operator()() + { + std::lock_guard lk{mMutex}; + CUcontext ctx{getCurrentCudaCtx()}; + std::shared_ptr result = mObservers[ctx].lock(); + if (result == nullptr) + { + // Create the resource and register with an observer. + result = std::shared_ptr{mCreator().release(), + [this, ctx](T* obj) + { + if (obj == nullptr) + { + return; + } + mDeleter(obj); + + // Clears observer to avoid growth of mObservers, in case users creates/destroys cuda contexts + // frequently. + std::shared_ptr observedObjHolder; // Delay destroy to avoid dead lock. + std::lock_guard lk{mMutex}; + // Must check observer again because another thread may created new instance for this ctx just + // before we lock mMutex. We can't infer that the observer is stale from the fact that obj is + // destroyed, because shared_ptr ref-count checking and observer removing are not in one atomic + // operation, and the observer may be changed to observe another instance. + observedObjHolder = mObservers.at(ctx).lock(); + if (observedObjHolder == nullptr) + { + mObservers.erase(ctx); + } + }}; + mObservers.at(ctx) = result; + } + return result; + } + +private: + CreatorFunc mCreator; + DeleterFunc mDeleter; + mutable std::mutex mMutex; + // CUDA resources are per-context. + std::unordered_map> mObservers; +}; + +template +class PerThreadSingletonCreator +{ +public: + using CreatorFunc = std::function()>; + using DeleterFunc = std::function; + + // creator returning std::unique_ptr is by design. + // It forces separation of memory for T and memory for control blocks. + // So when T is released, but we still have observer weak_ptr in mObservers, the T mem block can be released. + // creator itself must not own CUDA resources. Only the object it creates can. + PerThreadSingletonCreator(CreatorFunc creator, DeleterFunc deleter) + : mCreator{std::move(creator)} + , mDeleter{std::move(deleter)} + { + } + + std::shared_ptr operator()() + { + std::lock_guard lk{mMutex}; + + std::thread::id thread = std::this_thread::get_id(); + std::shared_ptr result = mObservers[thread].lock(); + + if (result == nullptr) + { + // Create the resource and register with an observer. + result = std::shared_ptr{mCreator().release(), + [this, thread](T* obj) + { + if (obj == nullptr) + { + return; + } + mDeleter(obj); + + // Clears observer to avoid growth of mObservers, in case users creates/destroys cuda contexts + // frequently. + std::shared_ptr observedObjHolder; // Delay destroy to avoid dead lock. + std::lock_guard lk{mMutex}; + // Must check observer again because another thread may created new instance for this ctx just + // before we lock mMutex. We can't infer that the observer is stale from the fact that obj is + // destroyed, because shared_ptr ref-count checking and observer removing are not in one atomic + // operation, and the observer may be changed to observe another instance. + observedObjHolder = mObservers.at(thread).lock(); + if (observedObjHolder == nullptr) + { + mObservers.erase(thread); + } + }}; + mObservers.at(thread) = result; + } + return result; + } + +private: + CreatorFunc mCreator; + DeleterFunc mDeleter; + mutable std::mutex mMutex; + // CUDA resources are per-thread. + std::unordered_map> mObservers; +}; + +} // namespace + +std::shared_ptr getCublasHandle() +{ + static PerThreadSingletonCreator creator( + []() -> auto + { + auto handle = std::unique_ptr(new cublasHandle_t); + TLLM_CUDA_CHECK(cublasCreate(handle.get())); + return handle; + }, + [](cublasHandle_t* handle) + { + TLLM_CUDA_CHECK(cublasDestroy(*handle)); + delete handle; + }); + return creator(); +} + +std::shared_ptr getCublasLtHandle() +{ + static PerThreadSingletonCreator creator( + []() -> auto + { + auto handle = std::unique_ptr(new cublasLtHandle_t); + TLLM_CUDA_CHECK(cublasLtCreate(handle.get())); + return handle; + }, + [](cublasLtHandle_t* handle) + { + TLLM_CUDA_CHECK(cublasLtDestroy(*handle)); + delete handle; + }); + return creator(); +} + +std::shared_ptr getCublasMMWrapper(std::shared_ptr cublasHandle, + std::shared_ptr cublasltHandle, cudaStream_t stream, void* workspace) +{ + static PerThreadSingletonCreator creator( + [cublasHandle, cublasltHandle, stream, workspace]() -> auto + { + auto wrapper = std::unique_ptr( + new tensorrt_llm::common::CublasMMWrapper(cublasHandle, cublasltHandle, stream, workspace)); + return wrapper; + }, + [](tensorrt_llm::common::CublasMMWrapper* wrapper) { delete wrapper; }); + return creator(); +} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.h new file mode 100644 index 00000000000..4e278e5cf23 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/opUtils.h @@ -0,0 +1,215 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/common/cublasMMWrapper.h" +#include "tensorrt_llm/common/workspace.h" + +#include +#include +#include +#include +#if ENABLE_MULTI_DEVICE +#include +#endif // ENABLE_MULTI_DEVICE + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tensorrt_llm::common +{ + +// Write values into buffer +template +void write(char*& buffer, T const& val) +{ + std::memcpy(buffer, &val, sizeof(T)); + buffer += sizeof(T); +} + +// Read values from buffer +template +void read(char const*& buffer, T& val) +{ + std::memcpy(&val, buffer, sizeof(T)); + buffer += sizeof(T); +} + +// Like std::unique_ptr, but does not prevent generation of default copy constructor when used as class members. +// The copy constructor produces nullptr. So the plugin default copy constructor will not really copy this, and +// your clone() implementation is responsible for initializing such data members. +// With this we can simplify clone() implementation when there are many data members including at least one unique_ptr. +template > +class UniqPtrWNullCopy : public std::unique_ptr +{ +public: + using std::unique_ptr::unique_ptr; + + // for compatibility with std::make_unique + explicit UniqPtrWNullCopy(std::unique_ptr&& src) + : std::unique_ptr::unique_ptr{std::move(src)} + { + } + + // copy constructor produces nullptr + UniqPtrWNullCopy(UniqPtrWNullCopy const&) + : std::unique_ptr::unique_ptr{} + { + } +}; + +// for testing only +void const* getCommSessionHandle(); +} // namespace tensorrt_llm::common + +inline bool isBuilding() +{ + auto constexpr key = "IS_BUILDING"; + auto const val = getenv(key); + return val != nullptr && std::string(val) == "1"; +} + +#if ENABLE_MULTI_DEVICE +#define NCCLCHECK(cmd) \ + do \ + { \ + ncclResult_t r = cmd; \ + if (r != ncclSuccess) \ + { \ + printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, ncclGetErrorString(r)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +std::unordered_map* getDtypeMap(); + +std::shared_ptr getComm(std::set const& group); + +#endif // ENABLE_MULTI_DEVICE + +//! To save GPU memory, all the plugins share the same cublas and cublasLt handle globally. +//! Get cublas and cublasLt handle for current cuda context +std::shared_ptr getCublasHandle(); +std::shared_ptr getCublasLtHandle(); +std::shared_ptr getCublasMMWrapper(std::shared_ptr cublasHandle, + std::shared_ptr cublasltHandle, cudaStream_t stream, void* workspace); + +#ifndef DEBUG + +#define PLUGIN_CHECK(status) \ + do \ + { \ + if (status != 0) \ + abort(); \ + } while (0) + +#define ASSERT_PARAM(exp) \ + do \ + { \ + if (!(exp)) \ + return STATUS_BAD_PARAM; \ + } while (0) + +#define ASSERT_FAILURE(exp) \ + do \ + { \ + if (!(exp)) \ + return STATUS_FAILURE; \ + } while (0) + +#define CSC(call, err) \ + do \ + { \ + cudaError_t cudaStatus = call; \ + if (cudaStatus != cudaSuccess) \ + { \ + return err; \ + } \ + } while (0) + +#define DEBUG_PRINTF(...) \ + do \ + { \ + } while (0) + +#else + +#define ASSERT_PARAM(exp) \ + do \ + { \ + if (!(exp)) \ + { \ + fprintf(stderr, "Bad param - " #exp ", %s:%d\n", __FILE__, __LINE__); \ + return STATUS_BAD_PARAM; \ + } \ + } while (0) + +#define ASSERT_FAILURE(exp) \ + do \ + { \ + if (!(exp)) \ + { \ + fprintf(stderr, "Failure - " #exp ", %s:%d\n", __FILE__, __LINE__); \ + return STATUS_FAILURE; \ + } \ + } while (0) + +#define CSC(call, err) \ + do \ + { \ + cudaError_t cudaStatus = call; \ + if (cudaStatus != cudaSuccess) \ + { \ + printf("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(cudaStatus)); \ + return err; \ + } \ + } while (0) + +#define PLUGIN_CHECK(status) \ + { \ + if (status != 0) \ + { \ + DEBUG_PRINTF("%s %d CUDA FAIL %s\n", __FILE__, __LINE__, cudaGetErrorString(status)); \ + abort(); \ + } \ + } + +#define DEBUG_PRINTF(...) \ + do \ + { \ + printf(__VA_ARGS__); \ + } while (0) + +#endif // DEBUG + +#define NVML_CHECK(cmd) \ + do \ + { \ + nvmlReturn_t r = cmd; \ + if (r != NVML_SUCCESS) \ + { \ + printf("Failed, NVML error %s:%d '%s'\n", __FILE__, __LINE__, nvmlErrorString(r)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/quantTypeUtils.cuh b/sgl-kernel/3rdparty/tensorrt_llm/common/quantTypeUtils.cuh new file mode 100644 index 00000000000..a228d3f9fc6 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/quantTypeUtils.cuh @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh" +#include "tensorrt_llm/common/cudaFp8Utils.h" +#include +#include +#include + +namespace tensorrt_llm +{ +namespace common +{ + +template +struct QuantTypeStaticVals; + +template <> +struct QuantTypeStaticVals +{ + static constexpr float MAX_VAL = 127.f; + static constexpr float MIN_SCALING_FACTOR = 0.f; + static constexpr float MIN_SCALING_FACTOR_RCP = FLT_MAX; +}; + +#ifdef ENABLE_FP8 + +template <> +struct QuantTypeStaticVals<__nv_fp8_e4m3> +{ + static constexpr float MAX_VAL = 448.f; + // Ref: https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L720 + static constexpr float MIN_SCALING_FACTOR = 1.0f / (448.f * 512.f); + static constexpr float MIN_SCALING_FACTOR_RCP = (448.f * 512.f); +}; + +#endif // ENABLE_FP8 + +} // namespace common +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/reduceKernelUtils.cuh b/sgl-kernel/3rdparty/tensorrt_llm/common/reduceKernelUtils.cuh new file mode 100644 index 00000000000..c5a4fe0e24e --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/reduceKernelUtils.cuh @@ -0,0 +1,399 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) +#include +#else +#include +#endif +#include "tensorrt_llm/common/cudaTypeUtils.cuh" +#include +#include +#include +#include +#include + +namespace cg = cooperative_groups; + +namespace tensorrt_llm +{ +namespace common +{ + +template +struct BytesToType; + +template <> +struct BytesToType<1> +{ + using type = uint8_t; +}; + +template <> +struct BytesToType<2> +{ + using type = uint16_t; +}; + +template <> +struct BytesToType<4> +{ + using type = uint32_t; +}; + +template <> +struct BytesToType<8> +{ + using type = uint64_t; +}; + +template <> +struct BytesToType<16> +{ + using type = float4; +}; + +template +__device__ inline void copy(void const* local, void* data) +{ + using T = typename BytesToType::type; + + T const* in = static_cast(local); + T* out = static_cast(data); + *out = *in; +} + +static float constexpr HALF_FLT_MAX = 65504.F; +#define FINAL_MASK 0xffffffff + +template +__inline__ __device__ T warpReduceSum(T val) +{ +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val = add(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80 + return val; +} + +/* Calculate the sum of all elements in a block */ +template +__inline__ __device__ T blockReduceSum(T val) +{ + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + val = warpReduceSum(val); + + if (lane == 0) + shared[wid] = val; + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T) (0.0f); + val = warpReduceSum(val); + + return val; +} + +template +__inline__ __device__ T warpReduceMax(T val) +{ +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); + return val; +} + +/* Calculate the maximum of all elements in a block */ +template +__inline__ __device__ T blockReduceMax(T val) +{ + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; // in-warp idx + int wid = threadIdx.x >> 5; // warp idx + + val = warpReduceMax(val); // get maxx in each warp + + if (lane == 0) // record in-warp maxx by warp Idx + shared[wid] = val; + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : -1e20f; + val = warpReduceMax(val); + + return val; +} + +/* Calculate the maximum of all elements in a block */ +template +__inline__ __device__ T blockAllReduceMax(T val) +{ + static __shared__ T shared[32]; + int lane = threadIdx.x & 0x1f; // in-warp idx + int wid = threadIdx.x >> 5; // warp idx + + val = warpReduceMax(val); // get maxx in each warp + + if (lane == 0) // record in-warp maxx by warp Idx + shared[wid] = val; + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + val = (lane < (blockDim.x / 32.f)) ? shared[lane] : -1e20f; + val = warpReduceMax(val); + + return val; +} + +template +__inline__ __device__ T warpReduceSumV2(T* val) +{ +#pragma unroll + for (int i = 0; i < NUM; i++) + { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); + } + return (T) (0.0f); +} + +template +__inline__ __device__ T blockReduceSumV2(T* val) +{ + static __shared__ T shared[NUM][33]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduceSumV2(val); + + if (lane == 0) + { +#pragma unroll + for (int i = 0; i < NUM; i++) + { + shared[i][wid] = val[i]; + } + } + + __syncthreads(); + + bool is_mask = threadIdx.x < (blockDim.x / 32.f); +#pragma unroll + for (int i = 0; i < NUM; i++) + { + val[i] = is_mask ? shared[i][lane] : (T) (0.0f); + } + warpReduceSumV2(val); + return (T) 0.0f; +} + +template +__inline__ __device__ T warpReduceMaxV2(T* val) +{ +#pragma unroll + for (int i = 0; i < NUM; i++) + { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32)); + } + return (T) (0.0f); +} + +template +__inline__ __device__ T blockReduceMaxV2(T* val) +{ + static __shared__ T shared[32][NUM]; + int lane = threadIdx.x & 0x1f; // in-warp idx + int wid = threadIdx.x >> 5; // warp idx + + warpReduceMaxV2(val); // get maxx in each warp + + if (lane == 0) // record in-warp maxx by warp Idx + { +#pragma unroll + for (int i = 0; i < NUM; i++) + { + shared[wid][i] = val[i]; + } + } + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + bool is_mask = threadIdx.x < (blockDim.x / 32.f); +#pragma unroll + for (int i = 0; i < NUM; i++) + { + val[i] = is_mask ? shared[lane][i] : (T) -1e20f; + } + warpReduceMaxV2(val); + + return (T) 0.0f; +} + +template +__inline__ __device__ void cgBlockReduceSumElements(float* element_list, float* cgBlockReduceSumElements_shm) +{ + cg::thread_block cta = cg::this_thread_block(); + cg::thread_block_tile<32> tile = cg::tiled_partition<32>(cta); + + int const tid = cta.thread_rank(); + int const blockz = blockDim.x; + for (int i = 0; i < NUM; i++) + { +#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) + cgBlockReduceSumElements_shm[i * blockz + tid] = cg::reduce(tile, element_list[i], cg::plus()); +#else + // TODO Add implementation here + if (threadIdx.x == 0 && blockIdx.x == 0) + { + printf("[ERROR] Not support cgBlockReduceSumElements when CUDA < 11 \n"); + assert(false); + } +#endif + } + cg::sync(cta); + if (tid == 0) + { +#pragma unroll + for (int i = 0; i < NUM; i++) + { + float beta = 0.0f; + for (int j = 0; j < blockz; j += 32) + { + beta += cgBlockReduceSumElements_shm[i * blockz + j]; + } + element_list[i] = beta; + } + } +} + +template +struct TopK +{ + int p[MAX_K]; // index, being -1 at the tail if the array is not full + T u[MAX_K]; // value in descend order, being -MAX_T_VAL if the element is invalid + + __device__ __forceinline__ void insert(T const elem, int const elem_id) + { + if (elem_id < 0) + { + return; + } + // Condition of updating the array + // 1. array is not full + // 2. elem is greater than the smallest (last) element in the array + // 3. elem is equal to the smallest (last) element in the array but its elem_id is smaller + bool const need_update + = (p[MAX_K - 1] == -1 || elem > u[MAX_K - 1] || elem == u[MAX_K - 1] && elem_id < p[MAX_K - 1]); + if (!need_update) + { + return; + } + // Find suitable index for the new element + int i; + for (i = MAX_K - 2; i >= 0; --i) + { + bool const need_decrease = (p[i] == -1 || elem > u[i] || elem == u[i] && elem_id < p[i]); + if (!need_decrease) + break; + } + // Move elements to correct positions + for (int k = MAX_K - 2; k >= i; --k) + { + p[k + 1] = p[k]; + u[k + 1] = u[k]; + } + p[i] = elem_id; + u[i] = elem; + } + + __device__ __forceinline__ void init() + { + T const MAX_T_VAL = (std::is_same::value) ? HALF_FLT_MAX : FLT_MAX; + for (int i = 0; i < MAX_K; i++) + { + p[i] = -1; + u[i] = -MAX_T_VAL; + } + } +}; + +template +__device__ __forceinline__ TopK reduce_topk_op(TopK const& a, TopK const& b) +{ + TopK res = a; + for (int i = 0; i < MAX_K; ++i) + res.insert(b.u[i], b.p[i]); + return res; +} + +template +struct TopK_2 +{ + int p = -1; + T u = -((std::is_same::value) ? HALF_FLT_MAX : FLT_MAX); + + __device__ __forceinline__ void insert(T elem, int elem_id) + { + if (elem > u) + { + u = elem; + p = elem_id; + } + } + + __device__ __forceinline__ void init() + { + u = -((std::is_same::value) ? HALF_FLT_MAX : FLT_MAX); + p = -1; + } +}; + +template +__device__ __forceinline__ TopK_2 reduce_topk_op_2(TopK_2 const& a, TopK_2 const& b) +{ + return a.u > b.u ? a : b; +} + +template +__device__ __forceinline__ T clamp_inf_for_half(float const input) +{ + return input; +} + +template <> +__device__ __forceinline__ half clamp_inf_for_half(float const input) +{ + // clamp inf values to enable fp16 training + return input > 0.0f ? (half) min(input, HALF_FLT_MAX - 1000) : (half) max(input, -HALF_FLT_MAX + 1000); +} + +} // namespace common +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/stlUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/stlUtils.h new file mode 100644 index 00000000000..9cda9fa0d42 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/stlUtils.h @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace tensorrt_llm::common::stl_utils +{ + +template +constexpr TOutputIt basicInclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst, TBinOp op) +{ + if (first != last) + { + auto val = *first; + while (true) + { + *dFirst = val; + ++dFirst; + ++first; + if (first == last) + { + break; + } + val = op(std::move(val), *first); + } + } + return dFirst; +} + +template +constexpr TOutputIt inclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst) +{ +#if defined(__GNUC__) && __GNUC__ <= 8 + return basicInclusiveScan(first, last, dFirst, std::plus<>{}); +#else + return std::inclusive_scan(first, last, dFirst); +#endif +} + +template +constexpr TOutputIt basicExclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst, T init, TBinOp op) +{ + if (first != last) + { + while (true) + { + T tmp{op(init, *first)}; + *dFirst = init; + ++dFirst; + ++first; + if (first == last) + { + break; + } + init = std::move(tmp); + } + } + return dFirst; +} + +template +constexpr TOutputIt exclusiveScan(TInputIt first, TInputIt last, TOutputIt dFirst, T init) +{ +#if defined(__GNUC__) && __GNUC__ <= 8 + return basicExclusiveScan(first, last, dFirst, std::move(init), std::plus<>{}); +#else + return std::exclusive_scan(first, last, dFirst, std::move(init)); +#endif +} + +template +struct HasOperatorOutput : std::false_type +{ +}; + +template +struct HasOperatorOutput() << std::declval()))>> + : std::true_type +{ +}; + +template +std::string toString(T const& t, typename std::enable_if_t::value, int> = 0) +{ + std::ostringstream oss; + oss << t; + return oss.str(); +} + +template +std::string toString(std::optional const& t, typename std::enable_if_t::value, int> = 0) +{ + std::ostringstream oss; + if (t) + { + oss << t.value(); + } + else + { + oss << "None"; + } + return oss.str(); +} + +} // namespace tensorrt_llm::common::stl_utils diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.cpp new file mode 100644 index 00000000000..f1c6f88b431 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/stringUtils.cpp @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/stringUtils.h" +#include "tensorrt_llm/common/assert.h" + +#include +#include +#include +#include +#include + +namespace tensorrt_llm::common +{ + +namespace +{ +std::string vformat(char const* fmt, va_list args) +{ + va_list args0; + va_copy(args0, args); + auto const size = vsnprintf(nullptr, 0, fmt, args0); + if (size <= 0) + return ""; + + std::string stringBuf(size, char{}); + auto const size2 = std::vsnprintf(&stringBuf[0], size + 1, fmt, args); + + TLLM_CHECK_WITH_INFO(size2 == size, std::string(std::strerror(errno))); + + return stringBuf; +} + +} // namespace + +std::string fmtstr(char const* format, ...) +{ + va_list args; + va_start(args, format); + std::string result = vformat(format, args); + va_end(args); + return result; +}; + +std::unordered_set str2set(std::string const& input, char delimiter) +{ + std::unordered_set values; + if (!input.empty()) + { + std::stringstream valStream(input); + std::string val; + while (std::getline(valStream, val, delimiter)) + { + if (!val.empty()) + { + values.insert(val); + } + } + } + return values; +}; + +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.cpp new file mode 100644 index 00000000000..c00041abdac --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.cpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "tensorrt_llm/common/timestampUtils.h" + +namespace tensorrt_llm::common +{ + +std::string getCurrentTimestamp() +{ + auto now = std::chrono::system_clock::now(); + auto now_t = std::chrono::system_clock::to_time_t(now); + auto tm = *std::localtime(&now_t); + + auto epoch_to_now = now.time_since_epoch(); + auto seconds = std::chrono::duration_cast(epoch_to_now); + auto us = std::chrono::duration_cast(epoch_to_now - seconds); + + std::ostringstream stream; + stream << std::put_time(&tm, "%m-%d-%Y %H:%M:%S"); + stream << "." << std::setfill('0') << std::setw(6) << us.count(); + return stream.str(); +} + +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.h b/sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.h new file mode 100644 index 00000000000..f52f23028c1 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/timestampUtils.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace tensorrt_llm::common +{ + +/// @brief Get the current timestamp in the format "MM-DD-YYYY HH:MM:SS:uuuuuu" +std::string getCurrentTimestamp(); + +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.cpp b/sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.cpp new file mode 100644 index 00000000000..b410613d055 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.cpp @@ -0,0 +1,105 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/tllmException.h" +#include "tensorrt_llm/common/stringUtils.h" + +#include +#if !defined(_MSC_VER) +#include +#include +#include +#endif +#include + +namespace tensorrt_llm::common +{ + +namespace +{ +int constexpr VOID_PTR_SZ = 2 + sizeof(void*) * 2; +} + +#if !defined(_MSC_VER) + +TllmException::TllmException(char const* file, std::size_t line, std::string const& msg) + : std::runtime_error{""} +{ + mNbFrames = backtrace(mCallstack.data(), MAX_FRAMES); + auto const trace = getTrace(); + std::runtime_error::operator=( + std::runtime_error{fmtstr("%s (%s:%zu)\n%s", msg.c_str(), file, line, trace.c_str())}); +} +#else +TllmException::TllmException(char const* file, std::size_t line, std::string const& msg) + : mNbFrames{} + , std::runtime_error{fmtstr("%s (%s:%zu)", msg.c_str(), file, line)} +{ +} +#endif + +TllmException::~TllmException() noexcept = default; + +std::string TllmException::getTrace() const +{ +#if defined(_MSC_VER) + return ""; +#else + auto const trace = backtrace_symbols(mCallstack.data(), mNbFrames); + std::ostringstream buf; + for (auto i = 1; i < mNbFrames; ++i) + { + Dl_info info; + if (dladdr(mCallstack[i], &info) && info.dli_sname) + { + auto const clearName = demangle(info.dli_sname); + buf << fmtstr("%-3d %*p %s + %zd", i, VOID_PTR_SZ, mCallstack[i], clearName.c_str(), + static_cast(mCallstack[i]) - static_cast(info.dli_saddr)); + } + else + { + buf << fmtstr("%-3d %*p %s", i, VOID_PTR_SZ, mCallstack[i], trace[i]); + } + if (i < mNbFrames - 1) + buf << std::endl; + } + + if (mNbFrames == MAX_FRAMES) + buf << std::endl << "[truncated]"; + + std::free(trace); + return buf.str(); +#endif +} + +std::string TllmException::demangle(char const* name) +{ +#if defined(_MSC_VER) + return name; +#else + std::string clearName{name}; + auto status = -1; + auto const demangled = abi::__cxa_demangle(name, nullptr, nullptr, &status); + if (status == 0) + { + clearName = demangled; + std::free(demangled); + } + return clearName; +#endif +} + +} // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/common/workspace.h b/sgl-kernel/3rdparty/tensorrt_llm/common/workspace.h new file mode 100644 index 00000000000..1406e821333 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/common/workspace.h @@ -0,0 +1,87 @@ +/* + * Copyright (c) 1993-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include + +namespace tensorrt_llm::common +{ + +std::uintptr_t constexpr kCudaMemAlign = 128; + +inline int8_t* alignPtr(int8_t* ptr, uintptr_t to) +{ + uintptr_t addr = (uintptr_t) ptr; + if (addr % to) + { + addr += to - addr % to; + } + return (int8_t*) addr; +} + +constexpr size_t alignSize(size_t size, size_t to) +{ + if ((size % to) != 0U) + { + size += to - size % to; + } + return size; +} + +inline int8_t* nextWorkspacePtrCommon(int8_t* ptr, uintptr_t previousWorkspaceSize, uintptr_t const alignment) +{ + uintptr_t addr = (uintptr_t) ptr; + addr += previousWorkspaceSize; + return alignPtr((int8_t*) addr, alignment); +} + +inline int8_t* nextWorkspacePtr(int8_t* ptr, uintptr_t previousWorkspaceSize) +{ + return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, kCudaMemAlign); +} + +inline int8_t* nextWorkspacePtr( + int8_t* const base, uintptr_t& offset, uintptr_t const size, uintptr_t const alignment = kCudaMemAlign) +{ + uintptr_t curr_offset = offset; + uintptr_t next_offset = curr_offset + ((size + alignment - 1) / alignment) * alignment; + int8_t* newptr = size == 0 ? nullptr : base + curr_offset; + offset = next_offset; + return newptr; +} + +inline int8_t* nextWorkspacePtrWithAlignment( + int8_t* ptr, uintptr_t previousWorkspaceSize, uintptr_t const alignment = kCudaMemAlign) +{ + return nextWorkspacePtrCommon(ptr, previousWorkspaceSize, alignment); +} + +inline size_t calculateTotalWorkspaceSize( + size_t const* workspaces, int count, uintptr_t const alignment = kCudaMemAlign) +{ + size_t total = 0; + for (int i = 0; i < count; i++) + { + total += workspaces[i]; + if (workspaces[i] % alignment) + { + total += alignment - (workspaces[i] % alignment); + } + } + return total; +} + +}; // namespace tensorrt_llm::common diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/copy_red_global.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/copy_red_global.hpp new file mode 100644 index 00000000000..61a41031bfb --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/copy_red_global.hpp @@ -0,0 +1,352 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include + +// Config + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDACC_VER_MAJOR__ >= 10)) +#define CUTE_ARCH_RED_F16_SM70_ENABLED +#endif + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) +#define CUTE_ARCH_RED_VEC_SM90_ENABLED +#define CUTE_ARCH_RED_BF16_SM90_ENABLED +#endif + +namespace cute +{ + +////////////////////////////////// +// Wrapper around CUDA's atomicAdd +////////////////////////////////// + +template +struct TypedAtomicAdd +{ + using SRegisters = T[1]; + using DRegisters = T[1]; + + CUTE_HOST_DEVICE static constexpr void copy(T const& src, T& dst) + { + atomicAdd(&dst, src); + } +}; + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +////////////////////////////////// +// F16 ADD PTX +////////////////////////////////// + +struct SM70_RED_ADD_NOFTZ_F16 +{ + using SRegisters = uint16_t[1]; + using DRegisters = uint16_t[1]; + + CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst) + { +#if defined(CUTE_ARCH_RED_F16_SM70_ENABLED) + asm volatile("red.global.add.noftz.f16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED."); +#endif + } +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +struct SM70_RED_ADD_NOFTZ_F16x2 +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst) + { +#if defined(CUTE_ARCH_RED_F16_SM70_ENABLED) + asm volatile("red.global.add.noftz.f16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED."); +#endif + } +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +struct SM90_RED_ADD_NOFTZ_F16x2_V2 +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint64_t[1]; + + CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst) + { +#if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED) + asm volatile("red.global.add.noftz.v2.f16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED."); +#endif + } +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +struct SM90_RED_ADD_NOFTZ_F16x2_V4 +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void copy( + uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst) + { +#if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED) + asm volatile("red.global.add.noftz.v4.f16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1), + "r"(src2), "r"(src3)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED."); +#endif + } +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +////////////////////////////////// +// BF16 ADD PTX +////////////////////////////////// + +struct SM90_RED_ADD_NOFTZ_BF16 +{ + using SRegisters = uint16_t[1]; + using DRegisters = uint16_t[1]; + + CUTE_HOST_DEVICE static void copy(uint16_t const& src0, uint16_t& gmem_dst) + { +#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED) + asm volatile("red.global.add.noftz.bf16 [%0], %1;\n" ::"l"(&gmem_dst), "h"(src0)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."); +#endif + } +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +////////////////////////////////// + +struct SM90_RED_ADD_NOFTZ_BF16x2 +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t& gmem_dst) + { +#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED) + asm volatile("red.global.add.noftz.bf16x2 [%0], %1;\n" ::"l"(&gmem_dst), "r"(src0)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."); +#endif + } +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +////////////////////////////////// + +struct SM90_RED_ADD_NOFTZ_BF16x2_V2 +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint64_t[1]; + + CUTE_HOST_DEVICE static void copy(uint32_t const& src0, uint32_t const& src1, uint64_t& gmem_dst) + { +#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED) + asm volatile("red.global.add.noftz.v2.bf16x2 [%0], {%1, %2};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."); +#endif + } +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +////////////////////////////////// + +struct SM90_RED_ADD_NOFTZ_BF16x2_V4 +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void copy( + uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, uint128_t& gmem_dst) + { +#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED) + asm volatile("red.global.add.noftz.v4.bf16x2 [%0], {%1, %2, %3, %4};\n" ::"l"(&gmem_dst), "r"(src0), "r"(src1), + "r"(src2), "r"(src3)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."); +#endif + } +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +////////////////////////////////// + +} // end namespace cute diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/mma.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/mma.h new file mode 100644 index 00000000000..2362da4f7f2 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/mma.h @@ -0,0 +1,120 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates exposing architecture support for multiply-add operations +*/ + +#pragma once +#include "cutlass_extensions/weight_only_quant_op.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace arch +{ + +// Tag which triggers MMA which will trigger +struct OpMultiplyAddDequantizeInterleavedBToA; + +/* + Below we have extra tags to signal what kind of dequantization we want to do + (per col, scale only fine grained, finegrained with zero). This still lets us + the existing template infrastructure (incl. that in CUTLASS). However, we + split out the template below into OpMultiplyAddDequantizeInterleavedBToA along + with the quantization op before instantiating the GEMM pieces. + + Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount of + code we need to duplicate. + */ +struct OpMultiplyAddDequantizeInterleavedBToA_percol_scale; +struct OpMultiplyAddDequantizeInterleavedBToA_fine_scale; +struct OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias; + +// The default just forwards the original operator +template +struct TagOperator +{ + using TaggedOperator = MmaOp; +}; + +// Specializations below attach more information to the operator +template <> +struct TagOperator +{ + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_percol_scale; +}; + +template <> +struct TagOperator +{ + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scale; +}; + +template <> +struct TagOperator +{ + using TaggedOperator = OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias; +}; + +// Here we instantiate some structs to "detag" the tagged operator. It splits it back to the original +// operator + the extra information. If no extra info was tagged, the dequant op per column scaling +// as a default. +template +struct DetagOperator +{ + using Operator = TaggedMmaOp; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; +}; + +template <> +struct DetagOperator +{ + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY; +}; + +template <> +struct DetagOperator +{ + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; +}; + +template <> +struct DetagOperator +{ + using Operator = OpMultiplyAddDequantizeInterleavedBToA; + static constexpr WeightOnlyQuantOp QuantOp = WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; +}; + +} // namespace arch +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/compute_occupancy.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/compute_occupancy.h new file mode 100644 index 00000000000..c83a9a074da --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/compute_occupancy.h @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include "cutlass/device_kernel.h" +#include "tensorrt_llm/common/cudaUtils.h" + +namespace tensorrt_llm +{ +namespace cutlass_extensions +{ + +template +inline int compute_occupancy_for_kernel() +{ + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size > (48 << 10)) + { + cudaFuncAttributes attr; + int device = 0; + int max_smem_per_block = 0; + tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device)); + tensorrt_llm::common::check_cuda_error( + cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device)); + if constexpr (enable_cutlass_3x) + { + tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, cutlass::device_kernel)); + } + else + { + tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, cutlass::Kernel)); + } + if (smem_size + attr.sharedSizeBytes >= static_cast(max_smem_per_block)) + { + // This should mean that + // cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size) + // wouldn't work. In that case, we return an occupancy of 0. This will cause the heuristic to ignore this + // configuration. + return 0; + } + + if constexpr (enable_cutlass_3x) + { + tensorrt_llm::common::check_cuda_error(cudaFuncSetAttribute( + cutlass::device_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + else + { + tensorrt_llm::common::check_cuda_error(cudaFuncSetAttribute( + cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + } + + int max_active_blocks = -1; + if constexpr (enable_cutlass_3x) + { + tensorrt_llm::common::check_cuda_error( + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, cutlass::device_kernel, + 128 * (GemmKernel::NumLoadWarpGroups + GemmKernel::NumMmaWarpGroups), smem_size)); + } + else + { + tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, cutlass::Kernel, GemmKernel::kThreadCount, smem_size)); + } + + return max_active_blocks; +} + +} // namespace cutlass_extensions +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp new file mode 100644 index 00000000000..bba25ec23a9 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp @@ -0,0 +1,550 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/fast_math.h" + +#include "cute/numeric/numeric_types.hpp" +#include "cute/tensor.hpp" +#include "cutlass/trace.h" + +#include "cutlass_extensions/arch/copy_red_global.hpp" +#include "cutlass_extensions/util/gather_tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace epilogue +{ +namespace collective +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class EpilogueMoeFusedFinalize +{ +public: + using EpilogueSchedule = PtrArrayNoSmemWarpSpecialized; + using DispatchPolicy = PtrArrayNoSmemWarpSpecialized; + + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementIntermediate = typename ThreadEpilogueOp::ElementD; + + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using InternalStrideC = cute::remove_pointer_t; + using ElementD = ElementD_; + using StrideD = StrideD_; + using InternalStrideD = cute::remove_pointer_t; + + static_assert(!is_same_v, "Stride C must be a pointer"); + static_assert(is_same_v, "Stride D must not be a pointer"); + + using CopyAtomR2S = Copy_Atom; + using CopyAtomS2R = Copy_Atom; + using CopyAtomR2G = Copy_Atom; + static constexpr int AlignmentD = CopyAtomR2G::NumValSrc; + + using SmemLayoutD = decltype(tile_to_shape(SmemLayoutAtomD{}, EpilogueTile{})); + + constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); + + struct SharedStorage + { + alignas(SmemAlignmentD) cute::ArrayEngine> smem_D; + }; + + struct TensorMapStorage + { + }; + + struct Arguments + { + typename ThreadEpilogueOp::Params thread{}; + ElementC const** ptr_C{}; + StrideC dC{}; + ElementD* ptr_D{}; + StrideD dD{}; + ElementBias const* ptr_bias; + StrideBias dBias{}; + ElementScale const* ptr_scale; + StrideScale dScale{}; + int64_t const* group_offset{}; + int32_t const* scatter_index{}; + cutlass::FastDivmod num_rows_in_final_output; + }; + + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params to_underlying_arguments( + ProblemShape const&, Arguments const& args, [[maybe_unused]] void* workspace) + { + return args; + } + + template + static size_t get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count = 0) + { + return 0; + } + + template + static cutlass::Status initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, + void* workspace, cudaStream_t stream, CudaHostAdapter* cuda_adapter = nullptr) + { + return cutlass::Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE static bool can_implement( + [[maybe_unused]] ProblemShape problem_shape, [[maybe_unused]] Arguments const& args) + { + bool implementable = true; + if (problem_shape.is_host_problem_shape_available()) + { + // Check alignment for all problem sizes + for (int i = 0; i < problem_shape.groups(); i++) + { + auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1); + auto [M, N, K, L] = problem_shape_MNKL; + implementable = implementable + && cutlass::detail::check_alignment(cute::make_shape(M, N, L), InternalStrideD{}); + } + } + + if (!implementable) + { + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for selected global " + "reduction instruction.\n"); + } + return implementable; + } + + CUTLASS_HOST_DEVICE + EpilogueMoeFusedFinalize(Params const& params_) + : params(params_) + { + } + + CUTLASS_DEVICE + bool is_source_needed() + { + // For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta. + return params.ptr_C != nullptr + && (params.thread.beta_ptr_array || params.thread.beta_ptr || params.thread.beta != 0); + } + + template + CUTLASS_HOST_DEVICE void operator()(ProblemShapeMNKL problem_shape_mnkl, BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, cute::Tensor const& accumulators, TiledMma tiled_mma, + ResidueMNK residue_mnk, int thread_idx, [[maybe_unused]] char* smem_buf) + { + using namespace cute; + using X = Underscore; + + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + + auto synchronize = [&]() + { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + auto mma_tile_m = tile_size<0>(tiled_mma); + auto mma_tile_n = tile_size<1>(tiled_mma); + auto epi_tile_m = size<0>(EpilogueTile{}); + auto epi_tile_n = size<1>(EpilogueTile{}); + + CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M"); + CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); + + // Batches are managed by using appropriate pointers to C and D matrices + int32_t const mock_L = 1; + int32_t const mock_l_coord = 0; + + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + + // If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups. + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups, + // we get the correct alpha/beta values for the current batch/group using group index. + ThreadEpilogueOp epilogue_op(params.thread, l_coord); + + SharedStorage& storage = *reinterpret_cast(smem_buf); + + Tensor sD_ = make_tensor(make_smem_ptr(storage.smem_D.begin()), SmemLayoutD{}); + Tensor sD = as_position_independent_swizzle_tensor(sD_); + + // Function to scatter output rows + auto& num_rows = params.num_rows_in_final_output; + auto read_scatter_map = IndexedGather(make_gmem_ptr(params.scatter_index + params.group_offset[l_coord])); + auto get_scatter_idx = [&](auto i) + { + auto scatter = read_scatter_map(i); + int quot, rem; + num_rows(quot, rem, scatter); + return rem; + }; + + // Represent the full output tensor + ElementC const* ptr_C = epilogue_op.is_source_needed() ? params.ptr_C[l_coord] : nullptr; + auto dC = epilogue_op.is_source_needed() ? params.dC[l_coord] : InternalStrideC{}; + Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C), make_shape(M, N, mock_L), dC); // (m,n,l) + Tensor mD_mnl = make_gather_tensor( + make_gmem_ptr(params.ptr_D), make_shape(M, N, mock_L), params.dD, get_scatter_idx); // (m,n,l) + + // Use fake shape for bias, it doesn't matter + bool const is_bias_needed = params.ptr_bias != nullptr; + Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_bias), make_shape(M, N, 1), params.dBias); + Tensor mScale_mnl = make_tensor( + make_gmem_ptr(params.ptr_scale + params.group_offset[l_coord]), make_shape(M, N), params.dScale); + + Tensor gC_mnl + = local_tile(mC_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl + = local_tile(mD_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) + + Tensor gC = gC_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_, _, m_coord, n_coord, mock_l_coord); // (BLK_M,BLK_N) + + Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + Tensor gBias_mnl + = local_tile(mBias_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gScale_mnl + = local_tile(mScale_mnl, blk_shape_MNK, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N,m,n,l) + + Tensor gBias = gBias_mnl(_, _, m_coord, n_coord, l_coord); // (BLK_M,BLK_N) + Tensor gScale = gScale_mnl(_, _, m_coord, n_coord); // (BLK_M,BLK_N) + + Tensor gBias_epi = flat_divide(gBias, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor gScale_epi = flat_divide(gScale, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + // Get the smallest tiled copy we can use to retile the accumulators + TiledCopy tiled_copy_C_atom + = make_tiled_copy_C_atom(Copy_Atom{}, tiled_mma); + TiledCopy tiled_r2s = make_tiled_copy_S(CopyAtomR2S{}, tiled_copy_C_atom); + + auto thread_r2s = tiled_r2s.get_thread_slice(thread_idx); + Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N) + Tensor tRS_sD = thread_r2s.partition_D(sD); // ((R2S,R2S_V),R2S_M,R2S_N) + Tensor tRS_rD = make_tensor(shape(tRS_sD)); // ((R2S,R2S_V),R2S_M,R2S_N) + + // Make a tiled copy vectorized along major direction of D + auto tiled_s2r = [&]() + { + if constexpr (cutlass::gemm::detail::is_k_major()) + { + constexpr int NumThreadsMajor = epi_tile_n / AlignmentD; + constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor; + return make_tiled_copy(CopyAtomS2R{}, + Layout, Int>, Stride, _1>>{}, + Layout>>{}); + } + else if constexpr (cutlass::gemm::detail::is_mn_major()) + { + constexpr int NumThreadsMajor = epi_tile_m / AlignmentD; + constexpr int NumThreadsMinor = cute::size(tiled_mma) / NumThreadsMajor; + return make_tiled_copy(CopyAtomS2R{}, + Layout, Int>, Stride<_1, Int>>{}, + Layout, _1>>{}); + } + else + { + static_assert(cute::is_void_v, "Unsupported D gmem layout."); + } + }(); + + auto thread_s2r = tiled_s2r.get_thread_slice(thread_idx); + Tensor tSR_sD = thread_s2r.partition_S(sD); // ((S2R,S2R_V),S2R_M,S2R_N) + Tensor tSR_gD = thread_s2r.partition_D(gD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) + Tensor tSR_gC = thread_s2r.partition_D(gC_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) + Tensor tSR_gBias = thread_s2r.partition_D(gBias_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) + Tensor tSR_gScale = thread_s2r.partition_D(gScale_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) + + // Allocate intermediate registers for a single subtile + Tensor tSR_rD = make_tensor(take<0, 3>(shape(tSR_gD))); // ((S2R,S2R_V),S2R_M,S2R_N) + Tensor tSR_rD_final = make_tensor(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N) + Tensor tSR_rC = make_tensor(shape(tSR_rD)); // ((S2R,S2R_V),S2R_M,S2R_N) + Tensor tSR_rBias = make_tensor(tSR_gBias(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N) + Tensor tSR_rScale = make_tensor(tSR_gScale(_, _, _, 0, 0).layout()); // ((S2R,S2R_V),S2R_M,S2R_N) + + // Make an identity coordinate tensor for predicating our output MN tile + Tensor cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); + Tensor cD_epi = flat_divide(cD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor tSR_cD = thread_s2r.partition_D(cD_epi); // ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N) + + // epilogue subtile loop + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m) + { + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n) + { + int mma_m = (epi_m * epi_tile_m) / mma_tile_m; + int mma_n = (epi_n * epi_tile_n) / mma_tile_n; + Tensor tRS_rAcc_mn = tRS_rAcc(_, mma_m, mma_n); + + int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n); + int r2s_v = epi_n_in_mma * size(tRS_rD); + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(tRS_rD); ++epi_v) + { + tRS_rD(epi_v) = tRS_rAcc_mn(r2s_v + epi_v); + } + + copy(tiled_r2s, tRS_rD, tRS_sD); + synchronize(); + + copy(tiled_s2r, tSR_sD, tSR_rD); + synchronize(); + + Tensor tSR_gC_mn = tSR_gC(_, _, _, epi_m, epi_n); + Tensor tSR_gBias_mn = tSR_gBias(_, _, _, epi_m, epi_n); + Tensor tSR_gScale_mn = tSR_gScale(_, _, _, epi_m, epi_n); + Tensor tSR_cD_mn = tSR_cD(_, _, _, epi_m, epi_n); + Tensor tSR_gD_mn = tSR_gD(_, _, _, epi_m, epi_n); + + if (epilogue_op.is_source_needed()) + { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_rD); ++m) + { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_rD); ++n) + { + if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) + { + copy(tSR_gC_mn(_, m, n), tSR_rC(_, m, n)); + if (is_bias_needed) + { + copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n)); + } + copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(tSR_rD); ++i) + { + auto epi_value = epilogue_op(tSR_rD(i, m, n), tSR_rC(i, m, n)); + if (is_bias_needed) + { + epi_value += static_cast(tSR_rBias(i, m, n)); + } + tSR_rD_final(i, m, n) = static_cast(tSR_rScale(i, m, n) * epi_value); + } + copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n)); + } + } + } + } + else + { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_rD); ++m) + { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_rD); ++n) + { + if (elem_less(tSR_cD_mn(0, m, n), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) + { + if (is_bias_needed) + { + copy(tSR_gBias_mn(_, m, n), tSR_rBias(_, m, n)); + } + copy(tSR_gScale_mn(_, m, n), tSR_rScale(_, m, n)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(tSR_rD); ++i) + { + auto epi_value = epilogue_op(tSR_rD(i, m, n)); + if (is_bias_needed) + { + epi_value += static_cast(tSR_rBias(i, m, n)); + } + tSR_rD_final(i, m, n) = static_cast(tSR_rScale(i, m, n) * epi_value); + } + copy(CopyAtomR2G{}, tSR_rD_final(_, m, n), tSR_gD_mn(_, m, n)); + } + } + } + } + } + } + } + +private: + Params params; +}; + +namespace detail +{ + +template +constexpr auto get_vectorized_atomic_add_op() +{ + using namespace cute; + + auto constexpr MaxVecSize = size(MaxVec{}); + + if constexpr (is_same_v) + { + if constexpr (MaxVecSize >= 8) + { + return SM90_RED_ADD_NOFTZ_F16x2_V4{}; + } + else if constexpr (MaxVecSize >= 4) + { + return SM90_RED_ADD_NOFTZ_F16x2_V2{}; + } + else if constexpr (MaxVecSize >= 2) + { + return SM70_RED_ADD_NOFTZ_F16x2{}; + } + else + { + return SM70_RED_ADD_NOFTZ_F16{}; + } + } + else if constexpr (is_same_v) + { + if constexpr (MaxVecSize >= 8) + { + return SM90_RED_ADD_NOFTZ_BF16x2_V4{}; + } + else if constexpr (MaxVecSize >= 4) + { + return SM90_RED_ADD_NOFTZ_BF16x2_V2{}; + } + else if constexpr (MaxVecSize >= 2) + { + return SM90_RED_ADD_NOFTZ_BF16x2{}; + } + else + { + return SM90_RED_ADD_NOFTZ_BF16{}; + } + } + else + { + // non-vectorized atomic add for all other types until supported + return TypedAtomicAdd{}; + } +} + +} // namespace detail + +template +struct EpilogueMoeFusedFinalizeBuilder +{ + + // assuming cooperative kernel schedule + using EpiTileN = decltype(cute::min(size<1>(TileShape{}), _32{})); + using EpilogueTile = Shape<_128, EpiTileN>; + + // Output of linear combination is ElementCompute instead of ElementD + // since we will be doing more computate on it, no need to cast yet. + using ThreadEpilogueOp + = cutlass::epilogue::thread::LinearCombination; + + using SmemLayoutAtomD + = decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()); + using CopyAtomR2S = decltype(detail::sm90_get_smem_store_op_for_accumulator()); + using CopyAtomS2R = DefaultCopy; + using CopyAtomR2G = decltype(detail::get_vectorized_atomic_add_op()); + + template + struct Sm90TmaWarpSpecializedAdapterWithSmemStorage : detail::Sm90TmaWarpSpecializedAdapter + { + // We need to override this one using declaration because otherwise we double up on the smem + using TensorMapStorage = typename EpilogueOp::TensorMapStorage; + + using Base = detail::Sm90TmaWarpSpecializedAdapter; + + CUTLASS_HOST_DEVICE + Sm90TmaWarpSpecializedAdapterWithSmemStorage( + typename EpilogueOp::Params const& params, [[maybe_unused]] typename Base::TensorStorage& shared_tensors) + : Base(params) + { + } + + // These functions depend on the type of TensorMapStorage + template + CUTLASS_DEVICE void tensormaps_perform_update([[maybe_unused]] TensorMapStorage& shared_tensormap, + [[maybe_unused]] typename EpilogueOp::Params const& params, + [[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] int32_t next_batch) + { + } + + template + CUTLASS_DEVICE void tensormaps_cp_fence_release([[maybe_unused]] TensorMapStorage& shared_tensormap, + [[maybe_unused]] cute::TmaDescriptor const* tensormap, [[maybe_unused]] uint32_t lane_predicate) + { + } + }; + + using CollectiveOp = Sm90TmaWarpSpecializedAdapterWithSmemStorage< + EpilogueMoeFusedFinalize>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h new file mode 100644 index 00000000000..f3c622b88a5 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h @@ -0,0 +1,105 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing linear combination with a maximum operation used by epilogues. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/functional.h" +#include "cutlass/half.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace epilogue +{ +namespace thread +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +__forceinline__ __device__ float copysignf_pos(float a, float b) +{ + float r; + r = __int_as_float(__float_as_int(a) | (__float_as_int(b) & 0x80000000)); + return r; +} + +__forceinline__ __device__ float tanh_opt(float x) +{ +#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750) + float const exp_val = -1.f * fabs(2 * x); + return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); +#else + return fast_tanh(x); +#endif +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +template <> +struct GELU_taylor +{ + static bool const kIsHeavy = true; + + CUTLASS_DEVICE + float operator()(float const& z) const + { + + float k0 = float(0.7978845608028654); + float k1 = float(0.044715); + + return float(cutlass::constants::half() * z + * (cutlass::constants::one() + tanh_opt(k0 * z * (cutlass::constants::one() + k1 * z * z)))); + } + + using Params = LinearCombinationGenericParams; + + CUTLASS_DEVICE + float operator()(float const& scalar, Params const& params_) const + { + return this->operator()(scalar); + } +}; + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h new file mode 100644 index 00000000000..d3d4d0a45ab --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h @@ -0,0 +1,352 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one scaling factor per row, and one per column. + + original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h + +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/arch/memory.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/numeric_conversion.h" +#include "tensorrt_llm/common/quantization.h" + +namespace tk = tensorrt_llm::common; + +namespace cutlass +{ +namespace epilogue +{ +namespace threadblock +{ + +template +class EpilogueVisitorPerRowPerCol +{ +public: + using ThreadblockShape = ThreadblockShape_; + static int const kThreadCount = ThreadCount; + + using ScaleTileIterator = ScaleTileIterator_; + using OutputTileIterator = OutputTileIterator_; + using ElementwiseFunctor = ElementwiseFunctor_; + + static int const kIterations = OutputTileIterator::kIterations; + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + using ElementOutput = typename OutputTileIterator::Element; + using LayoutOutput = cutlass::layout::RowMajor; + using ElementAccumulator = ElementAccumulator_; + + using AlphaScaleElementType = typename ScaleTileIterator::Element; + + using ElementCompute = ElementCompute_; + using AccumulatorFragment = Array; + using ComputeFragment = Array; + using OutputVector = Array; + + static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::kAccessWidth; + static bool const kHasMultiStepsInRow = (OutputTileIterator::ThreadMap::Iterations::kColumn > 1); + + /// Argument structure + struct Arguments + { + + typename ElementwiseFunctor::Params elementwise; + int64_t batch_stride_alpha; + int64_t batch_stride_C; + int64_t batch_stride_D; + + // + // Methods + // + Arguments() + : batch_stride_alpha(0) + , batch_stride_C(0) + , batch_stride_D(0) + { + } + + Arguments(typename ElementwiseFunctor::Params elementwise_) + : elementwise(elementwise_) + , batch_stride_alpha(0) + , batch_stride_C(0) + , batch_stride_D(0) + { + } + + Arguments(typename ElementwiseFunctor::Params elementwise_, int64_t batch_stride_alpha_, + int64_t batch_stride_C_, int64_t batch_stride_D_) + : elementwise(elementwise_) + , batch_stride_alpha(batch_stride_alpha_) + , batch_stride_C(batch_stride_C_) + , batch_stride_D(batch_stride_D_) + { + } + }; + + struct Params + { + + typename ElementwiseFunctor::Params elementwise; + int64_t batch_stride_alpha; + int64_t batch_stride_C; + int64_t batch_stride_D; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Arguments const& args) + : elementwise(args.elementwise) + , batch_stride_alpha(args.batch_stride_alpha) + , batch_stride_C(args.batch_stride_C) + , batch_stride_D(args.batch_stride_D) + { + } + }; + + /// Shared storage + struct SharedStorage + { + }; + +private: + Params const& params_; + SharedStorage& shared_storage_; + MatrixCoord extent_; + MatrixCoord extent_real_; + ElementwiseFunctor elementwise_; + + bool const per_token_quant_; + bool const per_channel_quant_; + + AlphaScaleElementType* ptr_alpha_row_; + AlphaScaleElementType* ptr_alpha_col_; + ScaleTileIterator iterator_alpha_col_; + OutputTileIterator iterator_C_; + OutputTileIterator iterator_D_; + + AlphaScaleElementType element_alpha_row_ = 1.0f; + AlphaScaleElementType element_alpha_col_ = 1.0f; + typename ScaleTileIterator::Fragment fragment_alpha_col_; + typename OutputTileIterator::Fragment fragment_C_; + typename OutputTileIterator::Fragment fragment_D_; + + ElementAccumulator beta_; + + int column_offset_; + + MatrixCoord thread_offset_; + +public: + CUTLASS_DEVICE + EpilogueVisitorPerRowPerCol(Params const& params, SharedStorage& shared_storage, + cutlass::MatrixCoord const& problem_size, int thread_idx, int warp_idx, int lane_idx, + typename ScaleTileIterator::Params params_alpha_col, typename OutputTileIterator::Params params_C, + typename OutputTileIterator::Params params_D, tk::QuantMode quant_option, AlphaScaleElementType* ptr_alpha_row, + AlphaScaleElementType* ptr_alpha_col, typename OutputTileIterator::Element* ptr_C, + typename OutputTileIterator::Element* ptr_D, + cutlass::MatrixCoord const& threadblock_offset = cutlass::MatrixCoord(0, 0), int column_offset = 0, + cutlass::MatrixCoord const& problem_size_real = cutlass::MatrixCoord(0, 0)) + : params_(params) + , shared_storage_(shared_storage) + , extent_(problem_size) + , elementwise_(params.elementwise) + , per_token_quant_(quant_option.hasPerTokenScaling()) + , per_channel_quant_(quant_option.hasPerChannelScaling()) + , ptr_alpha_row_(ptr_alpha_row) + , ptr_alpha_col_(ptr_alpha_col) + , iterator_alpha_col_(params_alpha_col, ptr_alpha_col, problem_size, thread_idx, threadblock_offset) + , iterator_C_(params_C, ptr_C, problem_size, thread_idx, threadblock_offset) + , iterator_D_(params_D, ptr_D, problem_size, thread_idx, threadblock_offset) + , extent_real_(problem_size_real) + { + beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta); + + if (beta_ == ElementAccumulator()) + { + iterator_C_.clear_mask(); + } + + if (!per_channel_quant_ && (ptr_alpha_col_ != nullptr)) + { + element_alpha_col_ = *ptr_alpha_col_; + } + + if (!per_token_quant_ && (ptr_alpha_row_ != nullptr)) + { + element_alpha_row_ = *ptr_alpha_row_; + } + } + + /// Helper to indicate split-K behavior + CUTLASS_DEVICE + void set_k_partition(int split_k_index, ///< Index of this threadblock within split-K partitioned scheme + int split_k_slices) + { ///< Total number of split-K slices + } + + /// Called to set the batch index + CUTLASS_DEVICE + void set_batch_index(int batch_idx) + { + iterator_alpha_col_.add_pointer_offset(batch_idx * params_.batch_stride_alpha); + iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C); + iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D); + } + + /// Called at the start of the epilogue just before iterating over accumulator slices + CUTLASS_DEVICE + void begin_epilogue() + { + if (per_channel_quant_) + { + iterator_alpha_col_.load(fragment_alpha_col_); + } + } + + /// Called at the start of one step before starting accumulator exchange + CUTLASS_DEVICE + void begin_step(int step_idx) + { + fragment_D_.clear(); + fragment_C_.clear(); + + if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) + { + iterator_C_.load(fragment_C_); + ++iterator_C_; + } + } + + /// Called at the start of a row + CUTLASS_DEVICE + void begin_row(int row_idx) + { + // load alpha_row in begin_step only when per token(row) scaling is used + if (per_token_quant_) + { + int thread_offset_row + = iterator_D_.thread_start_row() + OutputTileIterator::ThreadMap::iteration_offset(row_idx).row(); + + arch::global_load( + element_alpha_row_, ptr_alpha_row_ + thread_offset_row, thread_offset_row < extent_.row()); + } + } + + /// Called after accumulators have been exchanged for each accumulator vector + CUTLASS_DEVICE + void visit(int iter_idx, int row_idx, int column_idx, int frag_idx, AccumulatorFragment const& accum) + { + + NumericArrayConverter source_converter; + + ComputeFragment result = source_converter(accum); + if (per_channel_quant_) + { + ComputeFragment alpha_col = reinterpret_cast(&fragment_alpha_col_)[column_idx]; + result = per_token_channel_scale_accumulator_(result, alpha_col, element_alpha_row_); + } + else + { + result = per_token_scale_accumulator_(result, element_alpha_col_, element_alpha_row_); + } + + // Convert to the output + NumericArrayConverter output_converter; + OutputVector& output = reinterpret_cast(&fragment_D_)[frag_idx]; + output = output_converter(result); + } + + /// Called at the end of a row + CUTLASS_DEVICE + void end_row(int row_idx) {} + + /// Called after all accumulator elements have been visited + CUTLASS_DEVICE + void end_step(int step_idx) + { + + iterator_D_.store(fragment_D_); + ++iterator_D_; + } + + /// Called after all steps have been completed + CUTLASS_DEVICE + void end_epilogue() {} + +private: + CUTLASS_DEVICE + ComputeFragment per_token_channel_scale_accumulator_( + ComputeFragment const& accum, ComputeFragment const& scale_col, AlphaScaleElementType const& scale_row) + { + + ComputeFragment result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ComputeFragment::kElements; ++i) + { + result[i] = accum[i] * (scale_col[i] * scale_row); + } + + return result; + } + + CUTLASS_DEVICE + ComputeFragment per_token_scale_accumulator_( + ComputeFragment const& accum, AlphaScaleElementType const& scale_col, AlphaScaleElementType const& scale_row) + { + + ComputeFragment result; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ComputeFragment::kElements; ++i) + { + result[i] = accum[i] * (scale_col * scale_row); + } + + return result; + } +}; + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h new file mode 100644 index 00000000000..6f26d790170 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h @@ -0,0 +1,282 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + + original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h + +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/platform/platform.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_clamp.h" +#include "cutlass/epilogue/thread/linear_combination_gelu.h" +#include "cutlass/epilogue/thread/linear_combination_hardswish.h" +#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/thread/linear_combination_relu0.h" +#include "cutlass/epilogue/thread/linear_combination_sigmoid.h" + +#include "cutlass/epilogue/thread/conversion_op.h" +#include "cutlass/epilogue/thread/reduction_op.h" + +#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" + +#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h" +#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h" +#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h" + +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/threadblock/interleaved_epilogue.h" + +#include "cutlass/layout/permute.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace epilogue +{ +namespace threadblock +{ + +//////////////////////////////////////////////////////////////////////////////// + +namespace detail +{ + +/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts. +template +struct DefaultIteratorsTensorOp +{ + using WarpTileIterator + = cutlass::epilogue::warp::TileIteratorTensorOpMixed; + + using SharedLoadIterator + = cutlass::epilogue::threadblock::SharedLoadIteratorMixed; + + static int const kFragmentsPerIteration = 2; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load output tile from shared memory in epilogue. +/// +/// Satisfies: ReadableTileIterator +/// +template +class SharedLoadIteratorMixed +{ +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = int32_t; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + + static int const kAlignment = ThreadMap::kElementsPerAccess * sizeof_bits::value / 8; + + static int const kThreads = ThreadMap::kThreads; + + /// Fragment object + using Fragment = Array; + + /// Memory access size + using AccessType = AlignedArray; + + /// Vector type used for SMEM loads + using LoadType = AlignedArray::value, ThreadMap::kElementsPerAccess), + const_min(16, kAlignment)>; + + static int const kLoadsPerAccess = AccessType::kElements / LoadType::kElements; + +private: + // + // Data members + // + + /// Byte-level pointer + LoadType const* pointers_[kLoadsPerAccess]; + + /// Stride along adjacent rows in units of LoadType + int stride_; + +public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + SharedLoadIteratorMixed(TensorRef ref, int thread_idx) + : stride_((ref.stride(0) / LoadType::kElements)) + { + + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx); + + // Initialize pointers + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) + { + pointers_[i] = reinterpret_cast(ref.data()); + + int col_idx = (thread_offset.column() / kElementsPerAccess) * kLoadsPerAccess; + int bank_offset = (col_idx * int(sizeof(LoadType)) / 128) % kLoadsPerAccess; + + col_idx += (bank_offset + i) % kLoadsPerAccess; + + pointers_[i] += thread_offset.row() * stride_ + col_idx; + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) + { + pointers_[i] += pointer_offset / LoadType::kElements; + } + } + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& offset) + { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kLoadsPerAccess; ++i) + { + pointers_[i] + += offset.row() * Shape::kRow * stride_ + offset.column() * Shape::kColumn / LoadType::kElements; + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) const + { + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) + { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) + { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) + { + + int row_ptr_offset = row * ThreadMap::Delta::kRow * stride_ + + group * ThreadMap::Delta::kGroup * stride_ + cluster * ThreadMap::Delta::kCluster * stride_ + + pointer_offset / LoadType::kElements; + + int frag_row_idx + = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + LoadType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) + { + + int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn + column; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kLoadsPerAccess; ++v) + { + + int vector_idx + = (column * ThreadMap::Delta::kColumn / kElementsPerAccess * kLoadsPerAccess); + + LoadType const* memory_pointer = pointers_[v] + row_ptr_offset; + + frag_ptr[frag_idx * kLoadsPerAccess + v] = memory_pointer[vector_idx]; + } + } + } + } + } + } + + /// Loads a fragment + CUTLASS_DEVICE + void load(Fragment& frag) const + { + + load_with_pointer_offset(frag, 0); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h new file mode 100644 index 00000000000..233d633a823 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h @@ -0,0 +1,141 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * @file epilogue_helpers.h + * + * This file includes types for the epilogues. The empty structs exist so we can signal to template + * code the type of epilogue we want to run, and let the underlying code specify the details such as + * element types, accumulator type and elements per vector access. + * + */ + +#pragma once + +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/thread/linear_combination_silu.h" +#include "cutlass_extensions/epilogue/thread/fused_activations.h" +#include + +namespace tensorrt_llm +{ +namespace cutlass_extensions +{ + +struct EpilogueOpBiasSilu +{ +}; + +struct EpilogueOpBiasReLU +{ +}; + +struct EpilogueOpBiasFtGelu +{ +}; + +struct EpilogueOpBias +{ +}; + +struct EpilogueOpDefaultSilu +{ +}; + +struct EpilogueOpDefaultReLU +{ +}; + +struct EpilogueOpDefaultFtGelu +{ +}; + +struct EpilogueOpDefault +{ +}; + +template +struct Epilogue +{ + static_assert(sizeof(ElementType) == 0, "Unrecognized Epilogue Tag"); +}; + +constexpr auto BiasScaleMode = cutlass::epilogue::thread::ScaleType::NoBetaScaling; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombinationSilu; +}; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombinationRelu; +}; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombinationGeneric; +}; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombination; +}; + +constexpr auto DefaultScaleMode = cutlass::epilogue::thread::ScaleType::Default; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombinationSilu; +}; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombinationRelu; +}; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombinationGeneric; +}; + +template +struct Epilogue +{ + using Op = cutlass::epilogue::thread::LinearCombination; +}; + +} // namespace cutlass_extensions +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl new file mode 100644 index 00000000000..593eca06e3d --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl @@ -0,0 +1,221 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/arch/mma.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/gemm/collective/builders/sm90_common.inl" + +// SM90 Collective Builders should be used only starting CUDA 12.0 +#if (__CUDACC_VER_MAJOR__ >= 12) +#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail +{ + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int compute_stage_count_or_override_gated(StageCountAutoCarveout stage_count) +{ + // 32 bytes to account for barriers etc. + constexpr int stage_barrier_bytes = 32; + constexpr int a_bits = static_cast(sizeof_bits::value); + constexpr int b_bits = static_cast(sizeof_bits::value); + constexpr int stage_bytes = [&]() -> int + { + if constexpr (SwapAB) + { + return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / 8 + + (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 + stage_barrier_bytes; + } + else + { + return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 + + (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / 8 + stage_barrier_bytes; + } + }(); + + return (CapacityBytes - carveout_bytes) / stage_bytes; +} + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_WS_SS +template class Activation, bool SwapAB> +struct CollectiveBuilderGated + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v) &¬ detail:: + is_use_rmem_A()>> +{ + static_assert(is_static::value); + static_assert(is_static::value); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + static_assert(detail::is_aligned(), + "Should meet TMA alignment requirement\n"); + + static constexpr bool IsArrayOfPointersGemm + = (cute::is_same_v); + static constexpr bool IsFP8Input = detail::is_input_fp8(); + static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm), + "Kernel[Array/Group]TmaWarpSpecializedCooperative is only compatible with FP8 FastAccum version right now\n"); + + // For fp32 types, map to tf32 MMA value type + using MmaElementA = cute::conditional_t, tfloat32_t, ElementA>; + using MmaElementB = cute::conditional_t, tfloat32_t, ElementB>; + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + using AtomLayoutMNK = cute::conditional_t + || IsArrayOfPointersGemm, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector(), + AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages + = detail::compute_stage_count_or_override_gated(StageCountType{}); + using DispatchPolicy = cute::conditional_t, + /* For FP8 use a separate mainloop compared to other datatypes */ + cute::conditional_t, + MainloopSm90TmaGmmaWarpSpecialized>>; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMmaGated, + ElementB, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_WS_FP8_FAST_ACCUM_SS +template class Activation, bool SwapAB> +struct CollectiveBuilderGated + || cute::is_same_v + || cute::is_same_v + || cute::is_same_v>> +{ + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(detail::is_aligned(), + "Not meet TMA alignment requirement yet\n"); + static_assert( + detail::is_input_fp8(), "Only FP8 datatypes are compatible with these kernel schedules\n"); + // Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder + static_assert(!detail::is_use_rmem_A(), + "Not supported for fp8 non-TN warp specialized kernels yet\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + static constexpr bool IsArrayOfPointersGemm + = (cute::is_same_v); + using AtomLayoutMNK + = cute::conditional_t + || IsArrayOfPointersGemm, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma( + cute::GMMA::ss_op_selector(), + AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages + = detail::compute_stage_count_or_override_gated(StageCountType{}); + using DispatchPolicy = cute::conditional_t, + MainloopSm90TmaGmmaWarpSpecialized>; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMmaGated, + ElementB, TagToStrideB_t, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_builder_gated.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_builder_gated.hpp new file mode 100644 index 00000000000..2f2422c9914 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_builder_gated.hpp @@ -0,0 +1,58 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass_extensions/gemm/collective/collective_mma_gated.hpp" + +namespace cutlass::gemm::collective +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template class Activation, + bool SwapAB = false, class Enable = void> +struct CollectiveBuilderGated +{ + static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_mma_gated.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_mma_gated.hpp new file mode 100644 index 00000000000..d850f36df5f --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_mma_gated.hpp @@ -0,0 +1,59 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/detail/dependent_false.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template class Activation, bool SwapAB = false> +struct CollectiveMmaGated +{ + static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp" +#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp new file mode 100644 index 00000000000..dcba6ee6377 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp @@ -0,0 +1,642 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cute/tensor_predicate.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective +{ +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template class Activation_, bool SwapAB_> +struct CollectiveMmaGated, TileShape_, + ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, + GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_, SwapAB_> +{ + static constexpr bool isGated = true; + static constexpr bool SwapAB = SwapAB_; + + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecialized; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + using Activation = Activation_; + + using ElementAux = cute::conditional_t; + using ValTypeAux = cute::conditional_t; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert( + (size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert( + (size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert( + (size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert( + (size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape(SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + using SmemLayoutB = decltype(tile_to_shape(SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + using SmemLayoutAux = cute::conditional_t; + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(cute::is_base_of::value + && cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert( + cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using InternalElementA = cute::conditional_t>>; + using InternalElementB = cute::conditional_t>>; + using InternalElementAux = cute::conditional_t; + + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128> + { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + cute::array_aligned> smem_Aux; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments + { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + float scale_d0 = 1.0f; + float scale_d1 = 1.0f; + uint32_t mma_promotion_interval = 4; + }; + + // Device side kernel params + struct Params + { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy(GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy(GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + using TMA_Aux = cute::conditional_t; + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_Aux tma_load_aux; + float scale_d0 = 1.0f; + float scale_d1 = 1.0f; + }; + + // + // Methods + // + + template + static constexpr Params to_underlying_arguments( + ProblemShape const& problem_shape, Arguments const& args, void* workspace) + { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N, K, L), args.dB)); + typename Params::TMA_A tma_load_a = make_tma_copy(GmemTiledCopyA{}, tensor_a, + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + typename Params::TMA_B tma_load_b = make_tma_copy(GmemTiledCopyB{}, tensor_b, + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + + if constexpr (SwapAB) + { + auto ptr_Aux = reinterpret_cast(args.ptr_A + size(make_shape(M, K, L))); + Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA)); + typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyA{}, tensor_aux, + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1}; + } + else + { + auto ptr_Aux = reinterpret_cast(args.ptr_B + size(make_shape(N, K, L))); + Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB)); + typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyB{}, tensor_aux, + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1}; + } + } + + template + static bool can_implement(ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) + { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable + && cutlass::detail::check_alignment(cute::make_shape(M, K, L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable + && cutlass::detail::check_alignment(cute::make_shape(N, K, L), StrideB{}); + + if (!implementable) + { + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytes + = (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)) / 8 + + (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)) / 8 + + (size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) * static_cast(sizeof_bits::value)) + / 8; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) + { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_aux.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// gAux_xkl - The tma tensor, A/B after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const + { + using X = Underscore; + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M, K, L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,n,k,l) + + if constexpr (SwapAB) + { + Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(M, K, L)); // (m,k,l) + Tensor gAux_xkl + = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); + } + else + { + Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(N, K, L)); // (n,k,l) + Tensor gAux_xkl + = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,n,k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template + CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write, + cute::tuple const& load_inputs, BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) + { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) + { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id + = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + Tensor gAux_xkl = get<2>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + auto block_tma_aux = SwapAB ? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y) + : mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.x); + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k) + Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord) : gAux_xkl(_, _, n_coord, _, l_coord); + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + Tensor tAuxgAux = block_tma_aux.partition_S(gAux); + Tensor tAuxsAux = block_tma_aux.partition_D(sAux); + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + uint16_t mcast_mask_aux = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) + { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) + { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, Int<0>{})); + } + } + + if constexpr (cute::is_same_v) + { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) + { + mcast_mask_b |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, Int<0>{})); + } + } + + if constexpr (SwapAB) + { + mcast_mask_aux = mcast_mask_a; + } + else + { + mcast_mask_aux = mcast_mask_b; + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) + { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_, _, _, *k_tile_iter), + tAsA(_, _, _, write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_, _, _, *k_tile_iter), + tBsB(_, _, _, write_stage)); + copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux), tAuxgAux(_, _, _, *k_tile_iter), + tAuxsAux(_, _, _, write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) + { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) + { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template + CUTLASS_DEVICE void mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, FrgTensorC& accum0, + FrgTensorC& accum1, int k_tile_count, int thread_idx, TensorStorage& shared_tensors, + Params const& mainloop_params) + { + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutAux{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + auto tCsAux = [&]() -> auto + { + if constexpr (SwapAB) + { + return thread_mma.partition_A(sAux); + } + else + { + return thread_mma.partition_B(sAux); + } + }(); + auto tCrAux = [&]() -> auto + { + if constexpr (SwapAB) + { + return thread_mma.make_fragment_A(tCsAux); + } + else + { + return thread_mma.make_fragment_B(tCsAux); + } + }(); + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + if constexpr (SwapAB) + { + CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE + } + else + { + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE + } + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sAux)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) + { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum0); + if constexpr (SwapAB) + { + cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum1); + } + else + { + cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accum1); + } + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) + { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum0); + if constexpr (SwapAB) + { + cute::gemm(tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accum1); + } + else + { + cute::gemm(tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accum1); + } + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + + // UNLOCK smem_pipe_release, done _computing_ on it + pipeline.consumer_release(smem_pipe_release); + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + warpgroup_fence_operand(accum0); + warpgroup_fence_operand(accum1); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) + { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) + { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp new file mode 100644 index 00000000000..72c1adf293f --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp @@ -0,0 +1,665 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cute/tensor_predicate.hpp" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/gemm/collective/fp8_accumulation.hpp" +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective +{ +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template class Activation_, bool SwapAB_> +struct CollectiveMmaGated, TileShape_, + ElementA_, StrideA_, ElementB_, StrideB_, TiledMma_, GmemTiledCopyA_, SmemLayoutAtomA_, SmemCopyAtomA_, TransformA_, + GmemTiledCopyB_, SmemLayoutAtomB_, SmemCopyAtomB_, TransformB_, Activation_, SwapAB_> +{ + static constexpr bool isGated = true; + static constexpr bool SwapAB = SwapAB_; + + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedFP8; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + using Activation = Activation_; + + using ElementAux = cute::conditional_t; + using ValTypeAux = cute::conditional_t; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert( + (size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert( + (size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert( + (size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert( + (size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape(SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + using SmemLayoutB = decltype(tile_to_shape(SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(), Step<_2, _1, _3>, Step<_1, _2, _3>>{})); + using SmemLayoutAux = cute::conditional_t; + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value + && cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert( + cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert( + cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128> + { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + cute::array_aligned> smem_Aux; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments + { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + float scale_d0 = 1.0f; + float scale_d1 = 1.0f; + uint32_t mma_promotion_interval = 4; + }; + + // Device side kernel params + struct Params + { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy(GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_, _, 0), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy(GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_, _, 0), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + using TMA_Aux = cute::conditional_t; + TMA_A tma_load_a; + TMA_B tma_load_b; + TMA_Aux tma_load_aux; + float scale_d0 = 1.0f; + float scale_d1 = 1.0f; + uint32_t mma_promotion_interval = 4; + }; + + // + // Methods + // + + template + static constexpr Params to_underlying_arguments( + ProblemShape const& problem_shape, Arguments const& args, void* workspace) + { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M, K, L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N, K, L), args.dB)); + typename Params::TMA_A tma_load_a = make_tma_copy(GmemTiledCopyA{}, tensor_a, + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + typename Params::TMA_B tma_load_b = make_tma_copy(GmemTiledCopyB{}, tensor_b, + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + if constexpr (SwapAB) + { + auto ptr_Aux = reinterpret_cast(args.ptr_A + size(make_shape(M, K, L))); + Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(M, K, L), args.dA)); + typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyA{}, tensor_aux, + SmemLayoutA{}(_, _, cute::Int<0>{}), make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1, args.mma_promotion_interval}; + } + else + { + auto ptr_Aux = reinterpret_cast(args.ptr_B + size(make_shape(N, K, L))); + Tensor tensor_aux = make_tensor(ptr_Aux, make_layout(make_shape(N, K, L), args.dB)); + typename Params::TMA_Aux tma_load_aux = make_tma_copy(GmemTiledCopyB{}, tensor_aux, + SmemLayoutB{}(_, _, cute::Int<0>{}), make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + return {tma_load_a, tma_load_b, tma_load_aux, args.scale_d0, args.scale_d1, args.mma_promotion_interval}; + } + } + + template + static bool can_implement(ProblemShape const& problem_shape, [[maybe_unused]] Arguments const& args) + { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable + && cutlass::detail::check_alignment(cute::make_shape(M, K, L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable + && cutlass::detail::check_alignment(cute::make_shape(N, K, L), StrideB{}); + /* MMA promotion interval should be a multiple of 4, since each mainloop iteration would issue 4 MMA + * instructions. */ + implementable = implementable && (args.mma_promotion_interval % 4 == 0); + + if (!implementable) + { + CUTLASS_TRACE_HOST( + " CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytes + = (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)) / 8 + + (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)) / 8 + + (size<0>(SmemLayoutAux{}) * size<1>(SmemLayoutAux{}) * static_cast(sizeof_bits::value)) + / 8; + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) + { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_aux.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// gAux_xkl - The tma tensor, A/B after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l) + template + CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const + { + using X = Underscore; + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M, K, L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,n,k,l) + + if constexpr (SwapAB) + { + Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(M, K, L)); // (m,k,l) + Tensor gAux_xkl + = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step<_1, X, _1>{}); // (BLK_M,BLK_K,m,k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); + } + else + { + Tensor mAux_xkl = mainloop_params.tma_load_aux.get_tma_tensor(make_shape(N, K, L)); // (n,k,l) + Tensor gAux_xkl + = local_tile(mAux_xkl, TileShape{}, make_coord(_, _, _), Step{}); // (BLK_N,BLK_K,n,k,l) + return cute::make_tuple(gA_mkl, gB_nkl, gAux_xkl); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template + CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write, + cute::tuple const& load_inputs, BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) + { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) + { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id + = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + Tensor gAux_xkl = get<2>(load_inputs); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + auto block_tma_aux = SwapAB ? mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.y) + : mainloop_params.tma_load_aux.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_, _, m_coord, _, l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_, _, n_coord, _, l_coord); // (BLK_N,BLK_K,k) + Tensor gAux = SwapAB ? gAux_xkl(_, _, m_coord, _, l_coord) : gAux_xkl(_, _, n_coord, _, l_coord); + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + Tensor tAuxgAux = block_tma_aux.partition_S(gAux); + Tensor tAuxsAux = block_tma_aux.partition_D(sAux); + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + uint16_t mcast_mask_aux = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) + { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) + { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x, n, Int<0>{})); + } + } + + if constexpr (cute::is_same_v) + { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) + { + mcast_mask_b |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, Int<0>{})); + } + } + + if constexpr (SwapAB) + { + mcast_mask_aux = mcast_mask_a; + } + else + { + mcast_mask_aux = mcast_mask_b; + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) + { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_, _, _, *k_tile_iter), + tAsA(_, _, _, write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_, _, _, *k_tile_iter), + tBsB(_, _, _, write_stage)); + copy(mainloop_params.tma_load_aux.with(*tma_barrier, mcast_mask_aux), tAuxgAux(_, _, _, *k_tile_iter), + tAuxsAux(_, _, _, write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) + { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) + { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template + CUTLASS_DEVICE void mma(MainloopPipeline pipeline, PipelineState smem_pipe_read, FrgTensorC& accum0, + FrgTensorC& accum1, int k_tile_count, int thread_idx, TensorStorage& shared_tensors, + Params const& mainloop_params) + { + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sAux = make_tensor(make_smem_ptr(shared_tensors.smem_Aux.data()), SmemLayoutAux{}); + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + auto tCsAux = [&]() -> auto + { + if constexpr (SwapAB) + { + return thread_mma.partition_A(sAux); + } + else + { + return thread_mma.partition_B(sAux); + } + }(); + auto tCrAux = [&]() -> auto + { + if constexpr (SwapAB) + { + return thread_mma.make_fragment_A(tCsAux); + } + else + { + return thread_mma.make_fragment_B(tCsAux); + } + }(); + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum0)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum0)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + if constexpr (SwapAB) + { + CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<1>(accum1)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum1)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCsAux)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tCsAux)); // PIPE + } + else + { + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum1)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsAux) == size<2>(accum1)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsAux)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsAux)); // PIPE + } + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sAux)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + GmmaFP8Accumulation accumulation0(accum0, mainloop_params.mma_promotion_interval, size<2>(tCrA)); + GmmaFP8Accumulation accumulation1(accum1, mainloop_params.mma_promotion_interval, size<2>(tCrA)); + warpgroup_fence_operand(accumulation0()); + warpgroup_fence_operand(accumulation1()); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + if (accumulation0.prepare_if_needed()) + { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) + { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm( + tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation0()); + if constexpr (SwapAB) + { + cute::gemm( + tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation1()); + } + else + { + cute::gemm( + tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accumulation1()); + } + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + accumulation0.promote_if_needed(); + accumulation1.promote_if_needed(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accumulation0()); + warpgroup_fence_operand(accumulation1()); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + if (accumulation0.prepare_if_needed()) + { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + } + + warpgroup_fence_operand(accumulation0()); + warpgroup_fence_operand(accumulation1()); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) + { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm( + tiled_mma, tCrA(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation0()); + if constexpr (SwapAB) + { + cute::gemm( + tiled_mma, tCrAux(_, _, k_block, read_stage), tCrB(_, _, k_block, read_stage), accumulation1()); + } + else + { + cute::gemm( + tiled_mma, tCrA(_, _, k_block, read_stage), tCrAux(_, _, k_block, read_stage), accumulation1()); + } + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accumulation0()); + warpgroup_fence_operand(accumulation1()); + + accumulation0.promote_if_needed(); + accumulation1.promote_if_needed(); + + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + accumulation0.promote_residue_if_needed(); + accumulation1.promote_residue_if_needed(); + + warpgroup_fence_operand(accumulation0()); + warpgroup_fence_operand(accumulation1()); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) + { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) + { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h new file mode 100644 index 00000000000..2edd5a228b4 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h @@ -0,0 +1,438 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and + batched array variants. +*/ + +#pragma once + +// #include + +#include "cutlass/arch/arch.h" +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_universal.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm_universal.h" + +#include "cutlass/trace.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace device +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/* + This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088) + It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs + and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs. + + Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support + that feature at the moment. + */ + +template +class GemmUniversalBaseCompat +{ +public: + using GemmKernel = GemmKernel_; + using ThreadblockShape = typename GemmKernel::Mma::Shape; + + using ElementA = typename GemmKernel::ElementA; + using LayoutA = typename GemmKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = GemmKernel::kTransformA; + + using ElementB = typename GemmKernel::ElementB; + using LayoutB = typename GemmKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = GemmKernel::kTransformB; + + using ElementC = typename GemmKernel::ElementC; + using LayoutC = typename GemmKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + + using ElementAccumulator = typename GemmKernel::Mma::Policy::Operator::ElementC; + + using EpilogueOutputOp = typename GemmKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename GemmKernel::ThreadblockSwizzle; + using Operator = typename GemmKernel::Operator; + + /// Argument structure + using Arguments = typename GemmKernel::Arguments; + +protected: + /// Kernel parameters object + typename GemmKernel::Params params_; + +protected: + /// Private helper to obtain the grid dimensions with fix-up for split-K + static void get_grid_shape_(gemm::GemmCoord& grid_tiled_shape, int& gemm_k_size, Arguments const& args) + { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); + + gemm_k_size = args.problem_size.k(); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) + { + + int const kAlignK + = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + + if (gemm_k_size) + { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + } + +public: + /// Constructs the GEMM. + GemmUniversalBaseCompat() {} + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args) + { + + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + ThreadblockSwizzle threadblock_swizzle; + dim3 grid = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + + uint32_t const kGridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1); + + if (!(grid.y <= kGridYZMax && grid.z <= kGridYZMax)) + { + + return Status::kErrorInvalidProblem; + } + + return GemmKernel::can_implement(args); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args) + { + + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_workspace_size()"); + + size_t workspace_bytes = 0; + + // Determine grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + if (args.mode == GemmUniversalMode::kGemmSplitKParallel) + { + + // Split-K parallel always requires a temporary workspace + workspace_bytes = sizeof(ElementC) * size_t(args.batch_stride_D) * size_t(grid_tiled_shape.k()); + } + else if (args.mode == GemmUniversalMode::kGemm && grid_tiled_shape.k() > 1) + { + + // Serial split-K only requires a temporary workspace if the number of partitions along the + // GEMM K dimension is greater than one. + workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); + } + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + workspace_bytes += GemmKernel::get_extra_workspace_size(args, grid_tiled_shape); + + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const& args) + { + + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::get_grid_shape()"); + + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + dim3 result = threadblock_swizzle.get_grid_shape(grid_tiled_shape); + + CUTLASS_TRACE_HOST(" grid_tiled_shape: " << grid_tiled_shape << "\n" + << " result = {" << result << "}"); + + return result; + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) + { + + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::maximum_active_blocks()"); + + int max_active_blocks = -1; + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + + if (smem_size <= (48 << 10)) + { + + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, Kernel, GemmKernel::kThreadCount, smem_size); + + if (result == cudaSuccess) + { + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + } + else + { + + // Query assuming zero shared memory then compute occupancy limit based on SMEM + cudaError_t result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, Kernel, GemmKernel::kThreadCount, 0); + + if (result != cudaSuccess) + { + + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); + + return -1; + } + + if (smem_capacity < 0) + { + int device_idx = 0; + result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) + { + return -1; + } + + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) + { + return -1; + } + + smem_capacity = static_cast(properties.sharedMemPerMultiprocessor); + } + + int occupancy = std::min(max_active_blocks, smem_capacity / smem_size); + + CUTLASS_TRACE_HOST(" occupancy: " << occupancy); + + return occupancy; + } + + CUTLASS_TRACE_HOST(" returning internal error"); + + return -1; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) + { + + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + size_t workspace_bytes = get_workspace_size(args); + + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + if (workspace_bytes) + { + + if (!workspace) + { + CUTLASS_TRACE_HOST(" error: device workspace must not be null"); + + return Status::kErrorWorkspaceNull; + } + + if (args.mode == GemmUniversalMode::kGemm) + { + CUTLASS_TRACE_HOST(" clearing device workspace"); + cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_bytes, stream); + + if (result != cudaSuccess) + { + CUTLASS_TRACE_HOST(" cudaMemsetAsync() returned error " << cudaGetErrorString(result)); + + return Status::kErrorInternal; + } + } + } + + // Get CUDA grid shape + cutlass::gemm::GemmCoord grid_tiled_shape; + int gemm_k_size = 0; + + get_grid_shape_(grid_tiled_shape, gemm_k_size, args); + + // Initialize the Params structure + params_ = typename GemmKernel::Params(args, grid_tiled_shape, gemm_k_size, static_cast(workspace)); + + // Specify shared memory capacity for kernel. + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) + { + cudaError_t result + = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) + { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const& args, void* workspace = nullptr) + { + + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) + { + return Status::kErrorWorkspaceNull; + } + + params_.update(args, workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) + { + CUTLASS_TRACE_HOST("GemmUniversalBaseCompat::run()"); + + // + // Configure grid and block dimensions + // + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + // + // Launch kernel + // + + CUTLASS_TRACE_HOST(" grid: (" << grid << "), block: (" << block << "), SMEM: " << smem_size << " bytes"); + + // Launch + cutlass::Kernel<<>>(params_); + + // + // Query for errors + // + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) + { + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) + { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) + { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) + { + status = run(stream); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/splitk_gemm_grouped.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/splitk_gemm_grouped.h new file mode 100644 index 00000000000..bfd3666b9c1 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/splitk_gemm_grouped.h @@ -0,0 +1,542 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/arch/arch.h" +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_universal.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm_universal.h" + +#include "cutlass/trace.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace device +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, GemmCoord const* problem_sizes, int splitk, + int64_t* splitk_buffer_offsets) +{ + // in_tensor: [problem_idx, k_partition, hidden_size] + // Note that different requests of in_tensor might have different hidden_size (=m*n) + // so, we need to use splitk_buffer_offsets. + // out_tensor: problem_idx * [hidden_size] + + int const problem_idx = blockIdx.y; + GemmCoord problem = problem_sizes[problem_idx]; + int const hidden_size = problem.m() * problem.n(); + const T_IN* in_tensor_ = in_tensor + splitk_buffer_offsets[problem_idx] * splitk; + T_OUT* out_tensor_ = out_tensor[problem_idx]; + + for (int i = threadIdx.x + blockIdx.x * blockDim.x; i < hidden_size; i += blockDim.x * gridDim.x) + { + float sum = 0.0f; + for (int k_idx = 0; k_idx < splitk; k_idx++) + { + sum += (float) in_tensor_[k_idx * hidden_size + i]; + } + out_tensor_[i] = (T_OUT) (sum); + } +} + +/// GEMM Grouped +template +class BaseSplitkGrouped +{ +public: + using BaseKernel = BaseKernel_; + + using ElementA = typename BaseKernel::ElementA; + using LayoutA = typename BaseKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = BaseKernel::kTransformA; + static int const kAlignmentA = BaseKernel::kAlignmentA; + + using ElementB = typename BaseKernel::ElementB; + using LayoutB = typename BaseKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = BaseKernel::kTransformB; + static int const kAlignmentB = BaseKernel::kAlignmentB; + + using ElementC = typename BaseKernel::ElementC; + using LayoutC = typename BaseKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + static int const kAlignmentC = BaseKernel::kAlignmentC; + + using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC; + + using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename threadblock::GemmSplitKHorizontalThreadblockSwizzle; + + using Operator = typename BaseKernel::Operator; + using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename WarpMmaOperator::MathOperator; + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + using ThreadblockShape = typename BaseKernel::Mma::Shape; + using WarpShape = typename BaseKernel::WarpShape; + using InstructionShape = typename BaseKernel::InstructionShape; + static int const kStages = BaseKernel::Mma::kStages; + + /// Argument structure + using Arguments = typename BaseKernel::Arguments; + + using ProblemInfo = typename BaseKernel::ProblemVisitor::ProblemInfo; + +protected: + /// Kernel parameters object + typename BaseKernel::Params gemm_params_; + +private: + /// Get the number of tiles across all problems in a group + static int32_t group_tile_count(cutlass::gemm::GemmCoord const* problem_sizes_ptr, int problem_count) + { + int32_t tiles = 0; + for (int32_t i = 0; i < problem_count; ++i) + { + cutlass::gemm::GemmCoord problem = problem_sizes_ptr[i]; + BaseKernel::ProblemVisitor::possibly_transpose_problem(problem); + tiles += problem_tile_count(problem); + } + return tiles; + } + + /// Copy from `data` to `workspace` + Status copy_to_workspace(void* workspace, void* data, size_t bytes) + { + cudaError_t cuda_error = cudaMemcpy(workspace, data, bytes, cudaMemcpyHostToDevice); + if (cuda_error != cudaSuccess) + { + // Call cudaGetLastError() to clear the error bit + cuda_error = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaMemcpy() returned error " << cudaGetErrorString(cuda_error)); + return Status::kErrorInternal; + } + + return Status::kSuccess; + } + + /// Precomputes scheduling information for the grouped GEMM + Status precompute(Arguments const& args, int32_t tile_count, void* workspace) + { + size_t workspace_bytes = get_workspace_size(args); + std::vector host_workspace(workspace_bytes); + BaseKernel::ProblemVisitor::host_precompute( + args.host_problem_sizes, args.problem_count, args.threadblock_count, (void*) host_workspace.data()); + return copy_to_workspace(workspace, host_workspace.data(), workspace_bytes); + } + + /// Reorder `data` according to `indices` + template + static void reorder_array(T* data, std::vector const& indices) + { + // For now, simply create a copy of the data and then copy over to the original. + std::vector copy(indices.size()); + for (size_t i = 0; i < indices.size(); ++i) + { + copy.at(i) = data[indices[i]]; + } + + memcpy(data, copy.data(), indices.size() * sizeof(T)); + } + +public: + /// Constructs the GEMM. + BaseSplitkGrouped() {} + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const& args) + { + + return BaseKernel::can_implement(args); + } + + /// Get the number of tiles in a problem + static int32_t problem_tile_count(cutlass::gemm::GemmCoord const& problem) + { + auto grid = BaseKernel::ProblemVisitor::grid_shape(problem); + return BaseKernel::ProblemVisitor::tile_count(grid); + } + + /// Get the number of tiles across all problems in a group + static int32_t group_tile_count(Arguments const& args) + { + if (args.host_problem_sizes == nullptr) + { + CUTLASS_TRACE_HOST("Received nullptr for `args.host_problem_sizes"); + return -1; + } + + return group_tile_count(args.host_problem_sizes, args.problem_count); + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const& args) + { + size_t total_mn = 0; + for (int i = 0; i < args.problem_count; i++) + { + total_mn += args.host_problem_sizes[i].m() * args.host_problem_sizes[i].n(); + } + size_t workSpaceSize = total_mn * sizeof(ElementAccumulator) * args.split_k_slices; + + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) + { + workSpaceSize += BaseKernel::ProblemVisitor::get_workspace_size( + args.host_problem_sizes, args.problem_count, args.threadblock_count); + } + return workSpaceSize; + } + + /// Computes the grid shape + static dim3 get_grid_shape(Arguments const& args) + { + + return dim3(args.threadblock_count, 1, 1); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int smem_capacity = -1) + { + + CUTLASS_TRACE_HOST("BaseSplitkGrouped::maximum_active_blocks()"); + + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + + cudaError_t result; + if (smem_size > (48 << 10)) + { + result = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) + { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result)); + return -1; + } + } + + int max_active_blocks = -1; + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, Kernel, BaseKernel::kThreadCount, smem_size); + + if (result != cudaSuccess) + { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Sorts each pointer passed in according to the indices that sort + /// `problem_sizes_ptr` in descending order of problem-K dimension. + static void sort_problems(int problem_count, cutlass::gemm::GemmCoord* problem_sizes_ptr, int64_t* lda_host_ptr, + int64_t* ldb_host_ptr, int64_t* ldc_host_ptr, int64_t* ldd_host_ptr, int64_t* offset_A_ptr, + int64_t* offset_B_ptr, int64_t* offset_C_ptr, int64_t* offset_D_ptr) + { + std::vector indices(problem_count); + std::iota(indices.begin(), indices.end(), 0); + std::stable_sort(indices.begin(), indices.end(), + [&problem_sizes_ptr](size_t i, size_t j) { return problem_sizes_ptr[i].k() > problem_sizes_ptr[j].k(); }); + + reorder_array(problem_sizes_ptr, indices); + reorder_array(lda_host_ptr, indices); + reorder_array(ldb_host_ptr, indices); + reorder_array(ldc_host_ptr, indices); + reorder_array(ldd_host_ptr, indices); + reorder_array(offset_A_ptr, indices); + reorder_array(offset_B_ptr, indices); + reorder_array(offset_C_ptr, indices); + reorder_array(offset_D_ptr, indices); + } + + /// Computes the number of threadblocks to launch for the grouped kernel + static int sufficient( + cutlass::gemm::GemmCoord const* problem_sizes_ptr = nullptr, int problem_count = 0, int available_sm_count = -1) + { + // Determine the number of blocks that would be launched to fill up a single + // wave on the GPU with each SM having maximum occupancy. + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + if (result != cudaSuccess) + { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaGetDevice() returned error " << cudaGetErrorString(result)); + return 0; + } + + int multiprocessor_count; + result = cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device_idx); + if (result != cudaSuccess) + { + CUTLASS_TRACE_HOST(" cudaDeviceGetAttribute() returned error " << cudaGetErrorString(result)); + return 0; + } + + bool override_sm_count = (available_sm_count < 0 || available_sm_count > multiprocessor_count); + if (override_sm_count) + { + available_sm_count = multiprocessor_count; + } + + int max_active_blocks = maximum_active_blocks(); + if (max_active_blocks <= 0) + { + return 0; + } + + int occupancy_based_block_count = available_sm_count * max_active_blocks; + + if (problem_sizes_ptr == nullptr || problem_count == 0) + { + return occupancy_based_block_count; + } + + int total_tiles = group_tile_count(problem_sizes_ptr, problem_count); + + // If the group contains a single problem, launching the exact number of + // threadblocks needed to cover the problem minimizes the work performed + // per threadblock in finding the next tile to compute. We return total_tiles + // unless the user has provided the SM count. + if (problem_count == 1 && override_sm_count) + { + return total_tiles; + } + + // Choose between the full wave of threadblocks and the tile count. If there + // are fewer tiles in the group than threadblocks in the full wave, only + // some threadblocks will be assigned tiles. Those threadblocks + // which are not assigned tiles still need to perform the work of iterating through + // problem sizes to determine that they have no work to do. This competes for cycles + // with those threadblocks that are assigned tiles to compute. + return std::min(total_tiles, occupancy_based_block_count); + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) + { + + CUTLASS_TRACE_HOST("BaseSplitkGrouped::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Workspace + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) + { + return Status::kErrorWorkspaceNull; + } + + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) + { + int32_t tile_count = group_tile_count(args); + Status status = precompute(args, tile_count, workspace); + if (status != Status::kSuccess) + { + return status; + } + + gemm_params_ = typename BaseKernel::Params(args, workspace, tile_count); + } + else + { + gemm_params_ = typename BaseKernel::Params(args, workspace); + } + + // Specify shared memory capacity for kernel. + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) + { + cudaError_t result + = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) + { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const& args, void* workspace = nullptr) + { + + size_t workspace_bytes = get_workspace_size(args); + + if (workspace_bytes && !workspace) + { + return Status::kErrorWorkspaceNull; + } + + if (BaseKernel::ProblemVisitor::kRequiresPrecomputation) + { + int32_t tile_count = group_tile_count(args); + Status status = precompute(args, tile_count, workspace); + if (status != Status::kSuccess) + { + return status; + } + + gemm_params_.update(args, workspace, tile_count); + } + else + { + gemm_params_.update(args, workspace); + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) + { + if (!gemm_params_.problem_visitor.problem_count) + { + return Status::kSuccess; + } + + // + // Launch kernel + // + + // Launch splitk grouped gemm + { + dim3 grid(gemm_params_.threadblock_count, 1, gemm_params_.split_k_slices); + dim3 block(BaseKernel::kThreadCount, 1, 1); + + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + cutlass::Kernel<<>>(gemm_params_); + + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) + { + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + // Launch splitkReduction + { + dim3 grid(32, gemm_params_.problem_visitor.problem_count); + dim3 block(256); + splitkReduction<<>>(gemm_params_.ptr_D, gemm_params_.ptr_D_split, + gemm_params_.problem_visitor.problem_sizes, gemm_params_.split_k_slices, + gemm_params_.splitk_buffer_offsets); + + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) + { + CUTLASS_TRACE_HOST(" grid launch failed with error " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) + { + return run(stream); + } + + /// Initializes and runs the kernel. + Status operator()(Arguments const& args, void* workspace, cudaStream_t stream = nullptr) + { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) + { + status = run(stream); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GEMM Grouped +template +class SplitkGemmGrouped : public BaseSplitkGrouped +{ +public: + using GemmKernel = GemmKernel_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h new file mode 100644 index 00000000000..100a1161a88 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h @@ -0,0 +1,162 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/bfloat16.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/half.h" +#include "cutlass/layout/matrix.h" + +#include "cutlass_extensions/arch/mma.h" +#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +template +struct MixedGemmArchTraits +{ + static_assert(dependent_false, "Unrecognised parameterization"); +}; + +template +struct MixedGemmArchTraits +{ + static constexpr int Stages = 2; + using OperatorClass = cutlass::arch::OpClassSimt; + using AccType = float; + using LayoutB = cutlass::layout::ColumnMajor; + + static constexpr int ElementsPerAccessA = 1; + static constexpr int ElementsPerAccessB = 1; + static constexpr int ElementsPerAccessC = 1; + static constexpr int ThreadblockK = 8; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// ======================= Turing Traits ============================== +// Note that turing does not have native bfloat support so weights and activations will be casted to fp16 +// and compute will happen in fp16 then will be converted for bf16 output. +template +struct MixedGemmArchTraits::value + || cutlass::platform::is_same::value>::type> +{ +private: + using LayoutDetails = LayoutDetailsB; + +public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + + using Operator = typename LayoutDetails::Operator; +}; + +// ======================= Ampere Traits ============================== +template +struct MixedGemmArchTraits::value + || cutlass::platform::is_same::value>::type> +{ +private: + using LayoutDetails = LayoutDetailsB; + +public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + using Operator = typename LayoutDetails::Operator; +}; + +// ======================= Ada Traits ============================== +template +struct MixedGemmArchTraits::value + || cutlass::platform::is_same::value>::type> +{ +private: + using LayoutDetails = LayoutDetailsB; + +public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits::value>; + + using Operator = typename LayoutDetails::Operator; +}; + +// FP8 A/B = fp8, C/D = fp32 +template +struct MixedGemmArchTraits::value + || cutlass::platform::is_same::value>::type> +{ +private: + using LayoutDetails = LayoutDetailsB; + +public: + static constexpr int ThreadblockK = LayoutDetails::ThreadblockK; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using AccType = float; + // be careful, TypeC should align with HopperGroupedGemmInput::OutputTypeAdaptor_t + using TypeC = __nv_bfloat16; + using LayoutB = typename LayoutDetails::Layout; + + static constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + static constexpr int ElementsPerAccessB = LayoutDetails::ElementsPerAccess; + static constexpr int ElementsPerAccessC = 128 / cutlass::sizeof_bits::value; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 256 / cutlass::sizeof_bits::value>; + + using Operator = typename LayoutDetails::Operator; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_int8_traits.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_int8_traits.h new file mode 100644 index 00000000000..3fd722994e2 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_int8_traits.h @@ -0,0 +1,57 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +template +struct Int8GemmArchTraits +{ + using OperatorClass = cutlass::arch::OpClassSimt; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; +}; + +// ======================= Turing Traits ============================== +template <> +struct Int8GemmArchTraits +{ + using OperatorClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 16>; +}; + +// ======================= Ampere Traits ============================== +template <> +struct Int8GemmArchTraits +{ + using OperatorClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h new file mode 100644 index 00000000000..1dbd0b1765f --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_splitk_gemm_grouped.h @@ -0,0 +1,207 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with + the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are + accommodated by exchanging A and B operands and assuming transposed layouts. Partial + specializations here choose 'device::GemmTransposed' to implement this functionality. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/complex.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/kernel/default_gemm_complex.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" + +#include "cutlass/layout/permute.h" + +#include "splitk_gemm_grouped.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Complex elementwise transformation on A operand + ComplexTransform TransformA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Complex elementwise transformation on B operand + ComplexTransform TransformB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Whether the schedule of problems to visit has been precomputed + GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly, + /// Operation performed by GEMM + typename Operator = typename device::DefaultGemmConfiguration::Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Permute result D + typename PermuteDLayout = layout::NoPermute, + /// + typename Enable = void> +struct DefaultSplitkGemmGrouped; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Real-valued GEMM kernels +// + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Layout type for C and D matrix operands + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Whether the schedule of problems to visit has been precomputed + GroupScheduleMode GroupScheduleMode_, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Permute result D + typename PermuteDLayout> +struct DefaultSplitkGemmGrouped::value>::type> +{ + + // If true, we must construct a 'transposed-and-exchanged' Mma operator. + static bool const kInternalTranspose = platform::is_same::value; + + using MapArguments = kernel::detail::MapArguments; + + // Define the default GEMM kernel + using DefaultGemmKernel = typename kernel::DefaultGemm::GemmKernel; + + /// Define the kernel in terms of the default kernel + using GemmKernel = kernel::SplitkGemmGrouped; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h new file mode 100644 index 00000000000..0baec58ea9a --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fpA_intB_gemm.h @@ -0,0 +1,566 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail +{ +template +inline constexpr bool dependent_false_v = false; +} + +template +struct GemmFpAIntB +{ + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static bool const kSplitKSerial = SplitKSerial; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Element; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Mma::LayoutC; + using ElementScale = ElementC; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformA; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + + /// Parameters structure + struct Arguments + { + GemmUniversalMode mode = GemmUniversalMode::kGemm; + + cutlass::gemm::GemmCoord problem_size; + int group_size; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorScale::TensorRef ref_scale; + typename Mma::IteratorScale::TensorRef ref_zero; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + + // Control serial split-k + int batch_count; + + typename EpilogueOutputOp::Params output_op; + + // For gather+scatter operations + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; + + // Included so we can use Gemm Universal + int batch_stride_D = 0; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Arguments() {} + + CUTLASS_HOST_DEVICE + Arguments(cutlass::gemm::GemmCoord const& problem_size, int const group_size, + typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorScale::TensorRef ref_scale, typename Mma::IteratorScale::TensorRef ref_zero, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D, int serial_split_k_factor, + typename EpilogueOutputOp::Params output_op = typename EpilogueOutputOp::Params(), + int const* gather_A_indices = nullptr, int const* gather_B_indices = nullptr, + int const* scatter_D_indices = nullptr) + : problem_size(problem_size) + , group_size(group_size) + , ref_A(ref_A) + , ref_B(ref_B) + , ref_scale(ref_scale) + , ref_zero(ref_zero) + , ref_C(ref_C) + , ref_D(ref_D) + , batch_count(serial_split_k_factor) + , output_op(output_op) + , gather_A_indices(gather_A_indices) + , gather_B_indices(gather_B_indices) + , scatter_D_indices(scatter_D_indices) + { + } + }; + + /// Parameters structure + struct Params + { + cutlass::gemm::GemmCoord problem_size; + int group_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::Params params_B; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorScale::Params params_scale; + typename Mma::IteratorScale::TensorRef ref_scale; + typename Mma::IteratorScale::TensorRef ref_zero; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + typename EpilogueOutputOp::Params output_op; + int* semaphore; + int gemm_k_size; + // For gather+scatter operations + int const* gather_A_indices; + int const* gather_B_indices; + int const* scatter_D_indices; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : swizzle_log_tile(0) + , semaphore(0) + , gemm_k_size(0) + { + } + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, int const gemm_k_size, + void* workspace = nullptr) + : problem_size(args.problem_size) + , group_size(args.group_size) + , grid_tiled_shape(grid_tiled_shape) + , swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)) + , params_A(args.ref_A.layout()) + , ref_A(args.ref_A) + , params_B(args.ref_B.layout()) + , ref_B(args.ref_B) + , params_scale(args.ref_scale.layout()) + , ref_scale(args.ref_scale) + , ref_zero(args.ref_zero) + , params_C(args.ref_C.layout()) + , ref_C(args.ref_C) + , params_D(args.ref_D.layout()) + , ref_D(args.ref_D) + , output_op(args.output_op) + , semaphore(static_cast(workspace)) + , gemm_k_size(gemm_k_size) + , gather_A_indices(args.gather_A_indices) + , gather_B_indices(args.gather_B_indices) + , scatter_D_indices(args.scatter_D_indices) + { + } + }; + + /// Shared memory storage structure + union SharedStorage + { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + GemmFpAIntB() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(Arguments const& args) + { + static int const kAlignmentA + = (platform::is_same>::value) ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB + = (platform::is_same>::value) ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorB::AccessType::kElements; + + static int const kAlignmentScale = Mma::IteratorScale::AccessType::kElements; + + static int const kAlignmentC = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Epilogue::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(args.ref_A, kAlignmentA)) + { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_B, kAlignmentB)) + { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_scale, kAlignmentScale)) + { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_zero, kAlignmentScale)) + { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_C, kAlignmentC)) + { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(args.ref_D, kAlignmentC)) + { + return Status::kErrorMisalignedOperand; + } + + if (!args.ref_scale.good()) + { + return Status::kErrorNotSupported; + } + + if constexpr (hasZero(Mma::QuantOp)) + { + if (!args.ref_zero.good()) + { + return Status::kErrorNotSupported; + } + } + else + { + if (args.ref_zero.good()) + { + return Status::kErrorNotSupported; + } + } + + if constexpr (isFinegrained(Mma::QuantOp)) + { + if (args.group_size != 64 && args.group_size != 128) + { + return Status::kErrorNotSupported; + } + } + + return Status::kSuccess; + } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) + { + + return 0; + } + + // Initializes the fine grained scale+bias iterator. Needed since the fine grained iterator + // has a different constructor signature than a regular cutlass iterator + template = true> + CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero, + typename IteratorScale::TensorCoord extent, int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, int group_size) + { + + return IteratorScale(params, pointer_scale, pointer_zero, extent, thread_id, threadblock_offset, group_size); + } + + template = true> + CUTLASS_DEVICE static IteratorScale initialize_scale(typename IteratorScale::Params const& params, + typename IteratorScale::Pointer pointer_scale, typename IteratorScale::Pointer pointer_zero, + typename IteratorScale::TensorCoord extent, int thread_id, + typename IteratorScale::TensorCoord const& threadblock_offset, int group_size) + { + + return IteratorScale(params, pointer_scale, extent, thread_id, threadblock_offset); + } + + CUTLASS_DEVICE + void run_kernel_(Params const& params, SharedStorage& shared_storage) + { + using LayoutB = typename Mma::IteratorB::Layout; + static_assert(platform::is_same::value && kInterleave == 1 + || platform::is_same::value && kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() + || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) + { + + return; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size, + }; + + cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave, + threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave}; + + typename MatrixCoord::Index fg_row_offset = threadblock_tile_offset.k() * params.gemm_k_size / 64; + typename MatrixCoord::Index scale_row_offset = isFinegrained(Mma::QuantOp) ? fg_row_offset : 0; + cutlass::MatrixCoord tb_offset_scale{scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN}; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params.params_A, params.ref_A.data(), + {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A, params.gather_A_indices); + + typename Mma::IteratorB iterator_B(params.params_B, params.ref_B.data(), + {problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, thread_idx, tb_offset_B, + params.gather_B_indices); + + typename MatrixCoord::Index scale_row_extent = isFinegrained(Mma::QuantOp) ? problem_size_k / 64 : 1; + typename Mma::IteratorScale iterator_scale = initialize_scale( + params.params_scale, params.ref_scale.data(), params.ref_zero.data(), + {scale_row_extent, params.problem_size.n()}, thread_idx, tb_offset_scale, params.group_size); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + if (!kSplitKSerial || gemm_k_iterations > 0) + { + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators); + } + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) + { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C(params.params_C, params.ref_C.data(), params.problem_size.mn(), + thread_idx, threadblock_offset, params.scatter_D_indices); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D(params.params_D, params.ref_D.data(), params.problem_size.mn(), + thread_idx, threadblock_offset, params.scatter_D_indices); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) + { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) + { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) + { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) + { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else + { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } + + template + CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) + { + if constexpr (platform::is_same::value) + { + run_kernel_(params, shared_storage); + } + else + { + CUTLASS_NOT_IMPLEMENTED(); + } + } + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) + { +#if defined(__CUDA_ARCH__) +#if (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ == 890) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 900) + CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels. +#else + static_assert( + false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); +#endif +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh new file mode 100644 index 00000000000..1bd0a3f11a8 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel.cuh @@ -0,0 +1,218 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include +#include + +namespace fused_moe +{ +template +struct Fused_Moe_Kernel_sm80 +{ + static constexpr int kMaxTileM = MaxTileM_; + static constexpr int kTileN = isGateActivation(activation_type_) ? TileN_ / 2 : TileN_; + static constexpr int kTileK = TileK_; + static constexpr int kStages = Stages_; + static constexpr Activation_Type activation_type = activation_type_; + + using ElementInput = ElementInput_; + using ElementWeight = ElementWeight_; + using ElementOutput = ElementOutput_; + using BaseKernelTraits = Fused_Moe_Kernel_traits_sm80; + using Routine_Arguments = Routine_Arguments; + using Routine_Params = Routine_Params; + using ProblemVisitor + = cutlass::gemm::kernel::MoeProblemVisitor, false>, + cutlass::gemm::GemmShape, cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, + BaseKernelTraits::kThreadCount, BaseKernelTraits::kThreadCount>; + + struct Arguments + { + Routine_Arguments routine_args; + int problem_count{}; + int threadblock_count{}; + }; + + struct Params + { + Routine_Params routine_params; + int threadblock_count{}; + typename ProblemVisitor::Params problem_visitor_param; + }; + + using BaseKernelTraits_m16 = Fused_Moe_Kernel_traits_sm80; + static constexpr bool use_m16 = TileK_ >= 64; // use tileshape m = 16 when original tileshape k >= 64 + + static constexpr int kSmemSize = use_m16 + ? (BaseKernelTraits::kSmemSize > BaseKernelTraits_m16::kSmemSize ? BaseKernelTraits::kSmemSize + : BaseKernelTraits_m16::kSmemSize) + : BaseKernelTraits::kSmemSize; + static constexpr int kThreadCount = BaseKernelTraits::kThreadCount; + + static constexpr bool can_implement(int const avaliable_smem_size) + { + return BaseKernelTraits::can_implement(avaliable_smem_size); + } + + static Params to_underlying_arguments(Arguments const& args) + { + return { + {args.routine_args.ptr_input, args.routine_args.ptr_fc1, args.routine_args.ptr_bias, + args.routine_args.ptr_output, args.routine_args.total_tokens_including_expert, args.routine_args.gemm_n, + args.routine_args.gemm_k, args.routine_args.num_expert, args.routine_args.bias_is_broadcast}, + args.threadblock_count, + {args.routine_args.total_tokens_including_expert, args.routine_args.gemm_n, args.routine_args.gemm_k, + args.problem_count, nullptr, 0}}; + } + + CUTE_DEVICE + void run_device(Params const& params) + { +#define ROUTINE_PATH(kTileM_size) \ + { \ + constexpr int kTileM = use_m16 ? (kTileM_size) : ((kTileM_size) == 16 ? 32 : (kTileM_size)); \ + using RoutineTraits = Fused_Moe_Kernel_routine_sm80; \ + RoutineTraits routine{}; \ + int const block_m_idx = (block_m_idx_temp) *kMaxTileM / kTileM; \ + routine.run_routine(params.routine_params, problem_index, block_m_idx, block_n_idx, gemm_m); \ + } + typename ProblemVisitor::SharedStorage dummy_storage{}; + ProblemVisitor problem_visitor(params.problem_visitor_param, dummy_storage, blockIdx.x); + while (problem_visitor.next_tile()) + { + auto problem_size = problem_visitor.problem_size(); + auto grid_size = problem_visitor.grid_shape(problem_size); + auto problem_index = problem_visitor.problem_index(); + int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); + int const gemm_m = problem_size.m(); + const int32_t block_m_idx_temp = cta_idx / grid_size.n(); + const int32_t block_n_idx = cta_idx % grid_size.n(); + + int const residue_m = gemm_m - kMaxTileM * block_m_idx_temp; + if (residue_m > kMaxTileM / 2) + { + using RoutineTraits = Fused_Moe_Kernel_routine_sm80; + RoutineTraits routine{}; + routine.run_routine(params.routine_params, problem_index, block_m_idx_temp, block_n_idx, gemm_m); + } + else + { + + if constexpr (kMaxTileM >= 128) + { + if (residue_m > 32) + { + ROUTINE_PATH(64); + } + else if (residue_m > 16) + { + ROUTINE_PATH(32); + } + else + { + // TODO: use cuda core gemm here + ROUTINE_PATH(16); + } + } + else if (kMaxTileM == 64) + { + if (residue_m > 16) + { + ROUTINE_PATH(32); + } + else + { + // TODO: use cuda core gemm here + ROUTINE_PATH(16); + } + } + else if (kMaxTileM == 32) + { + // TODO: use cuda core gemm here + ROUTINE_PATH(16); + } + else + { + // TODO: use cuda core gemm here + ROUTINE_PATH(16); + } + } + problem_visitor.advance(gridDim.x); + } +#undef ROUTINE_PATH + } +}; + +template +__global__ void run_global(__grid_constant__ typename GemmType::Params const params) +{ + GemmType gemm; + gemm.run_device(params); +} + +/// Computes the maximum number of active blocks per multiprocessor +template +static int fused_gemm_maximum_active_blocks(int smem_capacity = -1) +{ + + CUTLASS_TRACE_HOST("BaseGrouped::maximum_active_blocks()"); + + constexpr int smem_size = GemmType::kSmemSize; + + CUTLASS_TRACE_HOST(" smem_size: " << smem_size << " bytes"); + + cudaError_t result; + if (smem_size > (48 << 10)) + { + result = cudaFuncSetAttribute(run_global, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) + { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error " << cudaGetErrorString(result)); + return -1; + } + } + + int max_active_blocks = -1; + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, run_global, GemmType::kThreadCount, smem_size); + + if (result != cudaSuccess) + { + // Call cudaGetLastError() to clear the error bit + result = cudaGetLastError(); + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error " << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; +} +} // namespace fused_moe diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh new file mode 100644 index 00000000000..4c46a541efd --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_routine.cuh @@ -0,0 +1,799 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace fused_moe +{ + +template +struct Fused_Moe_Kernel_routine_sm80; + +template +struct Fused_Moe_Kernel_routine_sm80> +{ + using KT = Fused_Moe_Kernel_traits_sm80; + using Params = Routine_Params; + + CUTE_DEVICE auto gmem_tensor_init(int const problem_index, int const gemm_m, Params const& params) + { + using X = cute::Underscore; + + int const M = gemm_m; + int const N1 = params.gemm_n; + int const K1 = params.gemm_k; + bool const bias_is_broadcast = params.bias_is_broadcast; + + int const row_jump = ((problem_index == 0) ? 0 : params.total_tokens_including_expert[problem_index - 1]); + typename KT::ElementInput const* ptr_input_ = params.ptr_input + row_jump * K1; + typename KT::ElementWeight const* ptr_fc1_gate_ + = params.ptr_fc1 + (2 * problem_index + 1) * N1 * K1; // TODO: we only focus on gated activation.. + typename KT::ElementWeight const* ptr_fc1_ + = params.ptr_fc1 + 2 * problem_index * N1 * K1; // TODO: we only focus on gated activation.. + typename KT::ElementInput const* ptr_bias_ = (params.ptr_bias == nullptr) + ? nullptr + : (bias_is_broadcast ? params.ptr_bias + 2 * problem_index * N1 : params.ptr_bias + 2 * row_jump * N1); + typename KT::ElementInput const* ptr_bias_gate_ = (params.ptr_bias == nullptr) + ? nullptr + : (bias_is_broadcast ? params.ptr_bias + (2 * problem_index + 1) * N1 + : params.ptr_bias + (2 * row_jump + 1) * N1); + typename KT::ElementOutput* ptr_output_ = params.ptr_output + row_jump * N1; + + cute::Tensor mInput_mk + = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_input_)), + cute::make_shape(M, K1), cute::make_stride(K1, cute::_1{})); + + cute::Tensor mfc1_gate_nk + = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_fc1_gate_)), + cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{})); + + cute::Tensor mfc1_nk + = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_fc1_)), + cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{})); + + cute::Tensor mBias_mn = cute::make_tensor( + cute::make_gmem_ptr(static_cast(ptr_bias_)), cute::make_shape(M, N1), + cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1 * 2, + cute::_1{})); // trick: bias shape is [1, N], but we use [M, N]. + + cute::Tensor mBias_gate_mn = cute::make_tensor( + cute::make_gmem_ptr(static_cast(ptr_bias_gate_)), cute::make_shape(M, N1), + cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1 * 2, + cute::_1{})); // trick: bias shape is [1, N], but we use [M, N]. + + cute::Tensor mOutput_mn + = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_output_)), + cute::make_shape(M, N1), cute::make_stride(N1, cute::_1{})); + + cute::Tensor gInput_mk = cute::local_tile(mInput_mk, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_K, m, k) + cute::Tensor gfc1_gate_nk = cute::local_tile(mfc1_gate_nk, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_N, BLK_K, n, k) + cute::Tensor gfc1_nk = cute::local_tile(mfc1_nk, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_N, BLK_K, n, k) + + cute::Tensor gBias_mn = cute::local_tile(mBias_mn, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) + + cute::Tensor gBias_gate_mn = cute::local_tile(mBias_gate_mn, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) + + cute::Tensor gOutput_mn = cute::local_tile(mOutput_mn, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) + + return cute::make_tuple(gInput_mk, gfc1_gate_nk, gfc1_nk, gBias_mn, gBias_gate_mn, gOutput_mn); + } + + // be careful, m_idx will change when use another tile shape.. + CUTE_DEVICE void run_routine( + Params const& params, int const problem_index, int const block_m_idx, int const block_n_idx, int const gemm_m) + { + extern __shared__ char smem_[]; + typename KT::SharedStorage& shared_storage = *reinterpret_cast(smem_); + int const thread_idx = threadIdx.x; + bool const bias_is_broadcast = params.bias_is_broadcast; + // gmem tensor partition .. + auto [gInput_mk, gfc1_gate_nk, gfc1_nk, gBias_mn, gBias_gate_mn, gOutput_mn] + = gmem_tensor_init(problem_index, gemm_m, params); + int const residue_m = gemm_m - block_m_idx * cute::size<0>(gInput_mk); + auto const n_tile_count = cute::size<2>(gfc1_gate_nk); + + // smem tensor .. + cute::Tensor sInput = cute::make_tensor( + cute::make_smem_ptr(shared_storage.smem_input.data()), typename KT::SmemLayoutA{}); // (BLK_M, BLK_K, Stage) + cute::Tensor sfc1_weight = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_weight.data()), + typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage) + cute::Tensor sfc1_gate_weight + = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_gate_weight.data()), + typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage) + cute::Tensor sO = cute::make_tensor( + cute::make_smem_ptr(shared_storage.smem_o.data()), typename KT::SmemLayoutO{}); // (BLK_M, BLK_N) + + // (1) first step, get the fc1_res and fc1_gate + + // (1.1) get partition for gmem -> smem + cute::Tensor gInput = gInput_mk(cute::_, cute::_, block_m_idx, cute::_); // (BLK_M, BLK_K, k) + cute::Tensor gfc1 = gfc1_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k) + cute::Tensor gfc1g = gfc1_gate_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k) + + typename KT::GmemTiledCopyA gmem_tiled_copy_A; + typename KT::GmemTiledCopyB gmem_tiled_copy_B; + auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx); + auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx); + + cute::Tensor tInputgInput = gmem_thr_copy_A.partition_S(gInput); // (ACPY,ACPY_M,ACPY_K,k) + cute::Tensor tInputsInput = gmem_thr_copy_A.partition_D(sInput); // (ACPY,ACPY_M,ACPY_K,Stage) + cute::Tensor tfc1gfc1 = gmem_thr_copy_B.partition_S(gfc1); // (BCPY,BCPY_N,BCPY_K,k) + cute::Tensor tfc1sfc1 = gmem_thr_copy_B.partition_D(sfc1_weight); // (BCPY,BCPY_N,BCPY_K,Stage) + cute::Tensor tfc1ggfc1g = gmem_thr_copy_B.partition_S(gfc1g); // (BCPY,BCPY_N,BCPY_K,k) + cute::Tensor tfc1gsfc1g = gmem_thr_copy_B.partition_D(sfc1_gate_weight); // (BCPY,BCPY_N,BCPY_K,Stage) + + // Allocate predicate tensors for input and fc weight (actually we only need input predicate tensor) + cute::Tensor tInputpInput + = cute::make_tensor(cute::make_shape(cute::size<1>(tInputsInput), cute::size<2>(tInputsInput)), + cute::Stride{}); + // Construct identity layout for sInput + cute::Tensor cInput = make_identity_tensor( + make_shape(cute::size<0>(sInput), cute::size<1>(sInput))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + + // Repeat the partitioning with identity layouts + cute::Tensor tInputcInput = gmem_thr_copy_A.partition_S(cInput); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + + // Set predicates for m bounds + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < cute::size<0>(tInputpInput); ++m) + { + tInputpInput(m, 0) = cute::get<0>(tInputcInput(0, m, 0)) < residue_m; // blk_m coord < residue_m + } + + // (1.2) prefetch gmem -> smem + cute::clear(tInputsInput); // we don't need to clear tfc1sfc1.. + auto k_tile_iter = cute::make_coord_iterator(cute::size<2>(gInput)); // emm, iter start from 0 + int k_tile_count = cute::size<2>(gInput); + CUTLASS_PRAGMA_UNROLL + for (int k_pipe = 0; k_pipe < KT::Stages - 1; ++k_pipe) + { + if (k_tile_count <= 0) + { + cute::clear(tInputpInput); + } + // cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + // tInputsInput(cute::_, cute::_, cute::_, k_pipe)); + // use copy_if + cute::copy_if(gmem_tiled_copy_A, tInputpInput, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + tInputsInput(cute::_, cute::_, cute::_, k_pipe)); + cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter), + tfc1sfc1(cute::_, cute::_, cute::_, k_pipe)); + cute::copy(gmem_tiled_copy_B, tfc1ggfc1g(cute::_, cute::_, cute::_, *k_tile_iter), + tfc1gsfc1g(cute::_, cute::_, cute::_, k_pipe)); + cute::cp_async_fence(); + k_tile_count--; + if (k_tile_count > 0) + { + ++k_tile_iter; + } + } + + // (1.3) get partition for rf + typename KT::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + cute::Tensor tOrInput = thr_mma.partition_fragment_A(sInput(cute::_, cute::_, 0)); // (MMA,MMA_M,MMA_K) + cute::Tensor tOrfc1 = thr_mma.partition_fragment_B(sfc1_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K) + cute::Tensor tOrfc1g = thr_mma.partition_fragment_B(sfc1_gate_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K) + + cute::Tensor accum + = cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N) + cute::Tensor accum_gate + = cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N) + cute::clear(accum); + cute::clear(accum_gate); + // checkout the shape + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum_gate)); // MMA_M + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum)); // MMA_N + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum_gate)); // MMA_N + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1g) == cute::size<2>(accum)); // MMA_N + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1g) == cute::size<2>(accum_gate)); // MMA_N + CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1)); // MMA_K + CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1g)); // MMA_K + CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_A) == cute::size(tiled_mma)); + CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_B) == cute::size(tiled_mma)); + + // (1.4)retiling the smem and rf for copy.. + auto smem_tiled_copy_A = cute::make_tiled_copy_A(typename KT::SmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + cute::Tensor tOsInput = smem_thr_copy_A.partition_S(sInput); // (CPY,CPY_M,CPY_K,Stage) + cute::Tensor tOrInput_copy_view = smem_thr_copy_A.retile_D(tOrInput); // (CPY,CPY_M,CPY_K) + CUTE_STATIC_ASSERT_V(cute::size<1>(tOsInput) == cute::size<1>(tOrInput_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(cute::size<2>(tOsInput) == cute::size<2>(tOrInput_copy_view)); // CPY_K + + auto smem_tiled_copy_B = cute::make_tiled_copy_B(typename KT::SmemCopyAtomB{}, tiled_mma); + auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); + cute::Tensor tOsfc1 = smem_thr_copy_B.partition_S(sfc1_weight); // (CPY,CPY_N,CPY_K,Stage) + cute::Tensor tOrfc1_copy_view = smem_thr_copy_B.retile_D(tOrfc1); // (CPY,CPY_N,CPY_K) + cute::Tensor tOsfc1g = smem_thr_copy_B.partition_S(sfc1_gate_weight); // (CPY,CPY_N,CPY_K,Stage) + cute::Tensor tOrfc1g_copy_view = smem_thr_copy_B.retile_D(tOrfc1g); // (CPY,CPY_N,CPY_K) + CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1) == cute::size<1>(tOrfc1_copy_view)); // CPY_N + CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1) == cute::size<2>(tOrfc1_copy_view)); // CPY_K + CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1g) == cute::size<1>(tOrfc1g_copy_view)); // CPY_N + CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1g) == cute::size<2>(tOrfc1g_copy_view)); // CPY_K + + // (1.5) mainloop + // Current pipe index in smem to read from + int smem_pipe_read = 0; + // Current pipe index in smem to write to + int smem_pipe_write = KT::Stages - 1; + + cute::Tensor tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); + cute::Tensor tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); + cute::Tensor tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read); + + constexpr int K_BLOCK_MAX = cute::size<2>(tOrInput); + // prefetch register pipeline + if constexpr (K_BLOCK_MAX > 1) + { + cute::cp_async_wait(); + __syncthreads(); + + // Prefetch the first rmem from the first k-tile + cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, cute::Int<0>{}), + tOrInput_copy_view(cute::_, cute::_, cute::Int<0>{})); + cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, cute::Int<0>{}), + tOrfc1_copy_view(cute::_, cute::_, cute::Int<0>{})); + cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, cute::Int<0>{}), + tOrfc1g_copy_view(cute::_, cute::_, cute::Int<0>{})); + } + // k loop for mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) + { + cute::for_each(cute::make_int_sequence{}, + [&](auto k_block) + { + if (k_block == K_BLOCK_MAX - 1) + { + tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); + tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); + tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read); + cute::cp_async_wait(); + __syncthreads(); + } + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; + cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), + tOrInput_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), + tOrfc1_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next), + tOrfc1g_copy_view(cute::_, cute::_, k_block_next)); + // Copy gmem to smem before computing gemm on each k-pipe + if (k_block == 0) + { + // cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + // tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::copy_if(gmem_tiled_copy_A, tInputpInput, + tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter), + tfc1sfc1(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::copy(gmem_tiled_copy_B, tfc1ggfc1g(cute::_, cute::_, cute::_, *k_tile_iter), + tfc1gsfc1g(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::cp_async_fence(); + if (k_tile_count - 1 > 0) + { + ++k_tile_iter; + } + + // Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe) + smem_pipe_write = smem_pipe_read; + ++smem_pipe_read; + smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read; + } + // Thread-level register gemm for k_block + cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), + accum); + cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block), + tOrfc1g(cute::_, cute::_, k_block), accum_gate); + }); + } + + // load tail + cute::for_each(cute::make_int_sequence{}, + [&](auto WaitIndex) + { + k_tile_count--; + using WaitIndex_t = decltype(WaitIndex); + cute::for_each(cute::make_int_sequence{}, + [&](auto k_block) + { + if (k_block == K_BLOCK_MAX - 1) + { + tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); + tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); + tOsfc1g_p = tOsfc1g(cute::_, cute::_, cute::_, smem_pipe_read); + cute::cp_async_wait(); + __syncthreads(); + } + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; + cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), + tOrInput_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), + tOrfc1_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next), + tOrfc1g_copy_view(cute::_, cute::_, k_block_next)); + if (k_block == 0) + { + // only update smem_pipe_read + ++smem_pipe_read; + smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read; + } + // Thread-level register gemm for k_block + cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), + tOrfc1(cute::_, cute::_, k_block), accum); + cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block), + tOrfc1g(cute::_, cute::_, k_block), accum_gate); + }); + }); + // mma tail + cute::for_each(cute::make_int_sequence{}, + [&](auto k_block) + { + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; + cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), + tOrInput_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), + tOrfc1_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1g_p(cute::_, cute::_, k_block_next), + tOrfc1g_copy_view(cute::_, cute::_, k_block_next)); + // Thread-level register gemm for k_block + cute::gemm( + tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), accum); + cute::gemm(tiled_mma, accum_gate, tOrInput(cute::_, cute::_, k_block), + tOrfc1g(cute::_, cute::_, k_block), accum_gate); + }); + // if (cute::thread0()) { + // cute::print(accum_gate(0, 0, 0)); + // printf("\n"); + // } + // (2) add bias if it has.. + if (params.ptr_bias != nullptr) + { + cute::Tensor gBias = gBias_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx); + cute::Tensor gBias_gate = gBias_gate_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx); + cute::Tensor tOgBias = thr_mma.partition_C(gBias); + cute::Tensor tOgBiasg = thr_mma.partition_C(gBias_gate); + for (int i = 0; i < cute::size(accum); i++) + { + accum(i) += tOgBias(i); + accum_gate(i) += tOgBiasg(i); + } + } + + // (3) calculate swiglu + using ActivationFn = typename KT::ActivationFn; + ActivationFn fn{}; + CUTLASS_PRAGMA_UNROLL + for (int temp_iter = 0; temp_iter < cute::size(accum); temp_iter++) + { + accum(temp_iter) = fn(accum_gate(temp_iter)) * accum(temp_iter); + } + + // (4) push all the result to smem + // (4.1) convert result from ElementAccum to ElementInput + cute::Tensor temp_accum = util_convert_type(accum); + // if (cute::thread0()) { + // cute::print(temp_accum(0, 0, 0)); + // printf("\n"); + // } + // (4.2) retile rf and smem for copy back.. + auto smem_tiled_copy_O = cute::make_tiled_copy_C(typename KT::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); + // cute::clear(sO); + cute::Tensor taccumrO = smem_thr_copy_O.retile_S(temp_accum); + cute::Tensor taccumsO = smem_thr_copy_O.partition_D(sO); + + // (4.3) copy rf result to smem (TODO: maybe use forloop for better performance..) + cute::copy(smem_tiled_copy_O, taccumrO, taccumsO); + __syncthreads(); + + // (4.4) sO -> rO -> gO + + typename KT::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + // auto gmem_thr_copy_Bias = gmem_tiled_copy_O.get_thread_slice(thread_idx % KT::kGmemTrheadsPerRow); // + // remember, for all the threads in the same col, they have the same idx for bias.. + cute::Tensor gO = gOutput_mn(cute::_, cute::_, block_m_idx, block_n_idx); + // cute::Tensor gBias = gBias_mn(cute::_, cute::_, 0, block_n_idx); // bias only have one row.. + auto tOsO = gmem_thr_copy_O.partition_S(sO); + auto tOgO = gmem_thr_copy_O.partition_D(gO); + // auto tOgBias = gmem_thr_copy_O.partition_D(gBias); + cute::Tensor cOutput = cute::make_identity_tensor( + cute::make_shape(cute::size<0>(typename KT::TileShape{}), cute::size<1>(typename KT::TileShape{}))); + cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cOutput); + cute::Tensor tOrO = cute::make_tensor(cute::shape(tOgO)); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < cute::size<1>(tOgO); ++m) + { + if (cute::get<0>(tOcO(0, m, 0)) < residue_m) + { + cute::copy(gmem_tiled_copy_O, tOrO(cute::_, m, cute::_), tOgO(cute::_, m, cute::_)); + } + } + } +}; + +template +struct Fused_Moe_Kernel_routine_sm80> +{ + + using KT = Fused_Moe_Kernel_traits_sm80; + using Params = Routine_Params; + + CUTE_DEVICE auto gmem_tensor_init(int const problem_index, int const gemm_m, Params const& params) + { + using X = cute::Underscore; + + int const M = gemm_m; + int const N1 = params.gemm_n; + int const K1 = params.gemm_k; + bool const bias_is_broadcast = params.bias_is_broadcast; + + int const row_jump = ((problem_index == 0) ? 0 : params.total_tokens_including_expert[problem_index - 1]); + typename KT::ElementInput const* ptr_input_ = params.ptr_input + row_jump * K1; + typename KT::ElementWeight const* ptr_fc1_ = params.ptr_fc1 + problem_index * N1 * K1; + typename KT::ElementInput const* ptr_bias_ = (params.ptr_bias == nullptr) + ? nullptr + : (bias_is_broadcast ? params.ptr_bias + problem_index * N1 : params.ptr_bias + row_jump * N1); + typename KT::ElementOutput* ptr_output_ = params.ptr_output + row_jump * N1; + + cute::Tensor mInput_mk + = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_input_)), + cute::make_shape(M, K1), cute::make_stride(K1, cute::_1{})); + + cute::Tensor mfc1_nk + = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_fc1_)), + cute::make_shape(N1, K1), cute::make_stride(K1, cute::_1{})); + + cute::Tensor mBias_mn = cute::make_tensor( + cute::make_gmem_ptr(static_cast(ptr_bias_)), cute::make_shape(M, N1), + cute::make_stride(bias_is_broadcast ? cute::Int<0>{} : N1, + cute::_1{})); // trick: bias shape is [1, N], but we use [M, N]. + + cute::Tensor mOutput_mn + = cute::make_tensor(cute::make_gmem_ptr(static_cast(ptr_output_)), + cute::make_shape(M, N1), cute::make_stride(N1, cute::_1{})); + + cute::Tensor gInput_mk = cute::local_tile(mInput_mk, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_K, m, k) + cute::Tensor gfc1_nk = cute::local_tile(mfc1_nk, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_N, BLK_K, n, k) + + cute::Tensor gBias_mn = cute::local_tile(mBias_mn, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) + + cute::Tensor gOutput_mn = cute::local_tile(mOutput_mn, typename KT::TileShape{}, + cute::make_coord(cute::_, cute::_, cute::_), cute::Step{}); // (BLK_M, BLK_N, m, n) + + return cute::make_tuple(gInput_mk, gfc1_nk, gBias_mn, gOutput_mn); + } + + // be careful, m_idx will change when use another tile shape.. + CUTE_DEVICE void run_routine( + Params const& params, int const problem_index, int const block_m_idx, int const block_n_idx, int const gemm_m) + { + extern __shared__ char smem_[]; + typename KT::SharedStorage& shared_storage = *reinterpret_cast(smem_); + int const thread_idx = threadIdx.x; + bool const bias_is_broadcast = params.bias_is_broadcast; + // gmem tensor partition .. + auto [gInput_mk, gfc1_nk, gBias_mn, gOutput_mn] = gmem_tensor_init(problem_index, gemm_m, params); + int const residue_m = gemm_m - block_m_idx * cute::size<0>(gInput_mk); + auto const n_tile_count = cute::size<2>(gfc1_nk); + + // smem tensor .. + cute::Tensor sInput = cute::make_tensor( + cute::make_smem_ptr(shared_storage.smem_input.data()), typename KT::SmemLayoutA{}); // (BLK_M, BLK_K, Stage) + cute::Tensor sfc1_weight = cute::make_tensor(cute::make_smem_ptr(shared_storage.smem_fc1_weight.data()), + typename KT::SmemLayoutB{}); // (BLK_N, BLK_K, Stage) + cute::Tensor sO = cute::make_tensor( + cute::make_smem_ptr(shared_storage.smem_o.data()), typename KT::SmemLayoutO{}); // (BLK_M, BLK_N) + + // (1) first step, get the fc1_res and fc1_gate + + // (1.1) get partition for gmem -> smem + cute::Tensor gInput = gInput_mk(cute::_, cute::_, block_m_idx, cute::_); // (BLK_M, BLK_K, k) + cute::Tensor gfc1 = gfc1_nk(cute::_, cute::_, block_n_idx, cute::_); // (BLK_N, BLK_K, k) + + typename KT::GmemTiledCopyA gmem_tiled_copy_A; + typename KT::GmemTiledCopyB gmem_tiled_copy_B; + auto gmem_thr_copy_A = gmem_tiled_copy_A.get_slice(thread_idx); + auto gmem_thr_copy_B = gmem_tiled_copy_B.get_slice(thread_idx); + + cute::Tensor tInputgInput = gmem_thr_copy_A.partition_S(gInput); // (ACPY,ACPY_M,ACPY_K,k) + cute::Tensor tInputsInput = gmem_thr_copy_A.partition_S(sInput); // (ACPY,ACPY_M,ACPY_K,Stage) + cute::Tensor tfc1gfc1 = gmem_thr_copy_B.partition_S(gfc1); // (BCPY,BCPY_N,BCPY_K,k) + cute::Tensor tfc1sfc1 = gmem_thr_copy_B.partition_D(sfc1_weight); // (BCPY,BCPY_N,BCPY_K,Stage) + + // Allocate predicate tensors for input and fc weight (actually we only need input predicate tensor) + cute::Tensor tInputpInput + = cute::make_tensor(cute::make_shape(cute::size<1>(tInputsInput), cute::size<2>(tInputsInput)), + cute::Stride{}); + // Construct identity layout for sInput + cute::Tensor cInput = make_identity_tensor( + make_shape(cute::size<0>(sInput), cute::size<1>(sInput))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + + // Repeat the partitioning with identity layouts + cute::Tensor tInputcInput = gmem_thr_copy_A.partition_S(cInput); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + + // Set predicates for m bounds + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < cute::size<0>(tInputpInput); ++m) + { + tInputpInput(m, 0) = cute::get<0>(tInputcInput(0, m, 0)) < residue_m; // blk_m coord < residue_m + } + + // (1.2) prefetch gmem -> smem + cute::clear(tInputsInput); // we don't need to clear tfc1sfc1.. + auto k_tile_iter = cute::make_coord_iterator(cute::size<2>(gInput)); // emm, iter start from 0 + int k_tile_count = cute::size<2>(gInput); + CUTLASS_PRAGMA_UNROLL + for (int k_pipe = 0; k_pipe < KT::Stages - 1; ++k_pipe) + { + if (k_tile_count <= 0) + { + cute::clear(tInputpInput); + } + // cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + // tInputsInput(cute::_, cute::_, cute::_, k_pipe)); + // use copy_if + cute::copy_if(gmem_tiled_copy_A, tInputpInput, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + tInputsInput(cute::_, cute::_, cute::_, k_pipe)); + cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter), + tfc1sfc1(cute::_, cute::_, cute::_, k_pipe)); + cute::cp_async_fence(); + k_tile_count--; + if (k_tile_count > 0) + { + ++k_tile_iter; + } + } + + // (1.3) get partition for rf + typename KT::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + cute::Tensor tOrInput = thr_mma.partition_fragment_A(sInput(cute::_, cute::_, 0)); // (MMA,MMA_M,MMA_K) + cute::Tensor tOrfc1 = thr_mma.partition_fragment_B(sfc1_weight(cute::_, cute::_, 0)); // (MMA,MMA_N,MMA_K) + + cute::Tensor accum + = cute::partition_fragment_C(tiled_mma, cute::take<0, 2>(typename KT::TileShape{})); // (MMA,MMA_M,MMA_N) + cute::clear(accum); + // checkout the shape + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrInput) == cute::size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(cute::size<1>(tOrfc1) == cute::size<2>(accum)); // MMA_N + CUTE_STATIC_ASSERT_V(cute::size<2>(tOrInput) == cute::size<2>(tOrfc1)); // MMA_K + CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_A) == cute::size(tiled_mma)); + CUTE_STATIC_ASSERT_V(cute::size(gmem_tiled_copy_B) == cute::size(tiled_mma)); + + // (1.4)retiling the smem and rf for copy.. + auto smem_tiled_copy_A = cute::make_tiled_copy_A(typename KT::SmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + cute::Tensor tOsInput = smem_thr_copy_A.partition_S(sInput); // (CPY,CPY_M,CPY_K,Stage) + cute::Tensor tOrInput_copy_view = smem_thr_copy_A.retile_D(tOrInput); // (CPY,CPY_M,CPY_K) + CUTE_STATIC_ASSERT_V(cute::size<1>(tOsInput) == cute::size<1>(tOrInput_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(cute::size<2>(tOsInput) == cute::size<2>(tOrInput_copy_view)); // CPY_K + + auto smem_tiled_copy_B = cute::make_tiled_copy_B(typename KT::SmemCopyAtomB{}, tiled_mma); + auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); + cute::Tensor tOsfc1 = smem_thr_copy_B.partition_S(sfc1_weight); // (CPY,CPY_N,CPY_K,Stage) + cute::Tensor tOrfc1_copy_view = smem_thr_copy_B.retile_D(tOrfc1); // (CPY,CPY_N,CPY_K) + CUTE_STATIC_ASSERT_V(cute::size<1>(tOsfc1) == cute::size<1>(tOrfc1_copy_view)); // CPY_N + CUTE_STATIC_ASSERT_V(cute::size<2>(tOsfc1) == cute::size<2>(tOrfc1_copy_view)); // CPY_K + + // (1.5) mainloop + // Current pipe index in smem to read from + int smem_pipe_read = 0; + // Current pipe index in smem to write to + int smem_pipe_write = KT::Stages - 1; + + cute::Tensor tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); + cute::Tensor tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); + + constexpr int K_BLOCK_MAX = cute::size<2>(tOrInput); + // prefetch register pipeline + if constexpr (K_BLOCK_MAX > 1) + { + cute::cp_async_wait(); + __syncthreads(); + + // Prefetch the first rmem from the first k-tile + cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, cute::Int<0>{}), + tOrInput_copy_view(cute::_, cute::_, cute::Int<0>{})); + cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, cute::Int<0>{}), + tOrfc1_copy_view(cute::_, cute::_, cute::Int<0>{})); + } + // k loop for mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) + { + cute::for_each(cute::make_int_sequence{}, + [&](auto k_block) + { + if (k_block == K_BLOCK_MAX - 1) + { + tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); + tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); + cute::cp_async_wait(); + __syncthreads(); + } + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; + cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), + tOrInput_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), + tOrfc1_copy_view(cute::_, cute::_, k_block_next)); + // Copy gmem to smem before computing gemm on each k-pipe + if (k_block == 0) + { + // cute::copy(gmem_tiled_copy_A, tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + // tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::copy_if(gmem_tiled_copy_A, tInputpInput, + tInputgInput(cute::_, cute::_, cute::_, *k_tile_iter), + tInputsInput(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::copy(gmem_tiled_copy_B, tfc1gfc1(cute::_, cute::_, cute::_, *k_tile_iter), + tfc1sfc1(cute::_, cute::_, cute::_, smem_pipe_write)); + cute::cp_async_fence(); + if (k_tile_count - 1 > 0) + { + ++k_tile_iter; + } + + // Advance the pipe -- Doing it here accounts for K_BLOCK_MAX = 1 (no rmem pipe) + smem_pipe_write = smem_pipe_read; + ++smem_pipe_read; + smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read; + } + // Thread-level register gemm for k_block + cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), + accum); + }); + } + // load tail + cute::for_each(cute::make_int_sequence{}, + [&](auto WaitIndex) + { + k_tile_count--; + using WaitIndex_t = decltype(WaitIndex); + cute::for_each(cute::make_int_sequence{}, + [&](auto k_block) + { + if (k_block == K_BLOCK_MAX - 1) + { + tOsInput_p = tOsInput(cute::_, cute::_, cute::_, smem_pipe_read); + tOsfc1_p = tOsfc1(cute::_, cute::_, cute::_, smem_pipe_read); + cute::cp_async_wait(); + __syncthreads(); + } + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; + cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), + tOrInput_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), + tOrfc1_copy_view(cute::_, cute::_, k_block_next)); + if (k_block == 0) + { + // only update smem_pipe_read + ++smem_pipe_read; + smem_pipe_read = (smem_pipe_read == KT::Stages) ? 0 : smem_pipe_read; + } + // Thread-level register gemm for k_block + cute::gemm(tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), + tOrfc1(cute::_, cute::_, k_block), accum); + }); + }); + // mma tail + cute::for_each(cute::make_int_sequence{}, + [&](auto k_block) + { + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + cute::_1{}) % K_BLOCK_MAX; + cute::copy(smem_tiled_copy_A, tOsInput_p(cute::_, cute::_, k_block_next), + tOrInput_copy_view(cute::_, cute::_, k_block_next)); + cute::copy(smem_tiled_copy_B, tOsfc1_p(cute::_, cute::_, k_block_next), + tOrfc1_copy_view(cute::_, cute::_, k_block_next)); + // Thread-level register gemm for k_block + cute::gemm( + tiled_mma, accum, tOrInput(cute::_, cute::_, k_block), tOrfc1(cute::_, cute::_, k_block), accum); + }); + // if (cute::thread0()) { + // cute::print(accum_gate(0, 0, 0)); + // printf("\n"); + // } + // (2) add bias if it has.. + if (params.ptr_bias != nullptr) + { + cute::Tensor gBias = gBias_mn(cute::_, cute::_, bias_is_broadcast ? 0 : block_m_idx, block_n_idx); + cute::Tensor tOgBias = thr_mma.partition_C(gBias); + for (int i = 0; i < cute::size(accum); i++) + { + accum(i) += tOgBias(i); + } + } + // (3) calculate swiglu + using ActivationFn = typename KT::ActivationFn; + ActivationFn fn{}; + CUTLASS_PRAGMA_UNROLL + for (int temp_iter = 0; temp_iter < cute::size(accum); temp_iter++) + { + accum(temp_iter) = fn(accum(temp_iter)); + } + + // (4) push all the result to smem + // (4.1) convert result from ElementAccum to ElementInput + cute::Tensor temp_accum = util_convert_type(accum); + // if (cute::thread0()) { + // cute::print(temp_accum(0, 0, 0)); + // printf("\n"); + // } + // (4.2) retile rf and smem for copy back.. + auto smem_tiled_copy_O = cute::make_tiled_copy_C(typename KT::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); + // cute::clear(sO); + cute::Tensor taccumrO = smem_thr_copy_O.retile_S(temp_accum); + cute::Tensor taccumsO = smem_thr_copy_O.partition_D(sO); + + // (4.3) copy rf result to smem (TODO: maybe use forloop for better performance..) + cute::copy(smem_tiled_copy_O, taccumrO, taccumsO); + __syncthreads(); + + // (4.4) sO -> rO -> gO + + typename KT::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + // auto gmem_thr_copy_Bias = gmem_tiled_copy_O.get_thread_slice(thread_idx % KT::kGmemTrheadsPerRow); // + cute::Tensor gO = gOutput_mn(cute::_, cute::_, block_m_idx, block_n_idx); + auto tOsO = gmem_thr_copy_O.partition_S(sO); + auto tOgO = gmem_thr_copy_O.partition_D(gO); + cute::Tensor cOutput = cute::make_identity_tensor( + cute::make_shape(cute::size<0>(typename KT::TileShape{}), cute::size<1>(typename KT::TileShape{}))); + cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cOutput); + cute::Tensor tOrO = cute::make_tensor(cute::shape(tOgO)); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < cute::size<1>(tOgO); ++m) + { + if (cute::get<0>(tOcO(0, m, 0)) < residue_m) + { + cute::copy(gmem_tiled_copy_O, tOrO(cute::_, m, cute::_), tOgO(cute::_, m, cute::_)); + } + } + } +}; + +} // namespace fused_moe diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh new file mode 100644 index 00000000000..b4c90085dbb --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/fused_moe_kernel_traits.cuh @@ -0,0 +1,215 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace fused_moe +{ +template +struct Routine_Arguments +{ + ElementInput* ptr_input{}; + ElementWeight* ptr_fc1{}; + ElementInput* ptr_bias{}; + ElementOutput* ptr_output{}; + int64_t const* total_tokens_including_expert{}; + int gemm_n{}; + int gemm_k{}; + int num_expert{}; + bool bias_is_broadcast{}; +}; + +template +struct Routine_Params +{ + ElementInput* ptr_input{}; + ElementWeight* ptr_fc1{}; + ElementInput* ptr_bias{}; + ElementOutput* ptr_output{}; + int64_t const* total_tokens_including_expert{}; + int gemm_n{}; + int gemm_k{}; + int num_expert{}; + bool bias_is_broadcast{}; +}; + +enum class Activation_Type +{ + Gelu = 0, + Relu, + Silu, + Swiglu, + Geglu, + Identity, + InvalidType +}; + +constexpr bool isGateActivation(Activation_Type const& activation_type) +{ + return activation_type == Activation_Type::Swiglu || activation_type == Activation_Type::Geglu; +} + +template +constexpr Activation_Type EpilogueRouting(bool /*is_gate*/) +{ + return Activation_Type::InvalidType; +} + +template <> +constexpr Activation_Type EpilogueRouting(bool /*is_gate*/) +{ + return Activation_Type::Identity; +} + +template <> +constexpr Activation_Type EpilogueRouting(bool /*is_gate*/) +{ + return Activation_Type::Relu; +} + +template <> +constexpr Activation_Type EpilogueRouting(bool is_gate) +{ + return is_gate ? Activation_Type::Swiglu : Activation_Type::Silu; +} + +template <> +constexpr Activation_Type EpilogueRouting(bool is_gate) +{ + return is_gate ? Activation_Type::Geglu : Activation_Type::Gelu; +} + +/* fusing all three kernels has many limitations. This is the simpler version. Just fuse first two kernels..*/ +template +struct Fused_Moe_Kernel_traits_sm80 +{ + using ElementInput = ElementInput_; + using ElementWeight = ElementWeight_; + using ElementAccum = float; + using ElementOutput = ElementOutput_; + + using index_t = uint32_t; + static_assert(TileM_ % 16 == 0); + static_assert(TileN_ % 32 == 0); + static_assert(TileK_ % 32 == 0); + static constexpr int Stages = Stages_; + static constexpr int kTileM = TileM_; + static constexpr int kTileN = TileN_; + static constexpr int kTileK = (kTileM > 16) ? (TileK_) : (TileK_ >= 64 ? TileK_ : 64); + + // tile shape + using TileShape = cute::Shape, cute::Int, cute::Int>; + static constexpr int kWarpsCount = 4; + static constexpr int kThreadCount = kWarpsCount * 32; + + // MMA atom arch and layout + using MMA_Atom_Arch = std::conditional_t, + cute::MMA_Atom, cute::MMA_Atom>; + // using ValLayoutMNK = cute::Layout>; + using ThreadLayoutMNK + = std::conditional_t, cute::_1>>, + cute::Layout, cute::_1>>>; + using ValLayoutMNK = std::conditional_t, + cute::Tile>; + using TiledMma = cute::TiledMMA; // 32x32x16 or 16x64x16 MMA for LDSM if kWarp = 4 + static constexpr int kAlignment = 8; + static constexpr int kBlcokKSmem = (kTileM == 16) ? 64 : 32; + // A memory copy operand + using DefaultOperandA + = DefaultGemm_TensorOpSm80_OperandA; + using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; + using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; + using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; + + // B memory copy operand + using DefaultOperandB + = DefaultGemm_TensorOpSm80_OperandB; + using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; + using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; + using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; + + // Output memory copy operand + using SmemLayoutAtomO = SmemLayoutAtomA; + using SmemCopyAtomO = cute::Copy_Atom; + static constexpr int kGmemElementPerLoad = sizeof(cute::uint128_t) / sizeof(ElementOutput); + static constexpr int kGmemTrheadsPerRow = kBlcokKSmem / kGmemElementPerLoad; + using GmemLayoutAtomO + = cute::Layout, cute::Int>, + cute::Stride, cute::_1>>; + using GmemTiledCopyO = decltype(cute::make_tiled_copy(cute::Copy_Atom{}, + GmemLayoutAtomO{}, cute::Layout>{})); + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2); + static_assert(cute::size<0>(TileShape{}) % cute::size<0>(SmemLayoutAtomA{}) == 0); // M + static_assert(cute::size<2>(TileShape{}) % cute::size<1>(SmemLayoutAtomA{}) == 0); // K + static_assert(cute::rank(SmemLayoutAtomB{}) == 2); + static_assert(cute::size<1>(TileShape{}) % cute::size<0>(SmemLayoutAtomB{}) == 0); // N + static_assert(cute::size<2>(TileShape{}) % cute::size<1>(SmemLayoutAtomB{}) == 0); // K + + using SmemLayoutA = decltype(cute::tile_to_shape(SmemLayoutAtomA{}, + cute::make_shape( + cute::shape<0>(TileShape{}), cute::shape<2>(TileShape{}), cute::Int{}))); // BLK_M, BLK_K, Stages + using SmemLayoutB = decltype(cute::tile_to_shape(SmemLayoutAtomB{}, + cute::make_shape( + cute::shape<1>(TileShape{}), cute::shape<2>(TileShape{}), cute::Int{}))); // BLK_N, BLK_K, Stages + using SmemLayoutO = decltype(cute::tile_to_shape( + SmemLayoutAtomO{}, cute::make_shape(cute::shape<0>(TileShape{}), cute::shape<1>(TileShape{})))); // BLK_M, BLK_N + + // we need at least 2 stages.. + static_assert(Stages >= 2); + + struct SharedStorageNormal : cute::aligned_struct<128> + { + cute::array_aligned> smem_input; + cute::array_aligned> smem_fc1_weight; + cute::array_aligned> smem_o; + }; + + struct SharedStorageGate : cute::aligned_struct<128> + { + cute::array_aligned> smem_input; + cute::array_aligned> smem_fc1_gate_weight; + cute::array_aligned> smem_fc1_weight; + cute::array_aligned> smem_o; + }; + + using SharedStorage = std::conditional_t; + + using ActivationFn = std::conditional_t, + std::conditional_t, + std::conditional_t, cutlass::epilogue::thread::Identity>>>; + + static constexpr int kSmemSize = static_cast(sizeof(SharedStorage)); + + static constexpr bool can_implement(int const avaliable_smem_size) + { + return avaliable_smem_size > kSmemSize; + } + + // #endif +}; +} // namespace fused_moe diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h new file mode 100644 index 00000000000..80a4d856085 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h @@ -0,0 +1,73 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Scheduler for grouped GEMM +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" +#include "cutlass/matrix_coord.h" + +#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h" +#include "cutlass_extensions/gemm/kernel/moe_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct GemmMoeProblemVisitor + : public MoeProblemVisitor, ThreadblockShape, + GroupScheduleMode_, PrefetchTileCount, ThreadCount> +{ + + static bool const kTransposed = Transposed; + + using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper; + using Base + = MoeProblemVisitor; + using Params = typename Base::Params; + using SharedStorage = typename Base::SharedStorage; + + // + // Methods + // + CUTLASS_DEVICE + GemmMoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx) + : Base(params_, shared_storage_, block_idx) + { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp new file mode 100644 index 00000000000..3a084ee04fb --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_universal_gated.hpp @@ -0,0 +1,70 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel +{ + +//////////////////////////////////////////////////////////////////////////////// + +/* + * Stateless universal device GEMM kernel type that treats GEMM as + * a composition of a collective mainloop and a collective epilogue. + * + * Supports both the 2.x and 3.x APIs based on whether the first type is + * a cute::tuple<> or not. + * 2.x API implementation: cutlass/gemm/kernel/gemm_universal.h + * 3.x API implementation: cutlass/gemm/kernel/gemm_*.hpp + * + * In the following declaration, the name preceding the 'Or' refers to + * 3.x API type argument order, and the name succeeding the 'Or' refers to + * 2.x API type argument order. Template arguments without two names + * belong to the 3.x API only. + **/ +template +class GemmUniversalGated; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel + +//////////////////////////////////////////////////////////////////////////////// + +#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp" +#include "cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp" +//////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h new file mode 100644 index 00000000000..0650ca8ded4 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h @@ -0,0 +1,585 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief GEMM kernel to support the epilogue visitor model + for customized softmax partial reduction epilogue fusion. + + This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once + its usage has been stabilized. For now, it is included in this example to demonstrate + some basic output fusion options. + + original file: 3rdparty/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h +*/ + +#pragma once + +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" +#include "cutlass/trace.h" + +#include "cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h" + +namespace tk = tensorrt_llm::common; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct GemmWithEpilogueVisitor +{ +public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueVisitor = typename Epilogue::Visitor; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using TensorRefA = TensorRef; + + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using TensorRefB = TensorRef; + + using ElementCompute = typename EpilogueVisitor::ElementCompute; + using LayoutAlphaCol = cutlass::layout::RowMajor; + using LayoutAlphaRow = cutlass::layout::ColumnMajor; + using TensorRefAlphaCol = TensorRef; + using TensorRefAlphaRow = TensorRef; + + using ElementC = typename EpilogueVisitor::ElementOutput; + using LayoutC = typename Epilogue::Layout; + using TensorRefC = TensorRef; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + using EpilogueOutputOp = + typename Epilogue::Visitor::ElementwiseFunctor; // Define type so GemmUniversalBase doesn't complain + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment + = const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value); + + // + // Structures + // + + /// Argument structure + struct Arguments + { + + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; + + TensorRefA ref_A; + TensorRefB ref_B; + tk::QuantMode quant_option; + TensorRefAlphaCol ref_alpha_col; + TensorRefAlphaRow ref_alpha_row; + TensorRefC ref_C; + TensorRefC ref_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + int64_t batch_stride_D; + + typename EpilogueVisitor::Arguments epilogue_visitor; + + // + // Methods + // + + Arguments() + : mode(GemmUniversalMode::kGemm) + , batch_count(1) + { + } + + /// constructs an arguments structure + Arguments(GemmUniversalMode mode_, GemmCoord problem_size_, int batch_count_, TensorRefA ref_A_, + TensorRefB ref_B_, tk::QuantMode quant_option_, TensorRefAlphaCol ref_alpha_col_, + TensorRefAlphaRow ref_alpha_row_, TensorRefC ref_C_, TensorRefC ref_D_, int64_t batch_stride_A_, + int64_t batch_stride_B_, typename EpilogueVisitor::Arguments epilogue_visitor_) + : mode(mode_) + , problem_size(problem_size_) + , batch_count(batch_count_) + , ref_A(ref_A_) + , ref_B(ref_B_) + , quant_option(quant_option_) + , ref_alpha_col(ref_alpha_col_) + , ref_alpha_row(ref_alpha_row_) + , ref_C(ref_C_) + , ref_D(ref_D_) + , batch_stride_A(batch_stride_A_) + , batch_stride_B(batch_stride_B_) + , batch_stride_D(0) + , epilogue_visitor(epilogue_visitor_) + { + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params + { + + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_col; + typename EpilogueVisitor::ScaleTileIterator::Params params_alpha_row; + typename EpilogueVisitor::OutputTileIterator::Params params_C; + typename EpilogueVisitor::OutputTileIterator::Params params_D; + + GemmUniversalMode mode; + int batch_count; + int gemm_k_size; + + void* ptr_A; + void* ptr_B; + tk::QuantMode quant_option; + typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_col; + typename EpilogueVisitor::ScaleTileIterator::Element* ptr_alpha_row; + ElementC* ptr_C; + ElementC* ptr_D; + + int64_t batch_stride_A; + int64_t batch_stride_B; + + typename EpilogueVisitor::Params epilogue_visitor; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : swizzle_log_tile(0) + , params_A(0) + , params_B(0) + , params_alpha_col(0) + , params_C(0) + , params_D(0) + , batch_count(0) + , gemm_k_size(0) + , mode(cutlass::gemm::GemmUniversalMode::kGemm) + , ptr_A(nullptr) + , ptr_B(nullptr) + , ptr_alpha_col(nullptr) + , ptr_alpha_row(nullptr) + , ptr_C(nullptr) + , ptr_D(nullptr) + , batch_stride_A(0) + , batch_stride_B(0) + { + } + + Params( + Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape_, int gemm_k_size_, int* workspace_) + : problem_size(args.problem_size) + , swizzle_log_tile(0) + , params_A(args.ref_A.layout()) + , params_B(args.ref_B.layout()) + , params_alpha_col(args.ref_alpha_col.layout()) + , params_alpha_row(args.ref_alpha_col.layout()) + , params_C(args.ref_C.layout()) + , params_D(args.ref_D.layout()) + , mode(args.mode) + , batch_count(args.batch_count) + , gemm_k_size(args.problem_size.k()) + , ptr_A(args.ref_A.data()) + , ptr_B(args.ref_B.data()) + , quant_option(args.quant_option) + , ptr_alpha_col(args.ref_alpha_col.data()) + , ptr_alpha_row(args.ref_alpha_row.data()) + , ptr_C(args.ref_C.data()) + , ptr_D(args.ref_D.data()) + , batch_stride_A(args.batch_stride_A) + , batch_stride_B(args.batch_stride_B) + , epilogue_visitor(args.epilogue_visitor) + { + + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) + { + + int const kAlignK + = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + + if (gemm_k_size) + { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage + { + + typename Mma::SharedStorage main_loop; + + struct + { + typename Epilogue::SharedStorage epilogue; + typename EpilogueVisitor::SharedStorage visitor; + } epilogue; + }; + +public: + // + // Methods + // + + CUTLASS_DEVICE + GemmWithEpilogueVisitor() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) + { + + CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()"); + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::OutputTileIterator::kElementsPerAccess; + + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + + if (platform::is_same::value) + { + isAMisaligned = problem_size.k() % kAlignmentA; + } + else if (platform::is_same::value) + { + isAMisaligned = problem_size.m() % kAlignmentA; + } + else if (platform::is_same>::value + || platform::is_same>::value) + { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (platform::is_same::value) + { + isBMisaligned = problem_size.n() % kAlignmentB; + } + else if (platform::is_same::value) + { + isBMisaligned = problem_size.k() % kAlignmentB; + } + else if (platform::is_same>::value + || platform::is_same>::value) + { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (platform::is_same::value) + { + isCMisaligned = problem_size.n() % kAlignmentC; + } + else if (platform::is_same::value) + { + isCMisaligned = problem_size.m() % kAlignmentC; + } + else if (platform::is_same>::value + || platform::is_same>::value) + { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) + { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) + { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) + { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + static Status can_implement(Arguments const& args) + { + return can_implement(args.problem_size); + } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) + { + + return 0; + } + +#define SPLIT_K_ENABLED 1 + + /// Executes one GEMM + CUTLASS_DEVICE + void run_kernel_(Params const& params, SharedStorage& shared_storage) + { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() + || params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) + { + + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA* ptr_A = static_cast(params.ptr_A); + ElementB* ptr_B = static_cast(params.ptr_B); + +#if SPLIT_K_ENABLED + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel) + { + + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) + { + + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } + else if (params.mode == GemmUniversalMode::kBatched) + { + ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; + } + else if (params.mode == GemmUniversalMode::kArray) + { + ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; + ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; + } +#endif + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B{offset_k, threadblock_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, ptr_A, {params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, ptr_B, {problem_size_k, params.problem_size.n()}, thread_idx, tb_offset_B); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // + // Construct the epilogue visitor + // + + EpilogueVisitor epilogue_visitor(params.epilogue_visitor, shared_storage.epilogue.visitor, + params.problem_size.mn(), thread_idx, warp_idx, lane_idx, params.params_alpha_col, params.params_C, + params.params_D, params.quant_option, params.ptr_alpha_row, params.ptr_alpha_col, params.ptr_C, + params.ptr_D, threadblock_offset, blockIdx.y * params.problem_size.m()); + + if (params.mode == GemmUniversalMode::kGemm) + { + // Indicate which position in a serial reduction the output operator is currently updating + epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) + { + epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); + } + + // Construct the epilogue + Epilogue epilogue(shared_storage.epilogue.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(epilogue_visitor, accumulators); + } + + template + CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) + { + if constexpr (platform::is_same::value) + { + run_kernel_(params, shared_storage); + } + else + { + CUTLASS_NOT_IMPLEMENTED(); + } + } + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) + { +#if defined(__CUDA_ARCH__) +#if (__CUDA_ARCH__ >= 720) && (__CUDA_ARCH__ < 750) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 900) + // TODO - replace with CUTLASS_NOT_IMPLEMENTED() and upgrade to 3.x kernels. + run_kernel(params, shared_storage); +#else + static_assert( + false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); +#endif +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h new file mode 100644 index 00000000000..6dc6ffc1a9f --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h @@ -0,0 +1,143 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/* + This file exists so that we use the same weight layout for MoE grouped gemm and regular gemm when the weight is + quantized. The preprocessing code reads this template to know how to organize the quantized weight matrices + to be consumed by CUTLASS. + + Note that for int4, ThreadBlockK MUST be 64. + + */ + +#pragma once + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/platform/platform.h" + +#include "cutlass_extensions/arch/mma.h" +#include "cutlass_extensions/tile_interleaved_layout.h" + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +template +struct LayoutDetailsB +{ +}; + +// Specializations for Turing+ when B is FP16. These are currently only used for MoE networks. +// TODO - Switch this to column major for weights since gemms should be more performant. +template +struct LayoutDetailsB= 75>::type> +{ + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template +struct LayoutDetailsB= 75>::type> +{ + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template +struct LayoutDetailsB +{ + static constexpr int ThreadblockK = 64; + +private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + +public: + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; + // for fast accumulation + // using Operator = cutlass::arch::OpMultiplyAddFastAccum; +}; + +// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA, +// which signals that we want to dequantize after loading from smem. +template + struct LayoutDetailsB < TypeA, + uint8_t, Arch, + typename platform::enable_if= 75 && Arch::kMinComputeCapability<90>::type> +{ + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + +private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + +public: + using Layout = layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; +}; + +template + struct LayoutDetailsB < TypeA, + uint4b_t, Arch, + typename platform::enable_if= 75 && Arch::kMinComputeCapability<90>::type> +{ + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + +private: + static constexpr int ElementsPerCacheLine = 128 * 8 / sizeof_bits::value; + static constexpr int ColumnsInterleaved = ElementsPerCacheLine / ThreadblockK; + +public: + using Layout = layout::ColumnMajorTileInterleave; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA; +}; + +template +struct LayoutDetailsB= 90>::type> +{ + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +template +struct LayoutDetailsB= 90>::type> +{ + static constexpr int ThreadblockK = 128 * 8 / cutlass::sizeof_bits::value; + using Layout = layout::ColumnMajor; + static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits::value; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cute_util.cuh b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cute_util.cuh new file mode 100644 index 00000000000..aac2cb35799 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cute_util.cuh @@ -0,0 +1,185 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include +#include + +template +struct DefaultGemm_TensorOpSm80_OperandA; + +template +struct DefaultGemm_TensorOpSm80_OperandB; + +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype(cute::composition( + cute::Swizzle<3, 3, 3>{}, cute::Layout, cute::Stride>{})); + using SmemCopyAtom = cute::Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype(cute::make_tiled_copy( + cute::Copy_Atom, cute::half_t>{}, + cute::Layout, cute::Stride>{}, + cute::Layout>{})); +}; + +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype(cute::composition( + cute::Swizzle<3, 3, 3>{}, cute::Layout, cute::Stride>{})); + using SmemCopyAtom = cute::Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype(cute::make_tiled_copy( + cute::Copy_Atom, cute::bfloat16_t>{}, + cute::Layout, cute::Stride>{}, + cute::Layout>{})); +}; + +/// Operand A - Column-major (M-major) +template +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype(cute::composition( + cute::Swizzle<3, 3, 3>{}, cute::Layout, cute::Stride>{})); + using SmemCopyAtom = cute::Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype(cute::make_tiled_copy( + cute::Copy_Atom, cute::half_t>{}, + cute::Layout, cute::Stride>{}, + cute::Layout>{})); +}; + +template +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype(cute::composition( + cute::Swizzle<3, 3, 3>{}, cute::Layout, cute::Stride>{})); + using SmemCopyAtom = cute::Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype(cute::make_tiled_copy( + cute::Copy_Atom, cute::bfloat16_t>{}, + cute::Layout, cute::Stride>{}, + cute::Layout>{})); +}; + +// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the DefaultOperands + +// Operand B - Column-Major (K-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{ +}; + +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{ +}; + +// Operand B - Row-Major (N-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{ +}; + +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{ +}; + +// +// F16: 128-by-128-by-32 (small k-block) +// + +/// Operand A - Row-major (K-Major) +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype(cute::composition( + cute::Swizzle<2, 3, 3>{}, cute::Layout, cute::Stride>{})); + using SmemCopyAtom = cute::Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype(cute::make_tiled_copy( + cute::Copy_Atom, cute::half_t>{}, + cute::Layout, cute::Stride>{}, + cute::Layout>{})); +}; + +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype(cute::composition( + cute::Swizzle<2, 3, 3>{}, cute::Layout, cute::Stride>{})); + using SmemCopyAtom = cute::Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype(cute::make_tiled_copy( + cute::Copy_Atom, cute::bfloat16_t>{}, + cute::Layout, cute::Stride>{}, + cute::Layout>{})); +}; + +template +CUTE_DEVICE auto util_convert_type(cute::Tensor const& tensor) +{ + using From_type = typename Engine::value_type; + constexpr int numel = decltype(cute::size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast const*>(tensor.data())); + return cute::make_tensor(cute::make_rmem_ptr(&frag), tensor.layout()); +} + +template +CUTE_DEVICE void util_copy( + TiledCopy const& tiled_copy, cute::Tensor const& S, cute::Tensor& D) +{ + CUTE_STATIC_ASSERT_V(cute::rank(S) == cute::Int<3>{}); + CUTE_STATIC_ASSERT_V(cute::rank(D) == cute::Int<3>{}); + CUTE_STATIC_ASSERT_V(cute::size<0>(S) == cute::size<0>(D)); + CUTE_STATIC_ASSERT_V(cute::size<1>(S) == cute::size<1>(D)); + CUTE_STATIC_ASSERT_V(cute::size<2>(S) == cute::size<2>(D)); + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < cute::size<1>(S); ++m) + { + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < cute::size<2>(S); ++k) + { + cute::copy(tiled_copy, S(cute::_, m, k), D(cute::_, m, k)); + } + } +} diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h new file mode 100644 index 00000000000..b708f7c28b5 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h @@ -0,0 +1,553 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief +*/ + +#pragma once + +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" + +#include "cutlass_extensions/gemm/kernel/gemm_moe_problem_visitor.h" +#include "cutlass_extensions/tile_interleaved_layout.h" + +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_sm90_traits.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// +// This section exists to that we can use the same kernel code for regular gemm and dequantizing gemms. +// It will dispatch to the dequantizing gemm if the Mma type has an Iterator for scales in global. +template +using void_t = void; + +template +struct use_dq_gemm : platform::false_type +{ +}; + +template +struct use_dq_gemm> : platform::true_type +{ +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MoeFCGemm +{ +public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + static bool const kTransposed = false; + + // Optional transpose + using MapArguments = kernel::detail::MapArguments; + + // Public-facing type definitions related to operand element type, layout, and complex conjugate + // operation. Must interact with the 'kTransposed' notion. + static_assert(!kTransposed, "Transpose problem not supported"); + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename MapArguments::LayoutC; + using ElementScale = ElementC; + + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = MapArguments::kAlignmentA; + static int const kAlignmentB = MapArguments::kAlignmentB; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using ProblemVisitor + = GemmMoeProblemVisitor; + + // + // Structures + // + + /// Argument structure + struct Arguments + { + + // + // Data members + // + + int problem_count; + int threadblock_count; + int group_size; + + typename EpilogueOutputOp::Params output_op; + + ElementA* ptr_A; + ElementB* ptr_B; + ElementScale* weight_scales; + ElementC* ptr_C; + ElementC* ptr_D; + bool C_is_broadcast; + + int64_t const* total_tokens_including_expert; + int64_t gemm_n; + int64_t gemm_k; + + // Only used by device-level operator + GemmCoord* host_problem_sizes; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() + : problem_count(0) + , threadblock_count(0) + , ptr_A(nullptr) + , ptr_B(nullptr) + , weight_scales(nullptr) + , ptr_C(nullptr) + , ptr_D(nullptr) + , total_tokens_including_expert(nullptr) + , gemm_n(0) + , gemm_k(0) + , host_problem_sizes(nullptr) + , C_is_broadcast{true} + { + } + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments(int problem_count, int threadblock_count, int group_size, typename EpilogueOutputOp::Params output_op, + ElementA const* ptr_A, ElementB const* ptr_B, ElementScale const* weight_scales, ElementC const* ptr_C, + bool C_is_broadcast, ElementC* ptr_D, int64_t const* total_tokens_including_expert, int64_t gemm_n, + int64_t gemm_k, GemmCoord* host_problem_sizes = nullptr) + : problem_count(problem_count) + , threadblock_count(threadblock_count) + , group_size(group_size) + , output_op(output_op) + , ptr_A(const_cast(ptr_A)) + , ptr_B(const_cast(ptr_B)) + , weight_scales(const_cast(weight_scales)) + , ptr_C(const_cast(ptr_C)) + , C_is_broadcast{C_is_broadcast} + , ptr_D(ptr_D) + , total_tokens_including_expert(total_tokens_including_expert) + , gemm_n(gemm_n) + , gemm_k(gemm_k) + , host_problem_sizes(nullptr) + { + if (platform::is_same::value || platform::is_same::value) + { + assert(weight_scales); + } + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params + { + + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + int group_size; + bool C_is_broadcast; + + typename EpilogueOutputOp::Params output_op; + + ElementA* ptr_A; + ElementB* ptr_B; + ElementScale* weight_scales; + ElementC* ptr_C; + ElementC* ptr_D; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params() + : ptr_A(nullptr) + , ptr_B(nullptr) + , weight_scales(nullptr) + , ptr_C(nullptr) + , ptr_D(nullptr) + , C_is_broadcast(true) + { + } + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + : problem_visitor( + args.total_tokens_including_expert, args.gemm_n, args.gemm_k, args.problem_count, workspace, tile_count) + , threadblock_count(args.threadblock_count) + , group_size(args.group_size) + , output_op(args.output_op) + , ptr_A(args.ptr_A) + , ptr_B(args.ptr_B) + , weight_scales(args.weight_scales) + , ptr_C(args.ptr_C) + , ptr_D(args.ptr_D) + , C_is_broadcast(args.C_is_broadcast) + { + } + + CUTLASS_HOST_DEVICE + void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + { + + problem_visitor = typename ProblemVisitor::Params(args.total_tokens_including_expert, args.gemm_n, + args.gemm_k, args.problem_count, workspace, tile_count); + threadblock_count = args.threadblock_count; + output_op = args.output_op; + ptr_A = args.ptr_A; + ptr_B = args.ptr_B; + weight_scales = args.weight_scales; + ptr_C = args.ptr_C; + ptr_D = args.ptr_D; + C_is_broadcast = args.C_is_broadcast; + } + }; + + /// Shared memory storage structure + union SharedStorage + { + typename ProblemVisitor::SharedStorage problem_visitor; + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + +public: + // + // Methods + // + + CUTLASS_DEVICE + MoeFCGemm() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) + { + return Status::kSuccess; + } + + static Status can_implement(Arguments const& args) + { + if (platform::is_same::value || platform::is_same::value) + { + if (args.weight_scales == nullptr) + { + CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - weight scales are required for uint8_t and uint4b_t"); + return Status::kInvalid; + } + } + else if (args.weight_scales != nullptr) + { + CUTLASS_TRACE_HOST( + "MoeFCGemm::can_implement() - weight scales are ignored for all types except uint8_t and uint4b_t"); + return Status::kInvalid; + } + else if (args.group_size != args.gemm_k) + { + CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - scale shape should be (1, gemm_n)"); + return Status::kInvalid; + } + // Handle the case the input is too short + else if (args.gemm_n < Mma::IteratorB::AccessType::kElements) + { + CUTLASS_TRACE_HOST("MoeFCGemm::can_implement() - gemm_n is smaller than the input alignment"); + return Status::kInvalid; + } + return Status::kSuccess; + } + + static size_t get_extra_workspace_size(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape) + { + + return 0; + } + + CUTLASS_DEVICE + void run_kernel_(Params const& params, SharedStorage& shared_storage) + { + // + // These types shadow the type-level definitions and support the ability to implement + // a 'transposed' GEMM that computes the transposed problems. + // + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK; + static_assert(platform::is_same::value && kInterleave == 1 + || platform::is_same::value && kInterleave >= 1, + "B must be row major/col major OR col major interleaved."); + + // + // Problem visitor. + // + ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + + const int64_t gemm_k = params.problem_visitor.gemm_k; + const int64_t gemm_n = params.problem_visitor.gemm_n; + int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits::value; + + // Outer 'persistent' loop to iterate over tiles + int loop = 0; + while (problem_visitor.next_tile()) + { + loop++; + + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t cta_idx = int32_t(problem_visitor.threadblock_idx()); + + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + cutlass::gemm::GemmCoord threadblock_offset( + int(cta_idx / grid_shape.n()) * Mma::Shape::kM, int(cta_idx % grid_shape.n()) * Mma::Shape::kN, 0); + + // Load element pointers. Exchange pointers and strides if working on the transpose + const int64_t rows_to_jump + = problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1]; + ElementA* ptr_A = reinterpret_cast(params.ptr_A) + rows_to_jump * gemm_k; + typename LayoutA::LongIndex ldm_A = gemm_k; + + char* byte_ptr_B = ((char*) params.ptr_B) + problem_idx * bytes_per_expert_matrix; + ElementB* ptr_B = reinterpret_cast(byte_ptr_B); + typename LayoutB::LongIndex ldm_B + = platform::is_same::value ? gemm_n : gemm_k * kInterleave; + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_offset.m(), + 0, + }; + + cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave}; + + cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()}; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, tb_offset_A); + + typename Mma::IteratorB iterator_B(LayoutB(ldm_B), ptr_B, + {problem_size.k() * kInterleave, problem_size.n() / kInterleave}, thread_idx, tb_offset_B); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Matrix multiply phase + // + + // Construct thread-scoped matrix multiply + auto CreateMMA = [&]() + { + if constexpr (use_dq_gemm::value) + return Mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx); + else + return Mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + }; + Mma mma = CreateMMA(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Wait for all threads to finish their epilogue phases from the previous tile. + __syncthreads(); + + // Compute threadblock-scoped matrix multiply-add + ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * problem_size.n(); + + if constexpr (use_dq_gemm::value) + { + const MatrixCoord scale_extent = {1, problem_size.n()}; + typename Mma::IteratorScale iterator_scale(Mma::IteratorScale::Layout(scale_extent.column()), + weight_scale_ptr, scale_extent, thread_idx, tb_offset_scale); + + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators); + } + else + { + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + } + + // + // Epilogue + // + + ElementC* ptr_C = reinterpret_cast(params.ptr_C) + + (params.C_is_broadcast ? problem_idx : rows_to_jump) * gemm_n; + ElementC* ptr_D = reinterpret_cast(params.ptr_D) + rows_to_jump * gemm_n; + + // lora need to set as layout_C(gemm_n) + LayoutC layout_C = params.C_is_broadcast ? LayoutC(0) : LayoutC(gemm_n); + LayoutC layout_D(gemm_n); + + typename Epilogue::OutputTileIterator::Params params_C(layout_C); + typename Epilogue::OutputTileIterator::Params params_D(layout_D); + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset.mn()); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset.mn()); + + Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + if constexpr (platform::is_same>::value) + { + EpilogueOutputOp output_op(params.output_op, problem_idx); + epilogue(output_op, iterator_D, accumulators, iterator_C); + } + else + { + EpilogueOutputOp output_op(params.output_op); + epilogue(output_op, iterator_D, accumulators, iterator_C); + } + + // Next tile + problem_visitor.advance(gridDim.x); + } + } + + template + CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage) + { + if constexpr (platform::is_same::value) + { + run_kernel_(params, shared_storage); + } + else + { + CUTLASS_NOT_IMPLEMENTED(); + } + } + + /* + To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond + to the ArchTag of the cutlass kernel operator. + */ + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) + { +#if defined(__CUDA_ARCH__) +#if (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 890) + run_kernel(params, shared_storage); +#elif (__CUDA_ARCH__ >= 890) && (__CUDA_ARCH__ < 900) + constexpr bool isFp8 = platform::is_same::value + || platform::is_same::value; + if constexpr (isFp8) + { + run_kernel(params, shared_storage); + } + else + { // reuse sm80 kernel for other types, align with dispatchToArch + run_kernel(params, shared_storage); + } +#elif (__CUDA_ARCH__ >= 900) + run_kernel(params, shared_storage); +#else + static_assert( + false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels."); +#endif +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_problem_visitor.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_problem_visitor.h new file mode 100644 index 00000000000..796dc2fe78d --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/moe_problem_visitor.h @@ -0,0 +1,344 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! \file + \brief Base scheduler for grouped problems, using MoE +*/ + +#pragma once + +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct BaseMoeProblemVisitor +{ + using ThreadblockShape = ThreadblockShape_; + + struct ProblemInfo + { + static int32_t const kNoPrefetchEntry = -1; + int32_t problem_idx; + int32_t problem_start; + + CUTLASS_DEVICE + ProblemInfo() + : problem_idx(kNoPrefetchEntry) + , problem_start(kNoPrefetchEntry) + { + } + + CUTLASS_DEVICE + ProblemInfo(int32_t problem_idx_, int32_t problem_start_) + : problem_idx(problem_idx_) + , problem_start(problem_start_) + { + } + }; + + struct Params + { + int64_t const* last_row_for_problem; + int64_t gemm_n; + int64_t gemm_k; + int32_t problem_count; + void const* workspace; + int32_t tile_count; + + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + Params() + : last_row_for_problem(nullptr) + , gemm_n(0) + , gemm_k(0) + , problem_count(0) + , workspace(nullptr) + , tile_count(0) + { + } + + /// Ctor + CUTLASS_HOST_DEVICE + Params(int64_t const* last_row_for_problem, int64_t gemm_n, int64_t gemm_k, int32_t problem_count, + void const* workspace = nullptr, int32_t tile_count = 0) + : last_row_for_problem(last_row_for_problem) + , gemm_n(gemm_n) + , gemm_k(gemm_k) + , problem_count(problem_count) + , workspace(workspace) + , tile_count(tile_count) + { + } + }; + + Params const& params; + int32_t tile_idx; + int32_t problem_tile_start; + int32_t problem_idx; + + // + // Methods + // + CUTLASS_DEVICE + BaseMoeProblemVisitor(Params const& params_, int32_t block_idx) + : params(params_) + , tile_idx(block_idx) + , problem_tile_start(0) + , problem_idx(0) + { + } + + /// Get the grid shape + CUTLASS_HOST_DEVICE + static cutlass::gemm::GemmCoord grid_shape(cutlass::gemm::GemmCoord const& problem) + { + + return cutlass::gemm::GemmCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), + ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN), 1); + } + + /// Gets the global tile index + CUTLASS_HOST_DEVICE + int32_t tile_index() const + { + return tile_idx; + } + + /// Gets the index of the problem + CUTLASS_HOST_DEVICE + int32_t problem_index() const + { + return problem_idx; + } + + CUTLASS_HOST_DEVICE + int32_t threadblock_idx() const + { + return tile_idx - problem_tile_start; + } + + CUTLASS_DEVICE + void advance(int32_t grid_size) + { + tile_idx += grid_size; + } + + CUTLASS_HOST_DEVICE + static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) + { + ProblemSizeHelper::possibly_transpose_problem(problem); + } + + /// Returns the problem size for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size() const + { + return problem_size(problem_idx); + } + + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size(int idx) const + { + const int64_t prev_problem_row = idx == 0 ? 0 : params.last_row_for_problem[idx - 1]; + const int64_t current_problem_row = params.last_row_for_problem[idx]; + const int64_t gemm_m = current_problem_row - prev_problem_row; + GemmCoord problem(GemmCoord::Index(gemm_m), GemmCoord::Index(params.gemm_n), GemmCoord::Index(params.gemm_k)); + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } + + CUTLASS_HOST_DEVICE + static int32_t tile_count(cutlass::gemm::GemmCoord const& grid) + { + return ProblemSizeHelper::tile_count(grid); + } + + static int32_t group_tile_count(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count) + { + int32_t total_tiles = 0; + for (int32_t i = 0; i < problem_count; ++i) + { + auto problem = host_problem_sizes_ptr[i]; + possibly_transpose_problem(problem); + auto grid = grid_shape(problem); + total_tiles += tile_count(grid); + } + + return total_tiles; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MoeProblemVisitor; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// ProblemVisitor that performs all scheduling on device +// +template +struct MoeProblemVisitor : public BaseMoeProblemVisitor +{ + using Base = BaseMoeProblemVisitor; + using Params = typename Base::Params; + static int const kThreadCount = ThreadCount; + static bool const kRequiresPrecomputation = false; + static int const kThreadsPerWarp = 32; + + struct SharedStorage + { + }; + + // Final tile of the problem loaded by this thread. Each thread will hold + // a separate value. + int32_t problem_ending_tile; + + SharedStorage& shared_storage; + + // + // Methods + // + CUTLASS_DEVICE + MoeProblemVisitor(Params const& params_, SharedStorage& shared_storage_, int32_t block_idx) + : Base(params_, block_idx) + , problem_ending_tile(0) + , shared_storage(shared_storage_) + { + this->problem_idx = -1 * kThreadsPerWarp; + this->problem_tile_start = 0; + } + + CUTLASS_DEVICE + bool next_tile() + { + // Check whether the tile to compute is within the range of the current problem. + int32_t problem_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, this->problem_idx % kThreadsPerWarp); + if (this->tile_idx < problem_tile_end) + { + return true; + } + + // Check whether the tile to compute is within the current group of problems fetched by the warp. + // The last tile for this group is the final tile of the problem held by the final thread in the warp. + int32_t group_tile_end = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1); + + // Keep the starting problem for this group in `problem_idx`. This is done to reduce + // register pressure. The starting problem for this group is simply the first problem + // in the group most recently fetched by the warp. + int32_t& group_problem_start = this->problem_idx; + group_problem_start = (this->problem_idx / kThreadsPerWarp) * kThreadsPerWarp; + + // Keep the starting tile for this group in `problem_tile_start`. This is done to reduce + // register pressure. + int32_t& group_tile_start = this->problem_tile_start; + + // Each thread in the warp processes a separate problem to advance until + // reaching a problem whose starting tile is less less than tile_idx. + while (group_tile_end <= this->tile_idx) + { + group_problem_start += kThreadsPerWarp; + if (group_problem_start > this->params.problem_count) + { + return false; + } + + // Since `group_tile_start` is a reference to `this->problem_tile_start`, this + // also sets `this->problem_tile_start`. The fact that `this->problem_tile_start` + // is also set here is used later in `next_tile`. + group_tile_start = group_tile_end; + + int lane_idx = threadIdx.x % kThreadsPerWarp; + int32_t lane_problem = group_problem_start + lane_idx; + + // Compute the number of tiles in the problem assigned to each thread. + problem_ending_tile = 0; + if (lane_problem < this->params.problem_count) + { + cutlass::gemm::GemmCoord problem = this->problem_size(lane_problem); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + problem_ending_tile = this->tile_count(grid); + } + + // Compute a warp-wide inclusive prefix sum to compute the ending tile index of + // each thread's problem. + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kThreadsPerWarp; i <<= 1) + { + int32_t val = __shfl_up_sync(0xffffffff, problem_ending_tile, i); + if (lane_idx >= i) + { + problem_ending_tile += val; + } + } + + // The total tile count for this group is now in the final position of the prefix sum + int32_t tiles_in_group = __shfl_sync(0xffffffff, problem_ending_tile, kThreadsPerWarp - 1); + + problem_ending_tile += group_tile_start; + group_tile_end += tiles_in_group; + } + + // The next problem to process is the first one that does not have ending tile position + // that is greater than or equal to tile index. + int32_t problem_idx_in_group = __popc(__ballot_sync(0xffffffff, problem_ending_tile <= this->tile_idx)); + + this->problem_idx = group_problem_start + problem_idx_in_group; + + // The starting tile for this problem is the ending tile of the previous problem. In cases + // where `problem_idx_in_group` is the first problem in the group, we do not need to reset + // `problem_tile_start`, because it is set to the previous group's ending tile in the while + // loop above. + if (problem_idx_in_group > 0) + { + this->problem_tile_start = __shfl_sync(0xffffffff, problem_ending_tile, problem_idx_in_group - 1); + } + + return true; + } + + static size_t get_workspace_size( + cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count, int32_t block_count) + { + return 0; + } + + static void host_precompute(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count, + int32_t block_count, void* host_workspace_ptr) + { + } +}; + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp new file mode 100644 index 00000000000..e3d31a2c5b3 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_cooperative.hpp @@ -0,0 +1,646 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/tensor.hpp" +#include "cutlass/arch/mma_sm90.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" +#include "cutlass/workspace.h" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel +{ + +/////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversalGated + && CollectiveMainloop_::isGated>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + using Activation = typename CollectiveMainloop::Activation; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(ArchTag::kMinComputeCapability >= 90); + + using TileSchedulerTag = TileScheduler_; + using TileScheduler = + typename detail::TileSchedulerSelector::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr uint32_t NumLoadWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMma{})) / NumThreadsPerWarpGroup; + static constexpr uint32_t MaxThreadsPerBlock + = CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + /// Register requirement for Load and Math WGs + static constexpr uint32_t LoadRegisterRequirement = 40; + static constexpr uint32_t MmaRegisterRequirement = 232; + + // 1 stage ordered sequence between mainloop and epilogue producer load threads + using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>; + + // Kernel level shared memory storage + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128> + { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16> + { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; + } pipelines; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments + { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params + { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + void* workspace{nullptr}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static Params to_underlying_arguments(Arguments const& args, void* workspace) + { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + auto problem_shape = args.problem_shape; + // if constexpr (detail::IF_SWAP_AB::value) { + // // swap M/N + // get<0>(problem_shape) = get<1>(args.problem_shape); + // get<1>(problem_shape) = get<0>(args.problem_shape); + // } + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) + { + CUTLASS_TRACE_HOST( + " WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + void* scheduler_workspace = workspace_ptr; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = nullptr; + // Precompute the sub tiles numbers in epilogue, pass into tile scheduler. Therefore it will be used + // in separate reduction scheme for streamk case, NumEpilogueSubTiles default value is 1, which means + // subtile will not be used, therefore separate reduction will not be enabled. + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments(problem_shape_MNKL, TileShape{}, + ClusterShape{}, hw_info, args.scheduler, scheduler_workspace, NumEpilogueSubTiles); + + return {args.mode, problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), hw_info, + scheduler, workspace}; + } + + static bool can_implement(Arguments const& args) + { + bool implementable = (args.mode == GemmUniversalMode::kGemm) + or (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); + if (!implementable) + { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + return implementable; + } + + static size_t get_workspace_size(Arguments const& args) + { + size_t workspace_size = 0; + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status initialize_workspace(Arguments const& args, void* workspace = nullptr, + cudaStream_t stream = nullptr, CudaHostAdapter* cuda_adapter = nullptr) + { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + constexpr uint32_t NumEpilogueSubTiles = CollectiveEpilogue::get_store_pipe_increment(TileShape{}); + + status = TileScheduler::template initialize_workspace(args.scheduler, + workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups, + NumEpilogueSubTiles); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups, NumEpilogueSubTiles); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) + { + return status; + } + + status = CollectiveEpilogue::initialize_workspace( + args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) + { + return status; + } + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 get_grid_shape(Params const& params) + { + // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently + TileSchedulerArguments args{}; + if constexpr (!std::is_const_v) + { + args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; + } + args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN + ? TileScheduler::RasterOrderOptions::AlongN + : TileScheduler::RasterOrderOptions::AlongM; + return TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); + } + + static dim3 get_block_shape() + { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void operator()(Params const& params, char* smem_buf) + { + using namespace cute; + using X = Underscore; + +// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. +#if !defined(__CUDA_ARCH_FEAT_SM90_ALL) + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); +#else + + // Preconditions + static_assert(size(TiledMma{}) == 256, "Cooperative kernel must have TiledMMA operating using 256 threads."); + static_assert(size<0>(TileShape{}) >= 128, + "Cooperative kernel requires Tile Size to be greater than or equal to 128 along the M-dimension."); + + static_assert(cute::rank(StrideA{}) == 3, + "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, + "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, + "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, + "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + /* In the Cooperative kernel, Consumer0 and Consumer1 collaborate on the same tile */ + enum class WarpGroupRole + { + Producer = 0, + Consumer0 = 1, + Consumer1 = 2 + }; + enum class ProducerWarpRole + { + Mainloop = 0, + Warp1 = 1, + Epilogue = 2, + Warp3 = 3 + }; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int lane_idx = canonical_lane_idx(); + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + int mma_thread_idx = thread_idx % size(TiledMma{}); + auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); + auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); + int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) + { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop) + { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) + { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; + mainloop_pipeline_params.num_consumers = size(TiledMma{}); + mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue) + { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) + { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; + epi_load_pipeline_params.consumer_arv_count = size(TiledMma{}); + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + typename LoadWarpOrderBarrier::Params params_load_order_barrier; + params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1; + params_load_order_barrier.group_size = NumThreadsPerWarp; + LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier); + + // Initialize starting pipeline states for the collectives + // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + auto cluster_wait_fn = []() + { + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer thread blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) + { + cute::cluster_arrive_relaxed(); + return []() { cute::cluster_wait(); }; + } + else + { + __syncthreads(); + return []() {}; // do nothing + } + }(); + + // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + + // Get the appropriate blocks for this thread block -- potential for thread block locality + TiledMma tiled_mma; + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + + TileScheduler scheduler{params.scheduler}; + auto work_tile_info = scheduler.get_current_work(); + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Prepare and partition the input tensors. Expects a tuple of tensors where: + // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) + // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) + auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); + static_assert(cute::tuple_size_v >= 3, + "Output of load_init must have at least three elements (A, B, Aux)"); + + // Extract out partitioned A and B. + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + Tensor gAux_xkl = get<2>(load_inputs); + + // Get pipeline stage increments from tensor shapes + auto k_tile_count = size<3>(gA_mkl); + + // Wait for all thread blocks in the Cluster + cluster_wait_fn(); + + if (warp_group_role == WarpGroupRole::Producer) + { + cutlass::arch::warpgroup_reg_dealloc(); + + // Mainloop Producer Warp + if (producer_warp_role == ProducerWarpRole::Mainloop) + { + bool do_load_order_arrive = true; + while (work_tile_info.is_valid()) + { + if (!TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) + { + work_tile_info = fetch_next_work(work_tile_info, scheduler); + continue; + } + + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the + // work. + auto work_k_tile_count + = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); + auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); + auto k_tile_iter + = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); + + collective_mainloop.load(params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state, + load_inputs, blk_coord, k_tile_iter, work_k_tile_count, lane_idx, block_rank_in_cluster, + shared_storage.tensors.mainloop); + // Update starting pipeline state for the next tile + mainloop_pipe_producer_state.advance(work_k_tile_count); + + // Signal for the epilogue load warp to begin + if (do_load_order_arrive) + { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + + // Get next work tile + work_tile_info = fetch_next_work(work_tile_info, scheduler); + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + } // Mainloop Producer Warp End + + // Epilogue Producer Warp + else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) + { + while (work_tile_info.is_valid()) + { + if (!TileScheduler::requires_separate_reduction(params.scheduler)) + { + load_order_barrier.wait(); + } + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) + { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + epi_load_pipe_producer_state = collective_epilogue.load(epi_load_pipeline, + epi_load_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, tiled_mma, lane_idx, + shared_storage.tensors.epilogue, work_tile_info.reduction_subtile_idx()); + } + + // Get next work tile + work_tile_info = fetch_next_work(work_tile_info, scheduler); + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } // Epilogue Producer Warp End + } // Producer Warp Group End + + else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) + { + cutlass::arch::warpgroup_reg_alloc(); + + // Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it + bool do_store_tail = false; + float scale_d0 = params.mainloop.scale_d0; + float scale_d1 = params.mainloop.scale_d1; + while (work_tile_info.is_valid()) + { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + auto work_k_tile_count + = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); + + // Allocate the accumulators for the (M,N) blk_shape + // + // MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead. + auto accumulators0 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) + auto accumulators1 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) + if (TileScheduler::valid_warpgroup_in_work_tile(work_tile_info)) + { + collective_mainloop.mma(mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0, + accumulators1, work_k_tile_count, mma_thread_idx, shared_storage.tensors.mainloop, + params.mainloop); + + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail(mainloop_pipeline, mainloop_pipe_consumer_state, work_k_tile_count); + + // Update starting mainloop pipeline state for the next tile + mainloop_pipe_consumer_state.advance(work_k_tile_count); + } + // Index of warp group within consumer warp groups + int consumer_warp_group_idx = canonical_warp_group_idx() - NumLoadWarpGroups; + + // Perform reduction across splits, if needed + TileScheduler::fixup( + params.scheduler, work_tile_info, accumulators0, NumMmaWarpGroups, consumer_warp_group_idx); + TileScheduler::fixup( + params.scheduler, work_tile_info, accumulators1, NumMmaWarpGroups, consumer_warp_group_idx); + + Activation elt_op; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators0); i++) + { + accumulators0[i] = (accumulators0[i] * scale_d0) * elt_op(scale_d1 * accumulators1[i]); + } + + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) + { + // Epilogue and write to gD + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] + = collective_epilogue.store(epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline, + epi_store_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, accumulators0, + tiled_mma, mma_thread_idx, shared_storage.tensors.epilogue, + work_tile_info.reduction_subtile_idx()); + epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next; + epi_store_pipe_producer_state = epi_store_pipe_producer_state_next; + do_store_tail = true; + } + + // Get next work tile + work_tile_info = fetch_next_work(work_tile_info, scheduler); + } // Scheduler work fetch loop + + if (do_store_tail) + { + collective_epilogue.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline, epi_store_pipe_producer_state); + } + } // Consumer Warp Groups End +#endif + } + +private: + // Kernel helper function to get next work unit + CUTLASS_DEVICE + typename TileScheduler::WorkTileInfo fetch_next_work( + typename TileScheduler::WorkTileInfo& work_tile_info, TileScheduler& scheduler) const + { + // Check whether we should continue on with the current work unit. If this is the case, + // the work unit will have been updated in continue_current_work to reflect the new + // tile to be computed. + if (scheduler.continue_current_work(work_tile_info)) + { + return work_tile_info; + } + + // Get next work tile + scheduler.advance_to_next_work(); + return scheduler.get_current_work(); + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp new file mode 100644 index 00000000000..39886f2431d --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/sm90_gemm_gated_tma_warpspecialized_pingpong.hpp @@ -0,0 +1,621 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/mma_sm90.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" +#include "cutlass/workspace.h" + +#include "cute/tensor.hpp" + +#include "cute/util/debug.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel +{ + +/////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversalGated + && CollectiveMainloop_::isGated>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + using Activation = typename CollectiveMainloop::Activation; + static_assert(ArchTag::kMinComputeCapability >= 90); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(!cute::is_same_v, + "Ping-pong kernel does not currently support stream-K scheduler."); + using TileSchedulerTag = TileScheduler_; + using TileScheduler = + typename detail::TileSchedulerSelector::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + static constexpr uint32_t NumLoadWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = 2; + static constexpr uint32_t MaxThreadsPerBlock + = CUTE_STATIC_V(size(TiledMma{})) + (NumMmaWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + /// Register requirement for Load and Math WGs + static constexpr uint32_t LoadRegisterRequirement = 40; + static constexpr uint32_t MmaRegisterRequirement = 232; + + // 1 stage ordered sequence between mainloop and epilogue producer load threads + using LoadWarpOrderBarrier = cutlass::OrderedSequenceBarrier<1, 2>; + + // Order Sequence barrier with two stages: one for Mainloop and one for Epilogue + static constexpr uint32_t StagesPerMathWarpGroup = 2; + using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier; + + // Kernel level shared memory storage + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128> + { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16> + { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + using MathWarpGroupOrderBarrierStorage = typename MathWarpGroupOrderBarrier::SharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order; + alignas(16) typename LoadWarpOrderBarrier::SharedStorage load_order; + } pipelines; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments + { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params + { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerParams scheduler{}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static Params to_underlying_arguments(Arguments const& args, void* workspace) + { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + (void) workspace; + auto problem_shape = args.problem_shape; + // if constexpr (detail::IF_SWAP_AB::value) { + // // swap M/N + // get<0>(problem_shape) = get<1>(args.problem_shape); + // get<1>(problem_shape) = get<0>(args.problem_shape); + // } + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) + { + CUTLASS_TRACE_HOST( + " WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + void* scheduler_workspace = workspace_ptr; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = nullptr; + + return {args.mode, problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), hw_info, + TileScheduler::to_underlying_arguments( + problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace)}; + } + + static bool can_implement(Arguments const& args) + { + bool implementable = (args.mode == GemmUniversalMode::kGemm) + or (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); + if (!implementable) + { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + return implementable; + } + + static size_t get_workspace_size(Arguments const& args) + { + size_t workspace_size = 0; + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + static cutlass::Status initialize_workspace(Arguments const& args, void* workspace = nullptr, + cudaStream_t stream = nullptr, CudaHostAdapter* cuda_adapter = nullptr) + { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + status = TileScheduler::template initialize_workspace(args.scheduler, + workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) + { + return status; + } + + status = CollectiveEpilogue::initialize_workspace( + args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) + { + return status; + } + + return status; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 get_grid_shape(Params const& params) + { + // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently + TileSchedulerArguments args{}; + if constexpr (!std::is_const_v) + { + args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; + } + args.raster_order = params.scheduler.raster_order_ == TileScheduler::RasterOrder::AlongN + ? TileScheduler::RasterOrderOptions::AlongN + : TileScheduler::RasterOrderOptions::AlongM; + return TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); + } + + static dim3 get_block_shape() + { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void operator()(Params const& params, char* smem_buf) + { + using namespace cute; + using X = Underscore; + +// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. +#if !defined(__CUDA_ARCH_FEAT_SM90_ALL) + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); +#else + + // Preconditions + static_assert(cute::rank(StrideA{}) == 3, + "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, + "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, + "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, + "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + enum class WarpGroupRole + { + Producer = 0, + Consumer0 = 1, + Consumer1 = 2 + }; + enum class ProducerWarpRole + { + Mainloop = 0, + Warp1 = 1, + Epilogue = 2, + Warp3 = 3 + }; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int lane_idx = canonical_lane_idx(); + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); + auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); + int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) + { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Mainloop) + { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) + { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; + mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup; + mainloop_pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytes; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Epilogue) + { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) + { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; + epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + typename LoadWarpOrderBarrier::Params params_load_order_barrier; + params_load_order_barrier.group_id = producer_warp_role == ProducerWarpRole::Mainloop ? 0 : 1; + params_load_order_barrier.group_size = NumThreadsPerWarp; + LoadWarpOrderBarrier load_order_barrier(shared_storage.pipelines.load_order, params_load_order_barrier); + + typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier; + // DMA Load WG will not participate in these Ordered Barrier syncs + params_math_wg_order_barrier.group_id = canonical_warp_group_idx() - static_cast(WarpGroupRole::Consumer0); + params_math_wg_order_barrier.group_size = NumThreadsPerWarpGroup; // Number of threads / participants in a group + MathWarpGroupOrderBarrier math_wg_order_barrier( + shared_storage.pipelines.math_wg_order, params_math_wg_order_barrier); + + // Initialize starting pipeline states for the collectives + // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + auto cluster_wait_fn = [&]() + { + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer thread blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) + { + cute::cluster_arrive_relaxed(); + return []() { cute::cluster_wait(); }; + } + else + { + __syncthreads(); + return []() {}; // do nothing + } + }(); + + // Separate out problem shape for convenience + // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + + // Get the appropriate blocks for this thread block -- potential for thread block locality + TiledMma tiled_mma; + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Prepare and partition the input tensors. Expects a tuple of tensors where: + // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) + // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) + auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); + static_assert(cute::tuple_size_v >= 3, + "Output of load_init must have at least three elements (A, B, Aux)"); + + // Extract out partitioned A and B. + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + Tensor gAux_xkl = get<2>(load_inputs); + + // Get pipeline stage increments from tensor shapes + auto k_tile_count = size<3>(gA_mkl); + auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape); + auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); + + TileScheduler scheduler{params.scheduler}; + + if (warp_group_role == WarpGroupRole::Consumer1) + { + // Advance 2nd Math WG to the next work tile for the startup + scheduler.advance_to_next_work(); + // Advance 2nd Math WG pipeline states to the end of 1st Math WG + mainloop_pipe_consumer_state.advance(k_tile_count); + epi_load_pipe_consumer_state.advance(c_tile_count); + epi_store_pipe_producer_state.advance(d_tile_count); + } + auto work_tile_info = scheduler.get_current_work(); + + // Wait for all thread blocks in the Cluster + cluster_wait_fn(); + + if (warp_group_role == WarpGroupRole::Producer) + { + cutlass::arch::warpgroup_reg_dealloc(); + + // Mainloop Producer Warp + if (producer_warp_role == ProducerWarpRole::Mainloop) + { + bool do_load_order_arrive = true; + while (work_tile_info.is_valid()) + { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl)); + + collective_mainloop.load(params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state, + load_inputs, blk_coord, k_tile_iter, k_tile_count, lane_idx, block_rank_in_cluster, + shared_storage.tensors.mainloop); + // Update starting pipeline state for the next tile + mainloop_pipe_producer_state.advance(k_tile_count); + + // Signal for the epilogue load warp to begin + if (do_load_order_arrive) + { + load_order_barrier.arrive(); + do_load_order_arrive = false; + } + + // Get next work tile + scheduler.advance_to_next_work(); + work_tile_info = scheduler.get_current_work(); + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + } // Mainloop Producer Warp End + + // Epilogue Producer Warp + else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) + { + load_order_barrier.wait(); + while (work_tile_info.is_valid()) + { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + epi_load_pipe_producer_state + = collective_epilogue.load(epi_load_pipeline, epi_load_pipe_producer_state, problem_shape_MNKL, + blk_shape, blk_coord, tiled_mma, lane_idx, shared_storage.tensors.epilogue); + + // Get next work tile + scheduler.advance_to_next_work(); + work_tile_info = scheduler.get_current_work(); + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } // Epilogue Producer Warp End + } // Producer Warp Group End + + else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) + { + cutlass::arch::warpgroup_reg_alloc(); + + float scale_d0 = params.mainloop.scale_d0; + float scale_d1 = params.mainloop.scale_d1; + while (work_tile_info.is_valid()) + { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Allocate the accumulators for the (M,N) blk_shape + Tensor accumulators0 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) + Tensor accumulators1 = partition_fragment_C(tiled_mma, take<0, 2>(blk_shape)); // (MMA,MMA_M,MMA_N) + + // Order two Math WG's MMA one after the other, helps hide Epilogue + math_wg_order_barrier.wait(); + + collective_mainloop.mma(mainloop_pipeline, mainloop_pipe_consumer_state, accumulators0, accumulators1, + k_tile_count, warp_group_thread_idx, shared_storage.tensors.mainloop, params.mainloop); + + // Cue for next Math WG's MMA to start + math_wg_order_barrier.arrive(); + + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail(mainloop_pipeline, mainloop_pipe_consumer_state, k_tile_count); + // Update starting mainloop pipeline state for the next tile + mainloop_pipe_consumer_state.advance(k_tile_count * NumMmaWarpGroups); + + Activation elt_op; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators0); i++) + { + accumulators0[i] = (accumulators0[i] * scale_d0) * elt_op(scale_d1 * accumulators1[i]); + } + + // Order two Math WG's Epilogue one after the other + math_wg_order_barrier.wait(); + + // Epilogue and write to gD + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] + = collective_epilogue.store(epi_load_pipeline, epi_load_pipe_consumer_state, epi_store_pipeline, + epi_store_pipe_producer_state, problem_shape_MNKL, blk_shape, blk_coord, accumulators0, + tiled_mma, warp_group_thread_idx, shared_storage.tensors.epilogue); + + // TMA store pipeline wait is only visible to TMA-issuing warp, so for multiple-consumer kernels + // we need to wait for all TMA stores to complete before issuing consumer order barrier arrives + // to ensure next math consumer doesn't overwrite smem of in-flight TMA stores of current consumer. + auto [epi_load_pipe_consumer_state_next_, epi_store_pipe_producer_state_next_] + = collective_epilogue.store_tail(epi_load_pipeline, epi_load_pipe_consumer_state_next, + epi_store_pipeline, epi_store_pipe_producer_state_next); + + // Update starting load/store pipeline states for the next tile + // state has already been incremented by 1 tile in collective calls, advance once again for ping pong + epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next_; + epi_store_pipe_producer_state = epi_store_pipe_producer_state_next_; + epi_load_pipe_consumer_state.advance(c_tile_count); + epi_store_pipe_producer_state.advance(d_tile_count); + + // Cue for next Math WG's Epilogue to start + math_wg_order_barrier.arrive(); + + // Get next work tile + scheduler.advance_to_next_work(NumMmaWarpGroups); + work_tile_info = scheduler.get_current_work(); + } // Scheduler work fetch loop + } // Consumer Warp Groups End +#endif + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h new file mode 100644 index 00000000000..5e3531f0938 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/splitk_gemm_grouped.h @@ -0,0 +1,494 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h +*/ + +#pragma once + +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace kernel +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SplitkGemmGrouped +{ +public: + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + static bool const kTransposed = Transposed; + + // Optional transpose + using MapArguments = kernel::detail::MapArguments; + + // Public-facing type definitions related to operand element type, layout, and complex conjugate + // operation. Must interact with the 'kTransposed' notion. + using ElementA = typename MapArguments::ElementA; + using LayoutA = typename MapArguments::LayoutA; + using ElementB = typename MapArguments::ElementB; + using LayoutB = typename MapArguments::LayoutB; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename MapArguments::LayoutC; + + using ElementFinalOutput = typename MapArguments::ElementA; + + static ComplexTransform const kTransformA = MapArguments::kTransformA; + static ComplexTransform const kTransformB = MapArguments::kTransformB; + + // Type definitions about the mainloop. + using Operator = typename Mma::Operator; + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = MapArguments::kAlignmentA; + static int const kAlignmentB = MapArguments::kAlignmentB; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using ProblemVisitor + = GemmGroupedProblemVisitor; + + // + // Structures + // + + /// Argument structure + struct Arguments + { + + // + // Data members + // + + GemmCoord* problem_sizes; + int problem_count; + int threadblock_count; + + typename EpilogueOutputOp::Params output_op; + + ElementA** ptr_A; + ElementB** ptr_B; + ElementFinalOutput** ptr_C; + ElementFinalOutput** ptr_D; + + typename LayoutA::Stride::LongIndex* lda; + typename LayoutB::Stride::LongIndex* ldb; + typename LayoutC::Stride::LongIndex* ldc; + typename LayoutC::Stride::LongIndex* ldd; + + // Only used by device-level operator + GemmCoord* host_problem_sizes; + + // splitK + int split_k_slices; + int64_t* splitk_buffer_offsets; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() + : problem_count(0) + , threadblock_count(0) + , ptr_A(nullptr) + , ptr_B(nullptr) + , ptr_C(nullptr) + , ptr_D(nullptr) + , lda(nullptr) + , ldb(nullptr) + , ldc(nullptr) + , ldd(nullptr) + , host_problem_sizes(nullptr) + , split_k_slices(1) + , splitk_buffer_offsets(nullptr) + { + } + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments(GemmCoord* problem_sizes, int problem_count, int threadblock_count, + typename EpilogueOutputOp::Params output_op, ElementA** ptr_A, ElementB** ptr_B, ElementFinalOutput** ptr_C, + ElementFinalOutput** ptr_D, typename LayoutA::Stride::LongIndex* lda, + typename LayoutB::Stride::LongIndex* ldb, typename LayoutC::Stride::LongIndex* ldc, + typename LayoutC::Stride::LongIndex* ldd, GemmCoord* host_problem_sizes, int split_k_slices, + int64_t* splitk_buffer_offsets) + : problem_sizes(problem_sizes) + , problem_count(problem_count) + , threadblock_count(threadblock_count) + , output_op(output_op) + , ptr_A(ptr_A) + , ptr_B(ptr_B) + , ptr_C(ptr_C) + , ptr_D(ptr_D) + , lda(lda) + , ldb(ldb) + , ldc(ldc) + , ldd(ldd) + , host_problem_sizes(host_problem_sizes) + , split_k_slices(split_k_slices) + , splitk_buffer_offsets(splitk_buffer_offsets) + { + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params + { + + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + + typename EpilogueOutputOp::Params output_op; + + ElementA** ptr_A; + ElementB** ptr_B; + ElementFinalOutput** ptr_C; + ElementFinalOutput** ptr_D; + ElementC* ptr_C_split; + ElementC* ptr_D_split; + + typename LayoutA::Stride::LongIndex* lda; + typename LayoutB::Stride::LongIndex* ldb; + typename LayoutC::Stride::LongIndex* ldc; + typename LayoutC::Stride::LongIndex* ldd; + + // + // Methods + // + + // splitk + GemmCoord grid_tiled_shape; + int swizzle_log_tile; + int gemm_k_size; + GemmCoord* host_problem_sizes; + int split_k_slices; + int64_t* splitk_buffer_offsets; + + CUTLASS_HOST_DEVICE + Params() + : ptr_A(nullptr) + , ptr_B(nullptr) + , ptr_C(nullptr) + , ptr_D(nullptr) + , ptr_C_split(nullptr) + , ptr_D_split(nullptr) + , lda(nullptr) + , ldb(nullptr) + , ldc(nullptr) + , ldd(nullptr) + , swizzle_log_tile(0) + , gemm_k_size(0) + , host_problem_sizes(nullptr) + , split_k_slices(1) + , splitk_buffer_offsets(nullptr) + { + } + + CUTLASS_HOST_DEVICE + Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + : problem_visitor(args.problem_sizes, args.problem_count, workspace, tile_count) + , host_problem_sizes(args.host_problem_sizes) + , threadblock_count(args.threadblock_count) + , output_op(args.output_op) + , ptr_A(args.ptr_A) + , ptr_B(args.ptr_B) + , ptr_C(args.ptr_C) + , ptr_D(args.ptr_D) + , ptr_C_split((ElementC*) workspace) + , ptr_D_split((ElementC*) workspace) + , lda(args.lda) + , ldb(args.ldb) + , ldc(args.ldc) + , ldd(args.ldd) + , split_k_slices(args.split_k_slices) + , splitk_buffer_offsets(args.splitk_buffer_offsets) + { + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + grid_tiled_shape = threadblock_swizzle.get_tiled_shape(args.host_problem_sizes[0], + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.split_k_slices); + swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape); + + // only support same k + int full_gemm_k_iterations = args.host_problem_sizes[0].k() / Mma::Shape::kK; + int gemm_k_iterations = full_gemm_k_iterations / grid_tiled_shape.k(); + + gemm_k_size = gemm_k_iterations * Mma::Shape::kK; + } + + CUTLASS_HOST_DEVICE + void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) + { + + problem_visitor = + typename ProblemVisitor::Params(args.problem_sizes, args.problem_count, workspace, tile_count); + threadblock_count = args.threadblock_count; + output_op = args.output_op; + ptr_A = args.ptr_A; + ptr_B = args.ptr_B; + ptr_C = args.ptr_C; + ptr_D = args.ptr_D; + ptr_C_split = workspace; + ptr_D_split = workspace; + + lda = args.lda; + ldb = args.ldb; + ldc = args.ldc; + ldd = args.ldd; + } + }; + + /// Shared memory storage structure + struct SharedStorage + { + union + { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + } kernel; + + // ProblemVisitor shared storage can't be overlapped with others + typename ProblemVisitor::SharedStorage problem_visitor; + }; + +public: + // + // Methods + // + + CUTLASS_DEVICE + SplitkGemmGrouped() {} + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) + { + return Status::kSuccess; + } + + static Status can_implement(Arguments const& args) + { + return Status::kSuccess; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const& params, SharedStorage& shared_storage) + { + + // + // These types shadow the type-level definitions and support the ability to implement + // a 'transposed' GEMM that computes the transposed problems. + // + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + + // + // Problem visitor. + // + ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) + { + + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + // Load element pointers. Exchange pointers and strides if working on the transpose + ElementA* ptr_A + = reinterpret_cast((kTransposed ? params.ptr_B[problem_idx] : params.ptr_A[problem_idx])); + typename LayoutA::LongIndex ldm_A = (kTransposed ? params.ldb[problem_idx] : params.lda[problem_idx]); + + ElementB* ptr_B + = reinterpret_cast((kTransposed ? params.ptr_A[problem_idx] : params.ptr_B[problem_idx])); + typename LayoutB::LongIndex ldm_B = (kTransposed ? params.lda[problem_idx] : params.ldb[problem_idx]); + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + cutlass::gemm::GemmCoord threadblock_offset(int(threadblock_idx / grid_shape.n()) * Mma::Shape::kM, + int(threadblock_idx % grid_shape.n()) * Mma::Shape::kN, 0); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_offset.m(), + threadblock_tile_offset.k() * params.gemm_k_size, + }; + + cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size, threadblock_offset.n()}; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k; + if (threadblock_tile_offset.k() + 1 == params.grid_tiled_shape.k()) + { + problem_size_k = problem_size.k(); + } + else + { + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size_k}, thread_idx, tb_offset_A); + + typename Mma::IteratorB iterator_B( + LayoutB(ldm_B), ptr_B, {problem_size_k, problem_size.n()}, thread_idx, tb_offset_B); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = canonical_warp_idx_sync(); + + int lane_idx = threadIdx.x % 32; + + // + // Matrix multiply phase + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); + + // Wait for all threads to finish their epilogue phases from the previous tile. + __syncthreads(); + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + ElementC* ptr_C = params.ptr_C_split; + ElementC* ptr_D = params.ptr_D_split; + + LayoutC layout_C(params.ldc[problem_idx]); + LayoutC layout_D(params.ldd[problem_idx]); + + typename Epilogue::OutputTileIterator::Params params_C(layout_C); + typename Epilogue::OutputTileIterator::Params params_D(layout_D); + + // assume identity swizzle + MatrixCoord threadblock_offset_C(threadblock_offset.m(), threadblock_offset.n()); + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset_C); + + iterator_C.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k() + + gridDim.z * params.splitk_buffer_offsets[problem_idx]); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset_C); + iterator_D.add_pointer_offset(problem_size.m() * problem_size.n() * threadblock_tile_offset.k() + + gridDim.z * params.splitk_buffer_offsets[problem_idx]); + + Epilogue epilogue(shared_storage.kernel.epilogue, thread_idx, warp_idx, lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // Next tile + problem_visitor.advance(gridDim.x); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h new file mode 100644 index 00000000000..ed5e3e4daf8 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma.h @@ -0,0 +1,125 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass_extensions/arch/mma.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ +//////////////////////////////////////////////////////////////////////////////// + +// We need to distinguish here, since we want volta support. It is too much effort +// to write shared memory iterators that are probably needed for volta to function +// properly. As a result, we allow converters both after the LDG (for volta) and after +// the LDS for Turing+. +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Warp level Mma + typename MmaOperator, + /// Math operation perform by warp level operator + typename MathOperator> +struct SetConverters +{ +}; + +// Dequantize after LDG, so set transforms accordingly +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Mma Policy + typename MmaOperator> +struct SetConverters +{ + using TransformAfterLDG + = FastInterleavedAndBiasedNumericArrayConverter; + + using TransformAfterLDS = NumericArrayConverter; +}; + +// Dequantize after LDS, so set transforms accordingly + +template < + /// Iterator for B matrix in global memory + typename IteratorB, + /// Mma Policy + typename MmaOperator> +struct SetConverters +{ + using TransformAfterLDG = NumericArrayConverter; + + using TransformAfterLDS + = FastInterleavedAndBiasedNumericArrayConverter; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale_, + /// Layout for the scale operand + typename LayoutScale_, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// + typename Enable = void> +struct DqMma; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h new file mode 100644 index 00000000000..17c6346553c --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h @@ -0,0 +1,302 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass_extensions/arch/mma.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage.h" +#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" +#include "cutlass_extensions/tile_interleaved_layout.h" + +#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h" +#include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h" + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +//////////////////////////////////////////////////////////////////////////////// + +template +struct DefaultScaleIteratorsMultistage; + +// Fine grained iterators +template +struct DefaultScaleIteratorsMultistage> +{ + using IteratorScale + = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, Element, + Layout, 0, Alignment>; + + using SmemIteratorScale = IteratorScale; +}; + +// Per column iterators +template +struct DefaultScaleIteratorsMultistage> +{ + // ThreadMap for scale iterator + static_assert((MmaShape::kN % Alignment) == 0, ""); + +private: + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap, + MmaShape::kN / Alignment, Alignment>; + +public: + // Define iterators over tiles from the scale operand + using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator, + Element, Layout, 0, IteratorScaleThreadMap, Alignment>; + + using SmemIteratorScale = IteratorScale; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Stages in GEMM + int kStages, + /// Operator performed by GEMM + typename Operator_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +struct DqMma= 80 && !layout::IsColumnMajorTileInterleave::value)>::type> +{ + + static_assert(platform::is_same::value || platform::is_same::value + || platform::is_same::value, + "Element A must be fp16, fp8 or bf16"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, + AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementB, LayoutB, 0, ThreadMapB, + AccessTypeB>; + + using ScaleIterators = DefaultScaleIteratorsMultistage; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converter = FastInterleavedAndBiasedNumericArrayConverter; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage; +}; + +// Specialization to handle column major interleave B +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Stages in GEMM + int kStages, + /// Operator performed by GEMM + typename Operator_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +struct DqMma= 80 && layout::IsColumnMajorTileInterleave::value)>::type> +{ + + static_assert(platform::is_same::value || platform::is_same::value + || platform::is_same::value, + "Element A must be fp16, fp8 or bf16"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(platform::is_same::value, + "Mma multistage must dequantize after ldsm"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + // Mma core does not depend on stages, so pass in at least 3 here to mma multistage pieces are created + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, ThreadMapA, + AccessTypeA>; + +private: + static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; + static constexpr int RowsPerTile = LayoutB::kRowsPerTile; + static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); + static_assert(RowsPerTile == MmaCore::Shape::kK, ""); + + using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; + using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; + static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); + + using GmemIteratorShape + = MatrixShape; + using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, OriginalThreadMap::kThreads, + layout::PitchLinearShape, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + +public: + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator; + + using ScaleIterators = DefaultScaleIteratorsMultistage; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converter = FastInterleavedAndBiasedNumericArrayConverter; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaMultistage; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h new file mode 100644 index 00000000000..345cd2eec9a --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h @@ -0,0 +1,284 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass_extensions/arch/mma.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h" +#include "cutlass_extensions/gemm/warp/default_mma_tensor_op.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" +#include "cutlass_extensions/tile_interleaved_layout.h" + +#include "cutlass_extensions/gemm/threadblock/default_dq_mma.h" +#include "cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h" + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +//////////////////////////////////////////////////////////////////////////////// + +template +struct DefaultScaleIteratorsPipelined; + +// Fine grained iterators +template +struct DefaultScaleIteratorsPipelined> +{ +private: + using SmemScaleType = half_t; + +public: + using IteratorScale + = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, Element, + Layout, 0, Alignment>; + + using SmemIteratorScale + = cutlass::transform::threadblock::FineGrainedScaleZeroIterator, + SmemScaleType, Layout, 0, Alignment>; +}; + +// Per column iterators +template +struct DefaultScaleIteratorsPipelined> +{ + static_assert((MmaShape::kN % Alignment) == 0, ""); + +private: + // ThreadMap for scale iterator + using IteratorScaleThreadMap = transform::PitchLinearStripminedThreadMap, + MmaShape::kN / Alignment, Alignment>; + using SmemScaleType = half_t; + +public: + // Define iterators over tiles from the scale operand + using IteratorScale = cutlass::transform::threadblock::PredicatedTileIterator, + Element, Layout, 0, IteratorScaleThreadMap, Alignment>; + + using SmemIteratorScale + = cutlass::transform::threadblock::PredicatedTileIterator, SmemScaleType, + Layout, 0, IteratorScaleThreadMap, Alignment>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator_> +struct DqMma::value)>::type> +{ + + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + static_assert(OperatorInfo::QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); + + static constexpr bool DqAfterLDG = platform::is_same::value; + using MmaCoreElementA = half_t; + using MmaCoreElementB = typename platform::conditional::type; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, + typename MmaCore::IteratorThreadMapA, kAlignmentA>; + + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementB, LayoutB, 0, + typename MmaCore::IteratorThreadMapB, kAlignmentB>; + + using ScaleIterators = DefaultScaleIteratorsPipelined; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converters = SetConverters; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined; +}; + +// Specialization to handle column major interleave B +template < + /// Type for element A + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Type for element B + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for the input scale + typename ElementScale, + /// Layout for the scale operand + typename LayoutScale, + /// Access granularity of Scales in unit of elements + int kAlignmentScale, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator_> +struct DqMma::value)>::type> +{ + + static_assert(platform::is_same::value || platform::is_same::value, + "Element A must be fp16 or bf16"); + + static_assert(platform::is_same::value || platform::is_same::value, + "Element B must be uint8 or uint4"); + + using OperatorInfo = arch::DetagOperator; + using Operator = typename OperatorInfo::Operator; + + static constexpr bool DqAfterLDG = platform::is_same::value; + using MmaCoreElementA = half_t; + using MmaCoreElementB = typename platform::conditional::type; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementA, LayoutA, 1, + typename MmaCore::IteratorThreadMapA, kAlignmentA>; + +private: + static constexpr int ColumnsInterleaved = LayoutB::kColumnsInterleaved; + static constexpr int RowsPerTile = LayoutB::kRowsPerTile; + static_assert(!(MmaCore::Shape::kN % ColumnsInterleaved), ""); + static_assert(RowsPerTile == MmaCore::Shape::kK, ""); + + using OriginalThreadMap = typename MmaCore::IteratorThreadMapB; + using OriginalWarpArrangement = typename OriginalThreadMap::Detail::WarpThreadArrangement; + static_assert(!(OriginalWarpArrangement::kStrided % ColumnsInterleaved), ""); + + using GmemIteratorShape + = MatrixShape; + using GmemThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, OriginalThreadMap::kThreads, + layout::PitchLinearShape, + MmaCore::kAccessSizeInBits / sizeof_bits::value>; + +public: + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator; + + // ThreadMap for scale iterator + static_assert((MmaCore::Shape::kN % kAlignmentScale) == 0, ""); + using IteratorScaleThreadMap + = transform::PitchLinearStripminedThreadMap, + MmaCore::Shape::kN / kAlignmentScale, kAlignmentScale>; + + using ScaleIterators = DefaultScaleIteratorsPipelined; + + // Define iterators over tiles from the scale operand + using IteratorScale = typename ScaleIterators::IteratorScale; + + using SmemIteratorScale = typename ScaleIterators::SmemIteratorScale; + + using Converters = SetConverters; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::DqMmaPipelined; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h new file mode 100644 index 00000000000..ad6c7496e14 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma.h @@ -0,0 +1,351 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" +#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" +#include "cutlass_extensions/gemm/threadblock/default_mma_bf16.h" + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma pipelined (stage=2) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma pipelined (stage=2) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int8 weight, mma multistage +/// (stage>=3) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight, mma multistage +/// (stage>=3) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +#ifdef ENABLE_FP8 +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp8 activation & int4 weight, mma multistage +/// (stage>=3) +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +#endif + +// fp16 x fp16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on +// large tile when not enough shared mem is present to do 3+ stage +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma +{ + + // Define the MmaCore components + // 3 is used on purpose here to trigger components for mma multistage + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, half_t, LayoutA, 1, ThreadMapA, AccessTypeA, + GatherA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, half_t, LayoutB, 0, ThreadMapB, AccessTypeB, + GatherB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h new file mode 100644 index 00000000000..77af81005ab --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/default_mma_bf16.h @@ -0,0 +1,353 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass_extensions/gemm/threadblock/default_dq_mma_multistage.h" +#include "cutlass_extensions/gemm/threadblock/default_dq_mma_pipelined.h" + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & bf16 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma +{ + +private: + // Conversions only needed pre-ampere. This will trigger mma pipeline, so we convert before STS. + static constexpr bool arch_has_bf16_mma = ArchTag::kMinComputeCapability >= 80; + using MmaElementA = typename platform::conditional::type; + using MmaElementB = typename platform::conditional::type; + +public: + // Define the MmaCore components + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + using IteratorA = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, bfloat16_t, LayoutA, 1, + typename MmaCore::IteratorThreadMapA, kAlignmentA, GatherA>; + + // Define iterators over tiles from the B operand + using IteratorB = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, bfloat16_t, LayoutB, 0, + typename MmaCore::IteratorThreadMapB, kAlignmentB, GatherB>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaPipelined; +}; + +// bf16 x bf16 specialization on Ampere to use mma multistage for 2 stage. Helps avoid reg spills on +// large tile when not enough shared mem is present to do 3+ stage +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB> +struct DefaultMma +{ + + // Define the MmaCore components + // 3 is used on purpose here to trigger components for mma multistage + using MmaCore = + typename cutlass::gemm::threadblock::DefaultMmaCore; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, bfloat16_t, LayoutA, 1, ThreadMapA, + AccessTypeA, GatherA>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, bfloat16_t, LayoutB, 0, ThreadMapB, + AccessTypeB, GatherB>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int8 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), bf16 activation & int4 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output (OperatorClass TensorOp), fp16 activation & int4 weight +template < + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// + int kStages, + /// Shared memory clear option + SharedMemoryClearOption SharedMemoryClear> +struct DefaultMma +{ + +private: + static constexpr int kAlignmentScale = 128 / sizeof_bits::value; + + using Mma = DqMma; + +public: + // Define the MmaCore components + using MmaCore = typename Mma::MmaCore; + + // Define iterators over tiles from the A operand + using IteratorA = typename Mma::IteratorA; + + // Define iterators over tiles from the B operand + using IteratorB = typename Mma::IteratorB; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockMma = typename Mma::ThreadblockMma; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h new file mode 100644 index 00000000000..1fb7f7eb28f --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_base.h @@ -0,0 +1,257 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass_extensions/weight_only_quant_op.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +//////////////////////////////////////////////////////////////////////////////// +// SFINAE trick so I can keep the same loop code for Volta and dispatch to the +// correct warp level mma. On volta, all data is stored to shared memory as FP16. +template +CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D, + typename WarpMma::FragmentA const& A, typename WarpMma::FragmentB const& B, typename WarpMma::FragmentC const& C, + int const warp_tileB_k_offset) +{ + warp_mma(D, A, B, C); +} + +template +CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D, + typename WarpMma::TransformedFragmentA const& A, typename WarpMma::TransformedFragmentB const& B, + typename WarpMma::FragmentC const& C, int const warp_tileB_k_offset) +{ + warp_mma(D, A, B, C, warp_tileB_k_offset); +} + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// The type of the scales + typename ElementScale_, + /// Number of stages, + int Stages, + /// The dequantizing op to be performed. + WeightOnlyQuantOp DequantOp, + /// Used for partial specialization, + typename Enable = bool> +class DqMmaBase +{ +public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + ///< Type of the scale to be loaded + using ElementScale = ElementScale_; + + static_assert(DequantOp != WeightOnlyQuantOp::UNDEFINED, ""); + + // Finegrained scales get streamed in via cp.async + static constexpr int ScalebiasStages = isFinegrained(DequantOp) ? Stages : 1; + // We always have scales. + static constexpr int ScaleElementsPerStage = Shape::kN; + // We sometimes have a bias + static constexpr int BiasElementsPerStage = hasZero(DequantOp) ? Shape::kN : 0; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM operations + static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + static constexpr int kNumKIterationsPerWarpBLoad + = Operator::IteratorB::InstructionShape::kRow / Operator::InstructionShape::kK; + + static_assert(!(kWarpGemmIterations % kNumKIterationsPerWarpBLoad), ""); + static constexpr int kWarpGemmIterationsForB = kWarpGemmIterations / kNumKIterationsPerWarpBLoad; + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage + { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA + = MatrixShape; + + /// Shape of the B matrix operand in shared memory + using ShapeB + = MatrixShape; + + /// Shape of the shared memory buffer for the scales for the B matrix. + using ShapeScale = MatrixShape; + /// Shape of the shared memory buffer for the biases of the B matrix. + using ShapeZero = MatrixShape; + + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B; + + /// Buffer to hold scales for threadblock + AlignedBuffer operand_scale; + + /// Buffer to hold scales for threadblock + AlignedBuffer operand_zero; + + public: + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() + { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() + { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() + { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() + { + return TensorRefB{operand_B.data(), LayoutB()}; + } + }; + +protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage& shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx) + , warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) + { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h new file mode 100644 index 00000000000..3c4036dd8cc --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage.h @@ -0,0 +1,110 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Used for partial specialization + typename Enable = void> +class DqMmaMultistage; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h" +#include "cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h" diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h new file mode 100644 index 00000000000..f81961dee3c --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_finegrained.h @@ -0,0 +1,708 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Iterators over scales in global memory + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Layout of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applied immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +class DqMmaMultistage> + : public DqMmaBase +{ +public: + ///< Base class + using Base = DqMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + static_assert(Base::SharedStorage::ShapeScale::kRow == Stages, ""); + static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, ""); + + /// Internal structure exposed for introspection. + struct Detail + { + + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA + = (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB + = (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + +private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave + = layout::IsColumnMajorTileInterleave::value; + static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + +private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory + SmemIteratorScale smem_iterator_scale_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + /// The group size for quantization + int const group_size, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx) + , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + {shared_storage.operand_zero.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) + , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) + , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), + shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_scales_and_advance(IteratorScale& iterator_scale, int stage = -1, int k_iter = -1) + { + static_assert(IteratorScale::Shape::kRow == 1, "Scale stride must be 1."); + + typename IteratorScale::AccessType* gmem_scale_ptr = iterator_scale.get_scale(); + typename IteratorScale::AccessType* gmem_zero_ptr = iterator_scale.get_zero(); + + typename IteratorScale::AccessType* smem_scale_ptr + = reinterpret_cast(this->smem_iterator_scale_.get_scale()); + typename IteratorScale::AccessType* smem_zero_ptr + = reinterpret_cast(this->smem_iterator_scale_.get_zero()); + + int const kSrcBytes = sizeof_bits::value * IteratorScale::kAlignment / 8; + + cutlass::arch::cp_async(smem_scale_ptr, gmem_scale_ptr, iterator_scale.valid()); + + if (gmem_zero_ptr != nullptr) + { + cutlass::arch::cp_async(smem_zero_ptr, gmem_zero_ptr, iterator_scale.valid()); + } + + if (iterator_scale.group_size_ == 64) + { + iterator_scale.add_tile_offset({1, 0}); + } + else if (iterator_scale.group_size_ == 128) + { + if constexpr (Shape::kK == 128) + { + iterator_scale.add_tile_offset({1, 0}); + } + else if constexpr (Shape::kK == 64) + { + if (iterator_scale.row_groupsize64_ & 0x1) + { + iterator_scale.add_tile_offset({1, 0}); + } + } + else + { + static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128"); + } + } + + iterator_scale.row_groupsize64_++; + + this->smem_iterator_scale_.add_tile_offset({1, 0}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance( + IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0) + { + iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) + { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) + { + typename IteratorA::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value + * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) + { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) + { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + else + { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) + { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) + { + typename IteratorB::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value + * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) + { + auto gmem_ptr = iterator_B.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) + { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + else + { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over scale operand in global memory + IteratorScale iterator_scale, + ///< initial value of accumulator + FragmentC const& src_accum) + { + + // + // Prologue + // + + TransformBAfterLDS lds_converter; + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) + { + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) + { + typename IteratorA::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) + { + int const kSrcBytes = sizeof_bits::value + * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) + { + typename IteratorB::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) + { + int const kSrcBytes = sizeof_bits::value + * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + copy_scales_and_advance(iterator_scale, stage, gemm_k_iterations); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels + // so that all accumulator elements outside the GEMM footprint are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) + { + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) + { + + typename IteratorA::AccessType* dst_ptr + = reinterpret_cast(last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) + { + + typename IteratorB::AccessType* dst_ptr + = reinterpret_cast(last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + typename Dequantizer::FragmentScale warp_frag_scales; + typename Dequantizer::FragmentZero warp_frag_zeros; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + warp_dequantizer_.add_pointer_offset(Shape::kN); + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) + { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) + { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) + { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + typename TransformBAfterLDS::result_type converted_frag_B + = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zeros); + + using FragmentOperandB = cutlass::Array; + constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements; + static_assert(ConversionVectorWidth == FragmentOperandB::kElements); + + using Converter + = cutlass::NumericArrayConverter; + + FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B); + run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum, + warp_tileB_k_compute_offset); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) + { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + + // This is the first group of a given stage, so we issue the loads for the B scales immediately. + if (group_start_iteration_B == 0) + { + copy_scales_and_advance(iterator_scale); + } + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) + { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - + // #committed) + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) + { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } + else + { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) + { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); + warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN); + smem_read_stage_idx = 0; + } + else + { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_scale.clear_mask(gemm_k_iterations == 0); + } + } + + // Load the scale needed for the next tile iteration. + warp_dequantizer_.load(warp_frag_scales, warp_frag_zeros); + // Update internal pointer to set of scales in shared memory. + warp_dequantizer_.add_pointer_offset(Shape::kN); + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) + { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h new file mode 100644 index 00000000000..83efdc5cb01 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_multistage_percol.h @@ -0,0 +1,647 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Iterators over scales in global memory + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Layout of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear> +class DqMmaMultistage> + : public DqMmaBase +{ +public: + ///< Base class + using Base = DqMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + // + // Dependent types + // + + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail + { + + static_assert(Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA + = (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB + = (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + +private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave + = layout::IsColumnMajorTileInterleave::value; + static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + +private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale operand to shared memory + SmemIteratorScale smem_iterator_scale_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& shared_storage, + ///< Group size for quantization. Not used by this main loop since it assumes per-column + int const group_size, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storage, thread_idx, warp_idx, lane_idx) + , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) + , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) + , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance( + IteratorA& iterator_A, IteratorB& iterator_B, int group_start_A = 0, int group_start_B = 0) + { + iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) + { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) + { + typename IteratorA::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value + * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) + { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) + { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + else + { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index(group_start_B * IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) + { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) + { + typename IteratorB::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value + * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) + { + auto gmem_ptr = iterator_B.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) + { + cutlass::arch::cp_async_zfill(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + else + { + cutlass::arch::cp_async(dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over scale operand in global memory + IteratorScale iterator_scale, + ///< initial value of accumulator + FragmentC const& src_accum) + { + + // + // Prologue + // + + TransformBAfterLDS lds_converter; + + // NOTE - switch to ldg.sts + // Issue this first, so cp.async.commit_group will commit this load as well. + // Note: we do not commit here and this load will commit in the same group as + // the first load of A. + FragmentScale tb_frag_scales; + tb_frag_scales.clear(); + iterator_scale.load(tb_frag_scales); + this->smem_iterator_scale_.store(tb_frag_scales); + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) + { + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) + { + typename IteratorA::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) + { + int const kSrcBytes = sizeof_bits::value + * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) + { + typename IteratorB::AccessType* dst_ptr + = reinterpret_cast(this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) + { + int const kSrcBytes = sizeof_bits::value + * IteratorB::ThreadMap::kElementsPerAccess / IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels + // so that all accumulator elements outside the GEMM footprint are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) + { + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) + { + + typename IteratorA::AccessType* dst_ptr + = reinterpret_cast(last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) + { + + typename IteratorB::AccessType* dst_ptr + = reinterpret_cast(last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + typename Dequantizer::FragmentScale warp_frag_scales; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + warp_dequantizer_.load(warp_frag_scales); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) + { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) + { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) + { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + typename TransformBAfterLDS::result_type converted_frag_B + = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); + + using FragmentOperandB = cutlass::Array; + constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + constexpr int ConversionVectorWidth = TransformBAfterLDS::result_type::kElements; + static_assert(ConversionVectorWidth == FragmentOperandB::kElements); + + using Converter + = cutlass::NumericArrayConverter; + + FragmentOperandB converted_frag_B_operand = Converter::convert(converted_frag_B); + run_warp_mma(warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B_operand, accum, + warp_tileB_k_compute_offset); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) + { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) + { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - + // #committed) + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) + { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } + else + { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) + { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); + smem_read_stage_idx = 0; + } + else + { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + } + } + } + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) + { + // commit and drain all pending and predicated LDGSTS pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h new file mode 100644 index 00000000000..bd3e38971b0 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined.h @@ -0,0 +1,106 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" +#include "cutlass_extensions/gemm_configs.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type for the scales + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Converter for B matrix applied immediately after the LDG (before STS) + typename TransformBAfterLDG_, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_, + /// Used for partial specialization + typename Enable = void> +class DqMmaPipelined; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h" +#include "cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h" diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h new file mode 100644 index 00000000000..50bdd0d85b0 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_finegrained.h @@ -0,0 +1,486 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" +#include "cutlass_extensions/gemm_configs.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Iterators over scales in global memory + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Layout of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Converter for B matrix applied immediately after the LDG (before STS) + typename TransformBAfterLDG_, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_> +class DqMmaPipelined> + : public DqMmaBase +{ +public: + ///< Base class + using Base = DqMmaBase; + + using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + using TransformBAfterLDG = TransformBAfterLDG_; + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2"); + + static_assert(Base::SharedStorage::ShapeScale::kRow == Base::kStages, ""); + static_assert(Base::SharedStorage::ShapeScale::kColumn == Shape::kN, ""); + +private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using WarpFragmentScale = typename Dequantizer::FragmentScale; + using WarpFragmentZero = typename Dequantizer::FragmentZero; + + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave + = layout::IsColumnMajorTileInterleave::value; + static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + +protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale and zero operand to shared memory + SmemIteratorScale smem_iterator_scale_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaPipelined(typename Base::SharedStorage& + shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM + int const group_size, ///< The group size for quantization + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx) + , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + {shared_storage.operand_zero.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) + , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) + , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), + shared_storage.operand_zero.data(), {Base::kStages, Shape::kN}, thread_idx, group_size) + { + + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_scales_and_advance(IteratorScale& iterator_scale) + { + using TransformScale = NumericArrayConverter; + + FragmentScale tb_frag_scales; + FragmentScale tb_frag_zeros; + tb_frag_scales.clear(); + tb_frag_zeros.clear(); + + TransformScale transformScale; + + using FragmentElement = typename FragmentScale::Element; + + auto gmem_scale_ptr = iterator_scale.get_scale(); + auto gmem_zero_ptr = iterator_scale.get_zero(); + + arch::global_load(tb_frag_scales, gmem_scale_ptr, iterator_scale.valid()); + + if (gmem_zero_ptr != nullptr) + { + arch::global_load( + tb_frag_zeros, gmem_zero_ptr, iterator_scale.valid()); + } + + typename TransformScale::result_type tb_frag_scales_fp16 = transformScale(tb_frag_scales); + typename TransformScale::result_type tb_frag_zeros_fp16; + if (gmem_zero_ptr != nullptr) + tb_frag_zeros_fp16 = transformScale(tb_frag_zeros); + + auto frag_scale_ptr_fp16 = reinterpret_cast(&tb_frag_scales_fp16); + auto frag_zero_ptr_fp16 = reinterpret_cast(&tb_frag_zeros_fp16); + auto smem_scale_ptr = this->smem_iterator_scale_.get_scale(); + auto smem_zero_ptr = this->smem_iterator_scale_.get_zero(); + + if (iterator_scale.valid()) + { + auto smem_offset = cast_smem_ptr_to_uint(smem_scale_ptr); + arch::shared_store(smem_offset, frag_scale_ptr_fp16); + + if (gmem_zero_ptr != nullptr) + { + smem_offset = cast_smem_ptr_to_uint(smem_zero_ptr); + arch::shared_store(smem_offset, frag_zero_ptr_fp16); + } + } + + if (iterator_scale.group_size_ == 64) + { + iterator_scale.add_tile_offset({1, 0}); + } + else if (iterator_scale.group_size_ == 128) + { + if constexpr (Shape::kK == 128) + { + iterator_scale.add_tile_offset({1, 0}); + } + else if constexpr (Shape::kK == 64) + { + if (iterator_scale.row_groupsize64_ & 0x1) + { + iterator_scale.add_tile_offset({1, 0}); + } + } + else + { + static_assert(Shape::kK == 0, "Unsupported k tile shape, can only be 64 or 128"); + } + } + + iterator_scale.row_groupsize64_++; + + this->smem_iterator_scale_.add_tile_offset({1, 0}); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + IteratorScale iterator_scale, ///< iterator over scale operand in global memory + FragmentC const& src_accum) + { ///< source accumulator tile + + // + // Prologue + // + TransformBAfterLDG ldg_converter; + TransformBAfterLDS lds_converter; + + using TransformA + = NumericArrayConverter; + + // These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want + // to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS. + TransformA transformA; + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + + tb_frag_A.clear(); + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transformA(tb_frag_A)); + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + copy_scales_and_advance(iterator_scale); + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + WarpFragmentScale warp_frag_scales; + WarpFragmentZero warp_frag_zero; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + warp_dequantizer_.load(warp_frag_scales, warp_frag_zero); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + warp_dequantizer_.add_pointer_offset(Shape::kN); + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); + iterator_scale.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tighest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) + { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) + { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) + { + + // Write fragments to shared memory + this->smem_iterator_A_.store(transformA(tb_frag_A)); + + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) + { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + this->smem_iterator_scale_.add_tile_offset({-Base::kStages, 0}); + } + else + { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); + warp_dequantizer_.add_pointer_offset(-Base::kStages * Shape::kN); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment. + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) + { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + if (warp_mma_k == 0) + { + + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + copy_scales_and_advance(iterator_scale); + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); + iterator_scale.clear_mask(gemm_k_iterations <= 2); + } + + typename TransformBAfterLDS::result_type converted_frag_B + = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales, warp_frag_zero); + run_warp_mma( + warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset); + } + + // Load the scales needed for the next tile iteration + warp_dequantizer_.load(warp_frag_scales, warp_frag_zero); + // Update internal pointer to the set of scales in shared memory + warp_dequantizer_.add_pointer_offset(Shape::kN); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h new file mode 100644 index 00000000000..316ea9f80a9 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/threadblock/dq_mma_pipelined_percol.h @@ -0,0 +1,399 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass_extensions/gemm/threadblock/dq_mma_base.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h" +#include "cutlass_extensions/interleaved_numeric_conversion.h" + +#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" +#include "cutlass_extensions/gemm_configs.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace threadblock +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Iterators over scales in global memory + typename IteratorScale_, + /// Iterators over scales in shared memory + typename SmemIteratorScale_, + /// Data type of accumulator matrix + typename ElementC_, + /// Layout of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Converter for B matrix applied immediately after the LDG (before STS) + typename TransformBAfterLDG_, + /// Converter for B matrix applited immediately after the LDS + typename TransformBAfterLDS_, + /// The quantization operator being used + WeightOnlyQuantOp QuantOp_> +class DqMmaPipelined> + : public DqMmaBase +{ +public: + ///< Base class + using Base = DqMmaBase; + + using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using IteratorScale = IteratorScale_; + using ElementScale = typename IteratorScale::Element; + using LayoutScale = typename IteratorScale::Layout; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + using SmemIteratorScale = SmemIteratorScale_; + + using TransformBAfterLDG = TransformBAfterLDG_; + using TransformBAfterLDS = TransformBAfterLDS_; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of operand Scale loaded from global memory; + using FragmentScale = typename IteratorScale::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + using Dequantizer = warp::MmaTensorOpDequantizer; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for DqMmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages == 2), "DqMmaPipelined requires kStages set to value 2"); + +private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + Dequantizer warp_dequantizer_; + + using ElementA = typename IteratorA::Element; + using ElementB = typename IteratorB::Element; + using LayoutDetailsForB = kernel::LayoutDetailsB; + + static constexpr bool RequiresTileInterleave + = layout::IsColumnMajorTileInterleave::value; + static_assert(!RequiresTileInterleave || (RequiresTileInterleave && (Shape::kK == LayoutDetailsForB::ThreadblockK)), + "Layout K must match threadblockK"); + +protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of scale operand to shared memory + SmemIteratorScale smem_iterator_scale_; + +public: + /// Construct from tensor references + CUTLASS_DEVICE + DqMmaPipelined(typename Base::SharedStorage& + shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM + int const group_size, ///< Will not be used, just to adapt to finegrained modifications and make the compilation + ///< successful. Because DqMmaPipelined is only enabled for sm<80, so even if this + ///< argument is not added, it does not affect compilation for sm>=80. + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx) + , warp_dequantizer_({shared_storage.operand_scale.data(), LayoutScale(Shape::kN)}, + (warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN)) / Base::WarpCount::kM, lane_idx) + , smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx) + , smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + , smem_iterator_scale_(LayoutScale(Shape::kN), shared_storage.operand_scale.data(), {1, Shape::kN}, thread_idx) + { + + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterationsForB * warp_idx_k, warp_idx_n}); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()(int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + IteratorScale iterator_scale, ///< iterator over scale operand in global memory + FragmentC const& src_accum) + { ///< source accumulator tile + + // + // Prologue + // + TransformBAfterLDG ldg_converter; + TransformBAfterLDS lds_converter; + + using TransformA + = NumericArrayConverter; + + using TransformScale = NumericArrayConverter; + + // These transforms are mainly to handle when we have bfloat activations and weights in GMEM and want + // to issue HMMA on architectures older than Ampere. We will convert to FP16 before STS. + TransformA transformA; + TransformScale transformScale; + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + FragmentScale tb_frag_scales; + + using WarpFragmentScale = typename Dequantizer::FragmentScale; + WarpFragmentScale warp_frag_scales; + + tb_frag_A.clear(); + tb_frag_B.clear(); + tb_frag_scales.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + iterator_scale.load(tb_frag_scales); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transformA(tb_frag_A)); + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + this->smem_iterator_scale_.store(transformScale(tb_frag_scales)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + __syncthreads(); + + warp_dequantizer_.load(warp_frag_scales); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tighest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) + { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) + { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) + { + + // Write fragments to shared memory + this->smem_iterator_A_.store(transformA(tb_frag_A)); + + this->smem_iterator_B_.store(ldg_converter(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) + { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } + else + { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterationsForB, 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; + int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; + // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment. + if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) + { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_tileB_k_load_offset + 1) % Base::kWarpGemmIterationsForB); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_tileB_k_load_offset + 1) % 2]); + ++this->warp_tile_iterator_B_; + } + + if (warp_mma_k == 0) + { + + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + + typename TransformBAfterLDS::result_type converted_frag_B + = lds_converter(warp_frag_B[warp_tileB_k_load_offset % 2]); + warp_dequantizer_.dequantize(converted_frag_B, warp_frag_scales); + run_warp_mma( + warp_mma, accum, warp_frag_A[warp_mma_k % 2], converted_frag_B, accum, warp_tileB_k_compute_offset); + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h new file mode 100644 index 00000000000..350b247de2e --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/default_mma_tensor_op.h @@ -0,0 +1,107 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Default warp-level GEMM operators selected by data type, size, and layouts of operands. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/warp/default_mma_tensor_op.h" +#include "cutlass/gemm/warp/mma_tensor_op.h" + +#include "cutlass_extensions/arch/mma.h" +#include "cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h" + +namespace cutlass +{ +namespace gemm +{ +namespace warp +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for m-by-n-by-kgroup +template < + /// Shape of one matrix production operation (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A elements, + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Number of partitions along K dimension + int PartitionsK, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor> +struct DefaultMmaTensorOp +{ + +private: + // Shape for computing the FP16s + using ComputeInstructionShape = InstructionShape_; + + // Chosen so we get K=16 for int8 and K=32 for int4. + static constexpr int LoadInstructionK = 128 / sizeof_bits::value; + + // Shape for loading the narrow data type from shared memory + using LoadInstructionShape = GemmShape; + +public: + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma, + cutlass::MatrixShape<1, 1>>; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::MmaTensorOpComputeBWithF16; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h new file mode 100644 index 00000000000..7c5088894b4 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_compute_B_with_f16.h @@ -0,0 +1,306 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing warp-level matrix multiply-accumulate operations targeting + Tensor Cores. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm80.h" +#include "cutlass/arch/mma_sm89.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma.h" + +#include "cutlass/gemm/warp/mma_tensor_op_policy.h" + +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace warp +{ + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + typename Policy_, + /// Instruction shape to override shared memory iterators with + typename SharedMemoryInstructionShape_, + /// Number of partitions along K dimension + int PartitionsK_ = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Used for partial specialization + typename Enable = bool> +class MmaTensorOpComputeBWithF16 +{ +public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + static_assert((platform::is_same::value + && platform::is_same::value) + || (platform::is_same::value + && platform::is_same::value + && ArchTag::kMinComputeCapability >= 80) + || (platform::is_same::value + && platform::is_same::value + && ArchTag::kMinComputeCapability >= 89), + "MmaTensorOpCvtBToA only supports underlying HMMA/QMMA"); + + static_assert(platform::is_same::value + || (platform::is_same::value && ArchTag::kMinComputeCapability >= 80) + || (platform::is_same::value && ArchTag::kMinComputeCapability >= 89), + "MmaTensorOpCvtBToA only supports Fp16 A or Bf16 A on Ampere+, or FP8 on Ada"); + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Instruction shape to override shared memory iterators with + using SharedMemoryInstructionShape = SharedMemoryInstructionShape_; + + static_assert( + SharedMemoryInstructionShape::kM == InstructionShape::kM, "M dimension of compute instruction must match load"); + static_assert( + SharedMemoryInstructionShape::kN == InstructionShape::kN, "N dimension of compute instruction must match load"); + + static constexpr int kExpansionFactor = SharedMemoryInstructionShape::kK / InstructionShape::kK; + + static_assert(!(Shape::kK % SharedMemoryInstructionShape::kK), ""); + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + +public: + /// Iterates over the A operand in memory + using IteratorA + = MmaTensorOpMultiplicandTileIterator, Operand::kA, ElementA, LayoutA, + MatrixShape, Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = Array; + + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator, Operand::kB, ElementB, + LayoutB, MatrixShape, Policy::OpDelta::kRow, + kThreadCount, kPartitionsK>; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed B tile + using TransformedFragmentB = Array; + + /// Iterates over the C operand in memory + using IteratorC = MmaTensorOpAccumulatorTileIterator, ElementC, LayoutC, + typename ArchMmaOperator::Shape, typename Policy::OpDelta>; + + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; + + /// Number of mma operations performed + using MmaIterations = MatrixShape<(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN>; + +public: + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; + +public: + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaTensorOpComputeBWithF16() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C, + int const warp_tileB_k_offset) const + { + + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; + + static_assert( + TransformedFragmentB::kElements == MmaOperandB::kElements * kExpansionFactor * MmaIterations::kColumn, + "Each thread should have a pack of mma registers for each column iteration AND for the expanded K dim of " + "B"); + + D = C; + + MmaOperandA const* ptr_A = reinterpret_cast(&A); + MmaOperandB const* ptr_B = reinterpret_cast(&B); + MmaOperandC* ptr_D = reinterpret_cast(&D); + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + // Serpentine visitation order maximizing reuse of Rb + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) + { + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) + { + + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); + + int n_offsetB = warp_tileB_k_offset + kExpansionFactor * n; + if (AccumulatorsInRowMajor) + { // matrix B is reordered + mma(ptr_D[n + m_serpentine * MmaIterations::kColumn], ptr_A[m_serpentine], ptr_B[n_offsetB], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } + else + { + mma(ptr_D[m_serpentine + n * MmaIterations::kRow], ptr_A[m_serpentine], ptr_B[n_offsetB], + ptr_D[m_serpentine + n * MmaIterations::kRow]); + } + } + } +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + // Serpentine visitation order maximizing reuse of Ra + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) + { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) + { + + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + int n_serpentine_offsetB = warp_tileB_k_offset + kExpansionFactor * n_serpentine; + if (AccumulatorsInRowMajor) + { // matrix B is reordered + mma(ptr_D[n_serpentine + m * MmaIterations::kColumn], ptr_A[m], ptr_B[n_serpentine_offsetB], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } + else + { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], ptr_A[m], ptr_B[n_serpentine_offsetB], + ptr_D[m + n_serpentine * MmaIterations::kRow]); + } + } + } +#else + assert(0); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h new file mode 100644 index 00000000000..1d5cd5d8985 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/warp/mma_tensorop_dequantizer.h @@ -0,0 +1,463 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" + +#include "cutlass/functional.h" +#include "cutlass/platform/platform.h" + +#include "cutlass_extensions/weight_only_quant_op.h" +#include "tensorrt_llm/common/cudaBf16Wrapper.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace gemm +{ +namespace warp +{ + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Matrix multiply operator + typename MmaOperator_, + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Operand identity + Operand Operand, + /// Data type of Scale elements + typename Element_, + /// Layout of operand + typename Layout_, + /// Number of threads participating in one matrix operation + int Threads, + /// + WeightOnlyQuantOp QuantOp_, + /// + typename Enable = void> +class MmaTensorOpDequantizer; + +//////////////////////////////////////////////////////////////////////////////// +// Bfloat specialization for Ampere +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// + WeightOnlyQuantOp QuantOp_> +class MmaTensorOpDequantizer= 80 + && platform::is_same::value>::type> +{ + +public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the scales + using ElementScale = bfloat16_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentScale = Array; + using FragmentZero = Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx) + { + int const warp_offset = warp_idx_n * Shape::kN; + int const quad = lane_idx / 4; + int const thread_offset = warp_offset + quad; + pointer_scale_ = smem_scales.data() + thread_offset; + if constexpr (hasZero(QuantOp)) + { + pointer_zero_ = smem_zeros.data() + thread_offset; + } + } + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) + : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) + { + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) + { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) + { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn + == FragmentDequantizedOperand::kElements, + ""); + + __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag); + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); + + __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]); + __nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) + { + operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2); + } + } +#else + // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should + // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid + // numerous conversion instructions in GEMM main loop. + arch::device_breakpoint(); +#endif + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag, FragmentScale& zero_frag) + { + if constexpr (hasZero(QuantOp)) + { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN]; + } + } + else + { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + } + + CUTLASS_DEVICE + void dequantize( + FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag) + { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn + == FragmentDequantizedOperand::kElements, + ""); + + __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag); + __nv_bfloat16 const* zero_ptr = reinterpret_cast<__nv_bfloat16 const*>(&zero_frag); + + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + static_assert(ExpandedMmaOperandB::kElements % 2 == 0, ""); + + __nv_bfloat162 scalex2 = __bfloat162bfloat162(scale_ptr[mma_n_iter]); + __nv_bfloat162 zerox2 = __bfloat162bfloat162(zero_ptr[mma_n_iter]); + __nv_bfloat162* operand_bf16x2_ptr = reinterpret_cast<__nv_bfloat162*>(&operand_frag_ptr[mma_n_iter]); + + if constexpr (hasZero(QuantOp)) + { + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) + { + operand_bf16x2_ptr[ii] = __hfma2(operand_bf16x2_ptr[ii], scalex2, zerox2); + } + } + else + { + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < ExpandedMmaOperandB::kElements / 2; ++ii) + { + operand_bf16x2_ptr[ii] = __hmul2(operand_bf16x2_ptr[ii], scalex2); + } + } + } +#else + // Slow path not implemented here on purpose. If we need to do HMMA on older arch, scale conversion should + // happen before scales are stored to shared memory and we should use the fp16 dequantizer. This will avoid + // numerous conversion instructions in GEMM main loop. + arch::device_breakpoint(); +#endif + } + + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) + { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_scale_ += offset; + pointer_zero_ += offset; + } + +private: + ElementScale const* pointer_scale_; + ElementScale const* pointer_zero_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +// Specialization for Turing & Ampere +template < + /// Underlying matrix multiply operator (concept: MmaTensorOp) + typename MmaOperator_, + /// Shape of the warp level matrix multiply (concept: GemmShape) + typename Shape_, + /// + WeightOnlyQuantOp QuantOp_> +class MmaTensorOpDequantizer= 75 + && platform::is_same::value>::type> +{ + +public: + /// Mma Operator + using MmaOperator = MmaOperator_; + + // The architecture specific mma ooperator being used + using ArchMmaOperator = typename MmaOperator::ArchMmaOperator; + + // Mma Instruction Shape + using InstructionShape = typename ArchMmaOperator::Shape; + + // This is the ratio of the load instruction vs the compute instruction. + static constexpr int kExpansionFactor = MmaOperator::IteratorB::InstructionShape::kRow / InstructionShape::kK; + + /// Type of the scales + using ElementScale = half_t; + + /// Fragment to hold B data before Mma + using FragmentDequantizedOperand = Array; + + // Fragment to hold scale data to apply to B before mma + // We need 1 fp16 per matrix iteration in the N dimension + static constexpr int kColsPerMmaPerThread = 1; + using FragmentScale = Array; + using FragmentZero = Array; + + /// Warp mma shape + using Shape = Shape_; + + /// Layout of the scales in shared memory + using Layout = layout::RowMajor; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx) + { + int const warp_offset = warp_idx_n * Shape::kN; + int const quad = lane_idx / 4; + int const thread_offset = warp_offset + quad; + pointer_scale_ = smem_scales.data() + thread_offset; + if constexpr (hasZero(QuantOp)) + { + pointer_zero_ = smem_zeros.data() + thread_offset; + } + } + + CUTLASS_DEVICE + MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx) + : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) + { + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag) + { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + + CUTLASS_DEVICE + void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag) + { + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB + = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn + == FragmentDequantizedOperand::kElements, + ""); + + multiplies mul_op; + + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); + } + } + + CUTLASS_DEVICE + void load(FragmentScale& scale_frag, FragmentScale& zero_frag) + { + if constexpr (hasZero(QuantOp)) + { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + zero_frag[mma_n_iter] = pointer_zero_[mma_n_iter * InstructionShape::kN]; + } + } + else + { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + scale_frag[mma_n_iter] = pointer_scale_[mma_n_iter * InstructionShape::kN]; + } + } + } + + CUTLASS_DEVICE + void dequantize( + FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag) + { + using _MmaOperandB = typename ArchMmaOperator::FragmentB; + using ExpandedMmaOperandB + = Array; + static_assert(ExpandedMmaOperandB::kElements * MmaOperator::MmaIterations::kColumn + == FragmentDequantizedOperand::kElements, + ""); + + multiplies mul_op; + ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast(&operand_frag); + + if constexpr (hasZero(QuantOp)) + { + plus plus_op; + + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + operand_frag_ptr[mma_n_iter] + = plus_op(mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]), zero_frag[mma_n_iter]); + } + } + else + { + CUTLASS_PRAGMA_UNROLL + for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) + { + operand_frag_ptr[mma_n_iter] = mul_op(operand_frag_ptr[mma_n_iter], scale_frag[mma_n_iter]); + } + } + } + + // Adds a pointer offset in units of elements. + CUTLASS_DEVICE + void add_pointer_offset(int64_t const& offset) + { + static_assert(sizeof(ElementScale) > 1, ""); + pointer_scale_ += offset; + pointer_zero_ += offset; + } + +private: + ElementScale const* pointer_scale_; + ElementScale const* pointer_zero_; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h new file mode 100644 index 00000000000..4acef2d180f --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h @@ -0,0 +1,224 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace tensorrt_llm +{ +namespace cutlass_extensions +{ +// Note: The shapes are in the format MxNxK. The K shape of the runtime config MUST match the K shape +// in the kernel layout details when doing weight only quantization. +enum class CutlassTileConfig +{ + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + // SiMT config + CtaShape128x128x8_WarpShape64x64x8, + + // TensorCore configs CTA_N = 128, CTA_K = 64 + // Warp configs for M=16 + CtaShape16x128x64_WarpShape16x32x64, + // Warp configs for M=32 + CtaShape32x128x64_WarpShape32x32x64, + + // Warp configs for M=64 + CtaShape64x128x64_WarpShape32x64x64, + CtaShape64x64x128_WarpShape32x64x64, + CtaShape64x128x64_WarpShape64x32x64, + + // Warp configs for M=128 + CtaShape128x64x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape64x32x64, + CtaShape128x128x64_WarpShape64x64x64, + CtaShape128x128x64_WarpShape128x32x64, + CtaShape128x256x64_WarpShape64x64x64, + + // Warp configs for M=256 + CtaShape256x128x64_WarpShape64x64x64, + + // TensorCore config CTA_N = 64, CTA_K = 128 + CtaShape128x64x128_WarpShape64x32x128, + + // TensorCore config CTA_N = 256, CTA_K = 64 + CtaShape16x256x64_WarpShape16x64x64, + + // TensorCore config CTA_N = 256, CTA_K = 128 + CtaShape16x256x128_WarpShape16x64x128 + +}; + +enum class SplitKStyle +{ + NO_SPLIT_K, + SPLIT_K_SERIAL, + STREAM_K, // Sm80+ + // SPLIT_K_PARALLEL // Not supported yet +}; + +enum class CutlassTileConfigSM90 +{ + // Signals that we should run heuristics do choose a config + Undefined, + + // Signals that we should run heuristics do choose a config + ChooseWithHeuristic, + + // CTA configs for M=64 + CtaShape64x16x128B, + CtaShape64x32x128B, + CtaShape64x64x128B, + CtaShape64x128x128B, + CtaShape64x256x128B, + + // CTA configs for M=128 + CtaShape128x16x128B, + CtaShape128x32x128B, + CtaShape128x64x128B, + CtaShape128x128x128B, + CtaShape128x256x128B, + + // CTA configs for M=128 + CtaShape256x128x128B, +}; + +enum class MainloopScheduleType +{ + AUTO // Automatically selects between pingpong and cooperative schedules on Hopper. On older architectures, this + // defaults to the "legacy" main loop schedule. +}; + +enum class EpilogueScheduleType +{ + AUTO // Automatically chooses an epilogue schedule compatible with the selected main loop schedule for Hopper. For + // architectures older than hopper, the epilogue is always performed by the same thread block as the main loop. +}; + +enum class ClusterShape +{ + ClusterShape_1x1x1, + ClusterShape_2x1x1, + ClusterShape_1x2x1, + ClusterShape_2x2x1, + ClusterShape_1x8x1, + ClusterShape_8x1x1 +}; + +struct CutlassGemmConfig +{ + enum CandidateConfigTypeParam : int + { + NONE = 0, + WEIGHT_ONLY = 1u << 0, + SIMT_ONLY = 1u << 1, + INT8_ONLY = 1u << 2, + HOPPER = 1u << 3, + GROUPED_GEMM = 1u << 4, + FP8_ONLY = 1u << 5, + }; + + CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic; + SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K; + int split_k_factor = -1; + int stages = -1; + + // config options for sm90 + CutlassTileConfigSM90 tile_config_sm90 = CutlassTileConfigSM90::ChooseWithHeuristic; + MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO; + EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO; + ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1; + bool is_sm90 = false; + + CutlassGemmConfig() {} + + CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages) + : tile_config(tile_config) + , split_k_style(split_k_style) + , split_k_factor(split_k_factor) + , stages(stages) + , is_sm90(false) + { + } + + CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90, MainloopScheduleType mainloop_schedule, + EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape) + : tile_config_sm90(tile_config_sm90) + , mainloop_schedule(mainloop_schedule) + , epilogue_schedule(epilogue_schedule) + , cluster_shape(cluster_shape) + , is_sm90(true) + { + } + + std::string toString() const + { + std::stringstream tactic; + tactic << "Cutlass GEMM Tactic"; + if (tile_config_sm90 != tensorrt_llm::cutlass_extensions::CutlassTileConfigSM90::ChooseWithHeuristic) + { + assert(is_sm90 && "Invalid cutlass GEMM config"); + tactic << "\n\tstyle=TMA" + << "\n\ttile shape ID: " << (int) tile_config_sm90 << "\n\tcluster shape ID: " << (int) cluster_shape + << "\n\tmainloop sched: " << (int) mainloop_schedule << "\n\tepi sched: " << (int) epilogue_schedule; + } + else if (tile_config != tensorrt_llm::cutlass_extensions::CutlassTileConfig::ChooseWithHeuristic) + { + assert(!is_sm90 && "Invalid cutlass GEMM config"); + tactic << "\n\tstyle=compatible" + << "\n\ttile shape ID: " << (int) tile_config << "\n\tstages: " << (int) stages + << "\n\tsplit k: " << (int) split_k_factor; + } + else + { + tactic << "\n\tundefined"; + } + tactic << "\n"; + return tactic.str(); + } +}; + +inline std::ostream& operator<<(std::ostream& out, CutlassGemmConfig const& config) +{ + // clang-format off + if (config.is_sm90) + { + out << "tile_config_sm90_enum: " << int(config.tile_config_sm90) + << ", mainloop_schedule_enum: " << int(config.mainloop_schedule) + << ", epilogue_schedule_enum: " << int(config.epilogue_schedule) + << ", cluster_shape_enum: " << int(config.cluster_shape); + } + else + { + out << "tile_config_enum: " << int(config.tile_config) + << ", split_k_style_enum: " << int(config.split_k_style) + << ", split_k_factor: " << config.split_k_factor + << ", stages: " << config.stages; + } + // clang-format on + return out; +} + +} // namespace cutlass_extensions +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h new file mode 100644 index 00000000000..44ba79680e6 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h @@ -0,0 +1,447 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Boost-like numeric conversion operator for int8 and CUTLASS int4b_t interleaved in a register +*/ + +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/array.h" +#include "cutlass/half.h" +#include "cutlass/numeric_types.h" + +namespace cutlass +{ + +// This converter is meant to be used with data interleaved in a 32-bit register where the even elements are in the low +// bits and the odd elemeents are in the high bits of the register. In addition, it assumes elements were originally +// signed and had a bias of 2**(b-1) added (where b is the number of bits in the type) to make all numbers unsigned. +// This converter will uninterleave the data and subtract the bias while converting to the result type. +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter +{ + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[0]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" : "=r"(h[1]) : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); + + // Lastly, we subtract 1152 from our constructed number using fp16 math to get our signed integer as fp16. + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[1]) : "r"(h[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) + { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter +{ + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + uint32_t* bf16_result_ptr = reinterpret_cast(&result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t fp32_base = 0x4B000000; + float fp32_intermediates[4]; + + // Construct FP32s, bfloat does not have enough mantissa for IADD trick + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + + // Subtract out fp32_base + 128 to make the unsigned integer signed. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 4; ++ii) + { + fp32_intermediates[ii] -= 8388736.f; + } + + // Truncate the fp32 representation and pack up as bfloat16s. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 2; ++ii) + { + bf16_result_ptr[ii] + = __byte_perm(fp32_intermediates_casted[2 * ii + 0], fp32_intermediates_casted[2 * ii + 1], 0x7632); + } +#else + // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use + // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. + result.clear(); // Suppress compiler warning + arch::device_breakpoint(); +#endif + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) + { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter +{ + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; + + uint32_t* h = reinterpret_cast(&result); + uint32_t const i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t BOTTOM_MASK = 0x000f000f; + static constexpr uint32_t TOP_MASK = 0x00f000f0; + static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400; + + // Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing + // format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions. + // In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and + // elt_67 to fp16 without having to shift them to the bottom bits before hand. + + // Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue + // immediately before required. + const uint32_t top_i4s = i4s >> 8; + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_23 (i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[1]) + : "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[2]) + : "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + // Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[3]) + : "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut)); + + // I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the + // half2 ctor. In this case, I chose performance reliability over code readability. + + // This is the half2 {1032, 1032} represented as an integer. + static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408; + // This is the half2 {1 / 16, 1 / 16} represented as an integer. + static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00; + // This is the half2 {-72, -72} represented as an integer. + static constexpr uint32_t NEG_72 = 0xd480d480; + + // Finally, we construct the output numbers. + // Convert elt_01 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_23 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + // Convert elt_45 + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM)); + // Convert elt_67 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_72)); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static constexpr int VEC_WIDTH = 8; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) + { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template <> +struct FastInterleavedAndBiasedNumericArrayConverter +{ + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + result_type result; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + uint32_t* h = reinterpret_cast(&result); + uint32_t const source_i4s = reinterpret_cast(source); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t I4s_TO_BF16s_MAGIC_NUM = 0x43004300; + + // We don't have enough mantissa to remove as much shift overhead as FP16, so we must loop. + // No shift needed for first item. + uint32_t i4s = source_i4s; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[0]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + CUTLASS_PRAGMA_UNROLL + for (int ii = 1; ii < result_type::kElements / 2; ++ii) + { + i4s >>= sizeof_bits::value; + // (i4s & 0x000f000f) | 0x43004300 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[ii]) + : "r"(i4s), "n"(MASK), "n"(I4s_TO_BF16s_MAGIC_NUM), "n"(immLut)); + } + + // This is the BF16 {-136, -136} represented as an integer. + static constexpr uint32_t BF16_BIAS = 0xC308C308; + static constexpr uint32_t BF16_ONE = 0x3F803F80; + + // Finally, we construct the output numbers. + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < result_type::kElements / 2; ++ii) + { + // Since this section is for Ampere+, we use bf16 fma to do the bias subtraction + asm("fma.rn.bf16x2 %0, %1, %2, %3;\n" : "=r"(h[ii]) : "r"(h[ii]), "r"(BF16_ONE), "r"(BF16_BIAS)); + } +#else + // Disable this on architectures older than Ampere since they lack hardware for bf16 mma. If one wishes to use + // HMMA on older hardware, they should Convert directly to FP16 using FP16 converters. + arch::device_breakpoint(); + result.clear(); // Suppress compiler warning. +#endif + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +template +struct FastInterleavedAndBiasedNumericArrayConverter +{ + static constexpr int VEC_WIDTH = 8; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 8."); + + using result_type = Array; + using source_type = Array; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) + { + using scalar_result_type = typename result_type::Element; + using scalar_source_type = typename source_type::Element; + FastInterleavedAndBiasedNumericArrayConverter + convert_vector_; + + result_type result; + using vec_result = Array; + using vec_source = Array; + + vec_result* result_ptr = reinterpret_cast(&result); + vec_source const* source_ptr = reinterpret_cast(&source); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / VEC_WIDTH; ++i) + { + result_ptr[i] = convert_vector_(source_ptr[i]); + } + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) + { + return convert(s); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h new file mode 100644 index 00000000000..5a0cd295708 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/tile_interleaved_layout.h @@ -0,0 +1,66 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines new layouts needed for MoE +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/pitch_linear_coord.h" + +namespace cutlass +{ +namespace layout +{ + +template +struct ColumnMajorTileInterleave +{ + static constexpr int kRowsPerTile = RowsPerTile; + static constexpr int kColumnsInterleaved = ColumnsInterleaved; +}; + +template +struct IsColumnMajorTileInterleave +{ + static constexpr bool value = false; +}; + +template +struct IsColumnMajorTileInterleave> +{ + static constexpr bool value = true; +}; + +} // namespace layout +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h new file mode 100644 index 00000000000..6095925e372 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/transform/threadblock/fine_grained_scale_zero_iterator.h @@ -0,0 +1,250 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates for visiting scales to be used when dequantizing the weights for weight-only GEMM + quantization. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass +{ +namespace transform +{ +namespace threadblock +{ + +//////////////////////////////////////////////////////////////////////////////// + +template +class FineGrainedScaleZeroIterator; + +template +class FineGrainedScaleZeroIterator +{ +public: + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = 0; + static int const kAlignment = Alignment_; + + static int const kAccessesPerVector = 1; + + /// Row index of scales corresponding to the groupsize of 64 + int row_groupsize64_; + int group_size_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using AccessType = AlignedArray; + + using Fragment = cutlass::Array; + + // For compatibility with existing iterator interface + struct Params + { + LongIndex stride_ = 0; + + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_ = 0; + + // Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : stride_(layout.stride(0)) + { + inc_advance_ = Shape::kRow * stride_ * sizeof_bits::value / 8; + } + }; + +private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + +private: + // + // Data members + // + + /// Parameters object with precomputed internal state + Params const params_; + + /// Internal pointer to first access of tile + BytePointer pointer_scale_; + BytePointer pointer_zero_; + + bool is_valid_ = false; + +public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_DEVICE + FineGrainedScaleZeroIterator( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of scale tensor + Pointer pointer_scale, + ///< Pointer to start of zero tensor + Pointer pointer_zero, + ///< Extent of the scale and bias + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + ///< Group size + int group_size) + : params_(params) + , pointer_scale_(reinterpret_cast(const_cast(pointer_scale))) + , pointer_zero_(reinterpret_cast(const_cast(pointer_zero))) + { + row_groupsize64_ = threadblock_offset.row(); + group_size_ = group_size; + + const LongIndex tb_row_byte_offset + = threadblock_offset.row() / (group_size / 64) * params_.stride_ * sizeof_bits::value / 8; + const LongIndex tb_col_byte_offset = threadblock_offset.column() * sizeof_bits::value / 8; + pointer_scale_ += (tb_row_byte_offset + tb_col_byte_offset); + + if (pointer_zero_ != nullptr) + { + pointer_zero_ += (tb_row_byte_offset + tb_col_byte_offset); + } + + static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment; + + int const thread_row = thread_id / THREADS_PER_ROW; + int const thread_col = thread_id % THREADS_PER_ROW; + + const LongIndex thread_row_byte_offset = thread_row * params_.stride_ * sizeof_bits::value / 8; + const LongIndex thread_col_byte_offset = thread_col * kAlignment * sizeof_bits::value / 8; + pointer_scale_ += (thread_row_byte_offset + thread_col_byte_offset); + if (pointer_zero_ != nullptr) + { + pointer_zero_ += (thread_row_byte_offset + thread_col_byte_offset); + } + + // For the rows, we must check that we are within the extent AND the tile to avoid extra reads on + // a given iteration. The same threads will be responsible for issues reads since the number of scales + // read in a given iteration is a constant. Therefore, we should never have to update is_valid_ + // outside of the constructor. + int const global_row = threadblock_offset.row() + thread_row; + int const global_col = threadblock_offset.column() + thread_col * kAlignment; + + bool const row_in_bounds = global_row < extent.row() && thread_row < Shape::kRow; + bool const col_in_bounds = global_col < extent.column(); + + is_valid_ = row_in_bounds && col_in_bounds; + } + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE FineGrainedScaleZeroIterator(Params const& params, ///< Precomputed parameters object + Pointer pointer_scale, ///< Pointer to start of scale tensor + Pointer pointer_zero, ///< Pointer to start of zero tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + int group_size) + : FineGrainedScaleZeroIterator( + params, pointer_scale, pointer_zero, extent, thread_id, make_Coord(0, 0), group_size) + { + } + + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) + { + const LongIndex row_byte_offset = tile_offset.row() * params_.inc_advance_; + const LongIndex col_byte_offset = tile_offset.column() * Shape::kColumn * sizeof_bits::value / 8; + pointer_scale_ += row_byte_offset + col_byte_offset; + if (pointer_zero_ != nullptr) + { + pointer_zero_ += row_byte_offset + col_byte_offset; + } + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE void clear_mask(bool enable = true) + { + is_valid_ &= (!enable); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const + { + return is_valid_; + } + + /// Returns a scale pointer + CUTLASS_HOST_DEVICE + AccessType* get_scale() const + { + return reinterpret_cast(pointer_scale_); + } + + /// Returns a zero pointer + CUTLASS_HOST_DEVICE + AccessType* get_zero() const + { + return reinterpret_cast(pointer_zero_); + } +}; + +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp new file mode 100644 index 00000000000..b430380b014 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp @@ -0,0 +1,181 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cute/util/print.hpp" + +using namespace cute; + +/// Function object that applies an index to its argument +template +struct IndexedGather +{ + CUTE_HOST_DEVICE constexpr IndexedGather(Iter indices = {}) + : indices_(indices) + { + } + + template + CUTE_HOST_DEVICE constexpr auto operator()(I i) const + { + return indices_[i]; + } + + CUTE_HOST_DEVICE friend void print(IndexedGather const& s) + { + cute::print("Indexed{"); + print(s.indices_); + print("}"); + } + + Iter indices_; +}; + +/// Custom stride object that applies a function followed by a stride +template +struct CustomStride +{ + CUTE_HOST_DEVICE constexpr CustomStride(Func const& func, Stride const& stride) + : func_(func) + , stride_(stride) + { + } + + template + CUTE_HOST_DEVICE constexpr friend auto operator*(I i, CustomStride const& s) + { + return s.func_(i) * s.stride_; + } + + template + CUTE_HOST_DEVICE constexpr friend auto operator*(CustomStride const& s, I i) + { + return s.func_(i) * s.stride_; + } + + CUTE_HOST_DEVICE friend void print(CustomStride const& s) + { + cute::print("Custom{"); + print(s.func_); + cute::print(","); + print(s.stride_); + cute::print("}"); + } + + template + CUTE_HOST_DEVICE constexpr friend auto safe_div(CustomStride const& s, Div const& div) + { + return CustomStride(s.func_, safe_div(s.stride_, div)); + } + + // Circumvent the requirement on make_layout that shape and stride are integral + template + CUTE_HOST_DEVICE constexpr friend auto make_layout(Shape const& shape, CustomStride const& stride) + { + return Layout(shape, stride); + } + + Func func_; + Stride stride_; +}; + +template +CUTLASS_HOST_DEVICE auto make_custom_stride_layout(Stride const& stride, Func&& func) +{ + // Use a dummy shape and replace the first non-unit and non-zero stride with a custom gather stride + auto idx = find_if(stride, [](auto x) { return !is_constant<1, decltype(x)>{} && !is_constant<0, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + return make_layout( + repeat_like(stride, _1{}), replace(stride, CustomStride{static_cast(func), get(stride)})); +} + +/// Helper function to optionally create a gather tensor +template +CUTLASS_HOST_DEVICE auto make_gather_tensor(Iterator iter, Shape const& shape, Stride const& stride, Func&& func) +{ + Layout matrix_layout = make_identity_layout(shape); + auto offset = as_arithmetic_tuple(repeat_like(shape, _0{})); + Layout gather_layout = make_custom_stride_layout(stride, static_cast(func)); + return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout}); +} + +namespace cute +{ + +template +CUTE_HOST_DEVICE constexpr auto upcast(Shape const& shape, Stride const& stride) +{ + if constexpr (is_tuple::value) + { + return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast(s, d); }); + } + else if constexpr (is_scaled_basis::value) + { + if constexpr (Stride::mode() == I) + { + return make_layout(shape_div(shape, Int{}), shape_div(stride, Int{})); + } + else + { + return make_layout(shape, stride); + } + } + else + { + return upcast(shape, stride); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr auto upcast( + ComposedLayout, Offset, Layout> const& layout) +{ + // Find index of the stride-1 mode - that is the only one that requires updating inner shape and offset + auto idx = find_if(layout.layout_a().stride(), [](auto x) { return is_constant<1, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + + // Upcast the outer layout (works as expected) + auto outer = upcast(layout.layout_a()); + + // Upcast the accumulated offset along stride-1 mode + auto offset = as_arithmetic_tuple(replace(layout.offset(), upcast(get(layout.offset())))); + + // Upcast the inner layout's shape along stride-1 mode + auto inner = upcast(layout.layout_b().shape(), layout.layout_b().stride()); + + return composition(outer, offset, inner); +} + +} // namespace cute diff --git a/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/weight_only_quant_op.h b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/weight_only_quant_op.h new file mode 100644 index 00000000000..64774428e9f --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/weight_only_quant_op.h @@ -0,0 +1,58 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines iterators used by warp-level matrix multiply operations targeting Tensor Cores. +*/ + +#pragma once + +namespace cutlass +{ + +enum class WeightOnlyQuantOp +{ + UNDEFINED, + PER_COLUMN_SCALE_ONLY, + FINEGRAINED_SCALE_ONLY, + FINEGRAINED_SCALE_AND_ZEROS +}; + +constexpr bool isFinegrained(WeightOnlyQuantOp op) +{ + return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS || op == WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY; +} + +constexpr bool hasZero(WeightOnlyQuantOp op) +{ + return op == WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS; +} + +} // namespace cutlass diff --git a/sgl-kernel/THIRDPARTYNOTICES.txt b/sgl-kernel/THIRDPARTYNOTICES.txt index c930aa5dd3d..fd717589792 100644 --- a/sgl-kernel/THIRDPARTYNOTICES.txt +++ b/sgl-kernel/THIRDPARTYNOTICES.txt @@ -223,3 +223,208 @@ BSD 3-Clause "New" License 3rdparty/cutlass include/flashinfer/attention/hopper/block_sparse_gather.cuh + +Notice for NVIDIA/TensorRT-LLM +------------------------------- + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 90c3cbc1d3c..50299140312 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -39,6 +39,8 @@ def _get_version(): cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default)) flashinfer = root / "3rdparty" / "flashinfer" turbomind = root / "3rdparty" / "turbomind" +tensorrt_llm_parent = root / "3rdparty" +tensorrt_llm = root / "3rdparty" / "tensorrt_llm" include_dirs = [ cutlass.resolve() / "include", cutlass.resolve() / "tools" / "util" / "include", @@ -51,6 +53,8 @@ def _get_version(): "cublasLt", turbomind.resolve(), turbomind.resolve() / "src", + tensorrt_llm_parent.resolve(), + tensorrt_llm.resolve() / "cutlass_extensions" / "include", ] nvcc_flags = [