From a105313511e7c0090dba3307c992c5c15916a4b4 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Thu, 30 Jan 2025 09:32:55 -0800 Subject: [PATCH 1/4] upd --- sgl-kernel/setup.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 50299140312..e5d097b14e7 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -94,6 +94,13 @@ def _get_version(): "3rdparty/flashinfer/csrc/sampling.cu", "3rdparty/flashinfer/csrc/renorm.cu", "3rdparty/flashinfer/csrc/rope.cu", + "3rdparty/tensorrt_llm/common/assert.cpp", + "3rdparty/tensorrt_llm/common/cublasMMWrapper.cpp", + "3rdparty/tensorrt_llm/common/cudaDriverWrapper.cpp", + "3rdparty/tensorrt_llm/common/logger.cpp", + "3rdparty/tensorrt_llm/common/stringUtils.cpp", + "3rdparty/tensorrt_llm/common/tllmException.cpp", + "3rdparty/tensorrt_llm/common/cudaFp8Utils.cu", ] enable_bf16 = os.getenv("SGL_KERNEL_ENABLE_BF16", "0") == "1" From bc9b56dc187ad5f3ba32c4383dd2e7d8546c595f Mon Sep 17 00:00:00 2001 From: zhyncs Date: Thu, 30 Jan 2025 10:08:30 -0800 Subject: [PATCH 2/4] upd --- .../cutlass_kernels/cutlass_heuristic.cpp | 428 ++++++++++ .../cutlass_kernels/cutlass_heuristic.h | 58 ++ .../cutlass_kernels/cutlass_preprocessors.cpp | 803 ++++++++++++++++++ .../cutlass_kernels/cutlass_preprocessors.h | 76 ++ .../cutlass_kernels/cutlass_type_conversion.h | 112 +++ sgl-kernel/setup.py | 11 + 6 files changed, 1488 insertions(+) create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h create mode 100644 sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp new file mode 100644 index 00000000000..347d1caac82 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp @@ -0,0 +1,428 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h" +#include "tensorrt_llm/common/cudaBf16Wrapper.h" + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ + +#include "cutlass/gemm/gemm.h" +#include "cutlass/numeric_types.h" +#include "tensorrt_llm/common/assert.h" + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic pop +#endif // __GNUC + +#include +#include +#include + +using namespace tensorrt_llm::cutlass_extensions; + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ + +struct TileShape +{ + int m; + int n; +}; + +TileShape get_cta_shape_for_config(CutlassTileConfig tile_config) +{ + switch (tile_config) + { + case CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64: return TileShape{16, 128}; + case CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64: return TileShape{16, 256}; + case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: return TileShape{32, 128}; + case CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64: return TileShape{64, 64}; + case CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64: + case CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64: return TileShape{64, 128}; + case CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64: return TileShape{128, 64}; + case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: + case CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64: + case CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64: + case CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64: return TileShape{128, 128}; + case CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64: return TileShape{128, 256}; + case CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64: return TileShape{256, 128}; + case CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128: return TileShape{16, 256}; + default: TLLM_THROW("[get_grid_shape_for_config] Invalid config"); + } +} + +bool is_valid_split_k_factor(int64_t const m, int64_t const n, int64_t const k, TileShape const tile_shape, + int const split_k_factor, size_t const workspace_bytes, bool const is_weight_only) +{ + + // All tile sizes have a k_tile of 64. + static constexpr int k_tile = 64; + + // For weight-only quant, we need k and k_elements_per_split to be a multiple of cta_k + if (is_weight_only) + { + if ((k % k_tile) != 0) + { + return false; + } + + if ((k % split_k_factor) != 0) + { + return false; + } + + int const k_elements_per_split = k / split_k_factor; + if ((k_elements_per_split % k_tile) != 0) + { + return false; + } + } + + // Check that the workspace has sufficient space for this split-k factor + int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; + int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; + int const required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; + + if (required_ws_bytes > workspace_bytes) + { + return false; + } + + return true; +} + +std::vector get_candidate_tiles( + int const sm, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param) +{ + enum class CutlassGemmType : char + { + Default, + WeightOnly, + Simt, + Int8, + Fp8 + }; + + CutlassGemmType gemm_type = CutlassGemmType::Default; + if (config_type_param & CutlassGemmConfig::SIMT_ONLY) + { + gemm_type = CutlassGemmType::Simt; + } + else if (config_type_param & CutlassGemmConfig::WEIGHT_ONLY) + { + gemm_type = CutlassGemmType::WeightOnly; + } + else if (config_type_param & CutlassGemmConfig::INT8_ONLY) + { + gemm_type = CutlassGemmType::Int8; + } + else if (config_type_param & CutlassGemmConfig::FP8_ONLY) + { + gemm_type = CutlassGemmType::Fp8; + } + + std::vector base_configs{ + CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64}; + if (sm >= 75) + { + base_configs.push_back(CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64); + } + + switch (gemm_type) + { + case CutlassGemmType::Simt: return {CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8}; + case CutlassGemmType::WeightOnly: + if (sm >= 75) + { + return {CutlassTileConfig::CtaShape16x128x64_WarpShape16x32x64, + CutlassTileConfig::CtaShape16x256x64_WarpShape16x64x64, + CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64}; + } + else + { + return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64}; + } + case CutlassGemmType::Int8: + return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, + CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; + case CutlassGemmType::Fp8: + if (config_type_param & CutlassGemmConfig::GROUPED_GEMM) + { + if (sm == 89) + { + return {CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128, + CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, + CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64}; + } + else + { + // no valid ampere style fp8 configs for sm90 + return {}; + } + } + else + { + if (sm == 89) + { + return {CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape32x64x64, + CutlassTileConfig::CtaShape64x64x128_WarpShape32x64x64, + CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x64x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape64x32x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64, + CutlassTileConfig::CtaShape128x256x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape256x128x64_WarpShape64x64x64, + CutlassTileConfig::CtaShape128x64x128_WarpShape64x32x128, + CutlassTileConfig::CtaShape16x256x128_WarpShape16x64x128}; + } + else + { + return {}; + } + } + default: return base_configs; + } +} + +std::vector get_candidate_tiles_sm90( + int const sm, CutlassGemmConfig::CandidateConfigTypeParam const config) +{ +#ifdef FAST_BUILD + // Fast build disables all configs except this one for SM90 + return {CutlassTileConfigSM90::CtaShape128x128x128B}; +#else + if (config & CutlassGemmConfig::GROUPED_GEMM) + { + return {CutlassTileConfigSM90::CtaShape128x16x128B, CutlassTileConfigSM90::CtaShape128x32x128B, + CutlassTileConfigSM90::CtaShape128x64x128B, CutlassTileConfigSM90::CtaShape128x128x128B, + CutlassTileConfigSM90::CtaShape128x256x128B, CutlassTileConfigSM90::CtaShape256x128x128B}; + } + else + { + return {CutlassTileConfigSM90::CtaShape64x16x128B, CutlassTileConfigSM90::CtaShape64x32x128B, + CutlassTileConfigSM90::CtaShape64x64x128B, CutlassTileConfigSM90::CtaShape64x128x128B, + CutlassTileConfigSM90::CtaShape64x256x128B, CutlassTileConfigSM90::CtaShape128x16x128B, + CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B, + CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B}; + } +#endif +} + +// We only compile CUTLASS kernels with multi-cast along M if the M tile is >= 128. This is purely to improve +// compilation speed. +bool supports_mcast_along_m(CutlassTileConfigSM90 const tile) +{ +#ifdef FAST_BUILD + return false; +#else + std::set valid_tiles{CutlassTileConfigSM90::CtaShape128x16x128B, + CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B, + CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B, + CutlassTileConfigSM90::CtaShape256x128x128B}; + return valid_tiles.count(tile) == 1; +#endif +} + +// We only compile CUTLASS kernels with multi-cast along N if the N tile is >= 128. This is purely to improve +// compilation speed. +bool supports_mcast_along_n(CutlassTileConfigSM90 const tile) +{ +#ifdef FAST_BUILD + return false; +#else + std::set valid_tiles{CutlassTileConfigSM90::CtaShape64x128x128B, + CutlassTileConfigSM90::CtaShape64x256x128B, CutlassTileConfigSM90::CtaShape128x128x128B, + CutlassTileConfigSM90::CtaShape128x256x128B, CutlassTileConfigSM90::CtaShape256x128x128B}; + return valid_tiles.count(tile) == 1; +#endif +} + +std::vector get_candidate_configs( + int sm, int const max_split_k, CutlassGemmConfig::CandidateConfigTypeParam const config_type_param) +{ + if (sm == 90 && (config_type_param & CutlassGemmConfig::HOPPER)) + { + std::vector tiles = get_candidate_tiles_sm90(sm, config_type_param); + + std::vector candidate_configs; + for (auto const& tile_config : tiles) + { + CutlassGemmConfig config( + tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1); + candidate_configs.push_back(config); + + bool const has_m_mcast = supports_mcast_along_m(tile_config); + bool const has_n_mcast = supports_mcast_along_n(tile_config); + if (has_m_mcast) + { + CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, + ClusterShape::ClusterShape_2x1x1); + candidate_configs.push_back(config); + } + + if (has_n_mcast) + { + CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, + ClusterShape::ClusterShape_1x2x1); + candidate_configs.push_back(config); + } + + if (has_m_mcast && has_n_mcast) + { + CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, + ClusterShape::ClusterShape_2x2x1); + candidate_configs.push_back(config); + } + } + return candidate_configs; + } + std::vector tiles = get_candidate_tiles(sm, config_type_param); + + std::vector candidate_configs; + bool const int8_configs_only = config_type_param & CutlassGemmConfig::INT8_ONLY; + int const min_stages = int8_configs_only ? 3 : 2; + int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2); + for (auto const& tile_config : tiles) + { + for (int stages = min_stages; stages <= max_stages; ++stages) + { + CutlassGemmConfig config(tile_config, SplitKStyle::NO_SPLIT_K, 1, stages); + candidate_configs.push_back(config); + if (sm >= 75) + { + for (int split_k_factor = 2; split_k_factor <= max_split_k; ++split_k_factor) + { + auto config = CutlassGemmConfig{tile_config, SplitKStyle::SPLIT_K_SERIAL, split_k_factor, stages}; + candidate_configs.push_back(config); + } + } + } + } + + return candidate_configs; +} + +CutlassGemmConfig estimate_best_config_from_occupancies(std::vector const& candidate_configs, + std::vector const& occupancies, int64_t const m, int64_t const n, int64_t const k, int64_t const num_experts, + int const split_k_limit, size_t const workspace_bytes, int const multi_processor_count, int const is_weight_only) +{ + + if (occupancies.size() != candidate_configs.size()) + { + TLLM_THROW( + "[estimate_best_config_from_occupancies] occpancies and " + "candidate configs vectors must have equal length."); + } + + CutlassGemmConfig best_config; + // Score will be [0, 1]. The objective is to minimize this score. + // It represents the fraction of SM resources unused in the last wave. + float config_score = 1.0f; + int config_waves = INT_MAX; + int current_m_tile = 0; + + int const max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit; + for (int ii = 0; ii < candidate_configs.size(); ++ii) + { + CutlassGemmConfig candidate_config = candidate_configs[ii]; + TileShape tile_shape = get_cta_shape_for_config(candidate_config.tile_config); + int occupancy = occupancies[ii]; + + if (occupancy == 0) + { + continue; + } + + // Keep small tile sizes when possible. + if (best_config.tile_config != CutlassTileConfig::ChooseWithHeuristic && m < current_m_tile + && current_m_tile < tile_shape.m) + { + continue; + } + + int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; + int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; + + for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor) + { + if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only)) + { + int const ctas_per_wave = occupancy * multi_processor_count; + int const ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor; + + int const num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave; + float const num_waves_fractional = ctas_for_problem / float(ctas_per_wave); + float const current_score = float(num_waves_total) - num_waves_fractional; + + float const score_slack = 0.1f; + if (current_score < config_score + || ((config_waves > num_waves_total) && (current_score < config_score + score_slack))) + { + config_score = current_score; + config_waves = num_waves_total; + SplitKStyle split_style + = split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; + best_config = CutlassGemmConfig( + candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages); + current_m_tile = tile_shape.m; + } + else if (current_score == config_score + && (best_config.stages < candidate_config.stages || split_k_factor < best_config.split_k_factor + || current_m_tile < tile_shape.m)) + { + // Prefer deeper pipeline or smaller split-k + SplitKStyle split_style + = split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K; + best_config = CutlassGemmConfig( + candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages); + current_m_tile = tile_shape.m; + config_waves = num_waves_total; + } + } + } + } + + if (best_config.tile_config == CutlassTileConfig::ChooseWithHeuristic) + { + TLLM_THROW("Heurisitc failed to find a valid config."); + } + + return best_config; +} + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h new file mode 100644 index 00000000000..f7b20ea41e8 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cute/tensor.hpp" +#include "cutlass_extensions/gemm_configs.h" +#include "tensorrt_llm/common/cudaUtils.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ + +template +struct should_filter_sm90_gemm_problem_shape +{ +#ifdef FAST_BUILD + constexpr static int TILE_K = 128 * 8 / cutlass::sizeof_bits::value; + using SupportedCtaShape = cute::Shape>; + using SupportedCgaShape = cute::Shape; + + constexpr static bool value + = !cute::is_same_v || !cute::is_same_v; +#else + constexpr static bool value = false; +#endif +}; +template +constexpr static bool should_filter_sm90_gemm_problem_shape_v + = should_filter_sm90_gemm_problem_shape::value; + +std::vector get_candidate_configs( + int sm, int const max_split_k, tensorrt_llm::cutlass_extensions::CutlassGemmConfig::CandidateConfigTypeParam const); + +tensorrt_llm::cutlass_extensions::CutlassGemmConfig estimate_best_config_from_occupancies( + std::vector const& candidate_configs, + std::vector const& occupancies, int64_t const m, int64_t const n, int64_t const k, int64_t const num_experts, + int const split_k_limit, size_t const workspace_bytes, int const multi_processor_count, int const is_weight_only); + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp new file mode 100644 index 00000000000..f5f8488d441 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp @@ -0,0 +1,803 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h" +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaBf16Wrapper.h" +#include "tensorrt_llm/common/stringUtils.h" + +#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h" + +using namespace tensorrt_llm::common; + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ + +struct LayoutDetails +{ + enum class Layout + { + UNKNOWN, + ROW_MAJOR, + COLUMN_MAJOR + }; + + Layout layoutB = Layout::UNKNOWN; + int rows_per_column_tile = 1; + int columns_interleaved = 1; + + bool uses_imma_ldsm = false; +}; + +template +struct getLayoutDetails +{ +}; + +template <> +struct getLayoutDetails +{ + LayoutDetails operator()() + { + LayoutDetails layout_details; + layout_details.layoutB = LayoutDetails::Layout::ROW_MAJOR; + return layout_details; + } +}; + +template <> +struct getLayoutDetails +{ + LayoutDetails operator()() + { + LayoutDetails layout_details; + layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR; + return layout_details; + } +}; + +template +struct getLayoutDetails> +{ + LayoutDetails operator()() + { + LayoutDetails layout_details; + layout_details.layoutB = LayoutDetails::Layout::COLUMN_MAJOR; + layout_details.rows_per_column_tile = RowsPerTile; + layout_details.columns_interleaved = ColumnsInterleaved; + return layout_details; + } +}; + +template +LayoutDetails getLayoutDetailsForArchAndQuantType() +{ + + using CompileTraits = cutlass::gemm::kernel::LayoutDetailsB; + using LayoutB = typename CompileTraits::Layout; + using MmaOperator = typename CompileTraits::Operator; + LayoutDetails details = getLayoutDetails()(); + details.uses_imma_ldsm = std::is_same::value; + return details; +} + +template +LayoutDetails getLayoutDetailsForArch(QuantType quant_type) +{ + int const bits_per_weight_element = get_weight_quant_bits(quant_type); + LayoutDetails details; + switch (quant_type) + { + case QuantType::W8_A16: + details = getLayoutDetailsForArchAndQuantType(); + break; + case QuantType::W4_A16: + details = getLayoutDetailsForArchAndQuantType(); + break; + case QuantType::W4_AFP8: + details = getLayoutDetailsForArchAndQuantType(); + break; + default: TLLM_THROW("Unsupported quantization type"); + } + return details; +} + +LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch) +{ + if (arch >= 75 && arch < 80) + { + return getLayoutDetailsForArch(quant_type); + } + else if (arch >= 80 && arch < 90) + { + return getLayoutDetailsForArch(quant_type); + } + else if (arch == 90) + { + return getLayoutDetailsForArch(quant_type); + } + else + { + TLLM_CHECK_WITH_INFO(false, "Unsupported Arch"); + return LayoutDetails(); + } +} + +// Permutes the rows of B in a way that is compatible with Turing+ architectures. +// +// Throws an error for other architectures. +// The data is permuted such that: +// For W8_A16, each group of 16 rows is permuted using the map below: +// 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15 +// For W4_A16, each group of 32 rows is permuted using the map below: +// 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22 23 30 31 +// For W4_A8, see the map in the code. The idea is similar to above. +// The goal of this permutation is to ensure data ends up in the correct threads after +// we execute LDSM. It counteracts the effect of the data being of different widths. +// For more information about the expected layouts, see the MMA section in the PTX docs. +std::vector get_permutation_map(QuantType quant_type) +{ + + if (quant_type == QuantType::W8_A16) + { + return {0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15}; + } + else if (quant_type == QuantType::W4_A16) + { + return {0, 1, 8, 9, 16, 17, 24, 25, 2, 3, 10, 11, 18, 19, 26, 27, 4, 5, 12, 13, 20, 21, 28, 29, 6, 7, 14, 15, + 22, 23, 30, 31}; + } + else if (quant_type == QuantType::W4_AFP8) + { + return {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23, 8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, + 28, 29, 30, 31}; + } + else + { + TLLM_THROW("Invalid quantization type for LDSM permutation"); + } +} + +void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, int8_t const* quantized_tensor, + std::vector const& shape, QuantType quant_type, int64_t const arch_version) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + // We only want to run this step for weight only quant. + std::vector row_permutation = get_permutation_map(quant_type); + + TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + int const BITS_PER_ELT = get_weight_quant_bits(quant_type); + int const K = 16 / BITS_PER_ELT; + int const ELTS_PER_BYTE = 8 / BITS_PER_ELT; + int const ELTS_PER_REG = 32 / BITS_PER_ELT; + + uint32_t const* input_byte_ptr = reinterpret_cast(quantized_tensor); + uint32_t* output_byte_ptr = reinterpret_cast(permuted_quantized_tensor); + + int MMA_SHAPE_N = 8; + int B_ROWS_PER_MMA = 8 * K; + int const elts_in_int32 = 32 / BITS_PER_ELT; + + int const num_vec_cols = num_cols / elts_in_int32; + + TLLM_CHECK_WITH_INFO( + arch_version >= 75, "Unsupported Arch. Pre-volta not supported. Column interleave not needed on Volta."); + + TLLM_CHECK_WITH_INFO(num_rows % B_ROWS_PER_MMA == 0, + fmtstr("Invalid shape for quantized tensor. Number of rows of quantized matrix must be a multiple of %d", + B_ROWS_PER_MMA)); + TLLM_CHECK_WITH_INFO(num_cols % MMA_SHAPE_N == 0, + fmtstr("Invalid shape for quantized tensor. On turing/Ampere, the number of cols must be a multiple of %d.", + MMA_SHAPE_N)); + + TLLM_CHECK_WITH_INFO(size_t(B_ROWS_PER_MMA) == row_permutation.size(), "Unexpected number of LDSM rows permuted."); + + for (int expert = 0; expert < num_experts; ++expert) + { + const int64_t matrix_offset = expert * int64_t(num_rows) * int64_t(num_vec_cols); + for (int base_row = 0; base_row < num_rows; base_row += B_ROWS_PER_MMA) + { + for (int tile_row = 0; tile_row < B_ROWS_PER_MMA; ++tile_row) + { + + for (int write_col = 0; write_col < num_vec_cols; ++write_col) + { + int const write_row = base_row + tile_row; + int const tile_read_row = row_permutation[tile_row]; + int const read_row = base_row + tile_read_row; + int const read_col = write_col; + + const int64_t read_offset = matrix_offset + int64_t(read_row) * num_vec_cols + read_col; + const int64_t write_offset = matrix_offset + int64_t(write_row) * num_vec_cols + write_col; + + output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; + } + } + } + } +} + +// We need to use this transpose to correctly handle packed int4 and int8 data +// The reason this code is relatively complex is that the "trivial" loops took a substantial +// amount of time to transpose leading to long preprocessing times. This seemed to be a big +// issue for relatively large models. +template +void subbyte_transpose_impl( + int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor, std::vector const& shape) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + constexpr int bits_per_elt = get_weight_quant_bits(quant_type); + + TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + const size_t col_bytes = num_cols * bits_per_elt / 8; + const size_t col_bytes_trans = num_rows * bits_per_elt / 8; + const size_t num_bytes = size_t(num_experts) * num_rows * col_bytes; + + uint8_t const* input_byte_ptr = reinterpret_cast(quantized_tensor); + uint8_t* output_byte_ptr = reinterpret_cast(transposed_quantized_tensor); + + static constexpr int ELTS_PER_BYTE = 8 / bits_per_elt; + + static constexpr int M_TILE_L1 = 64; + static constexpr int N_TILE_L1 = M_TILE_L1 / ELTS_PER_BYTE; + uint8_t cache_buf[M_TILE_L1][N_TILE_L1]; + + static constexpr int VECTOR_WIDTH = std::min(32, N_TILE_L1); + + // We assume the dims are a multiple of vector width. Our kernels only handle dims which are multiples + // of 64 for weight-only quantization. As a result, this seemed like a reasonable tradeoff because it + // allows GCC to emit vector instructions. + TLLM_CHECK_WITH_INFO(!(col_bytes_trans % VECTOR_WIDTH) && !(col_bytes % VECTOR_WIDTH), + fmtstr("Number of bytes for rows and cols must be a multiple of %d. However, num_rows_bytes = %ld and " + "num_col_bytes = %ld.", + VECTOR_WIDTH, col_bytes_trans, col_bytes)); + + int const num_m_tiles = (num_rows + M_TILE_L1 - 1) / M_TILE_L1; + int const num_n_tiles = (col_bytes + N_TILE_L1 - 1) / N_TILE_L1; + + for (size_t expert = 0; expert < num_experts; ++expert) + { + const size_t matrix_offset = expert * num_rows * col_bytes; + for (size_t row_tile_start = 0; row_tile_start < num_rows; row_tile_start += M_TILE_L1) + { + for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes; col_tile_start_byte += N_TILE_L1) + { + + int const row_limit = std::min(row_tile_start + M_TILE_L1, num_rows); + int const col_limit = std::min(col_tile_start_byte + N_TILE_L1, col_bytes); + + for (int ii = 0; ii < M_TILE_L1; ++ii) + { + int const row = row_tile_start + ii; + + for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) + { + int const col = col_tile_start_byte + jj; + + const size_t logical_src_offset = matrix_offset + row * col_bytes + col; + + if (row < row_limit && col < col_limit) + { + for (int v = 0; v < VECTOR_WIDTH; ++v) + { + cache_buf[ii][jj + v] = input_byte_ptr[logical_src_offset + v]; + } + } + } + } + + if constexpr (bits_per_elt == 8) + { + for (int ii = 0; ii < M_TILE_L1; ++ii) + { + for (int jj = ii + 1; jj < N_TILE_L1; ++jj) + { + std::swap(cache_buf[ii][jj], cache_buf[jj][ii]); + } + } + } + else if constexpr (bits_per_elt == 4) + { + + for (int ii = 0; ii < M_TILE_L1; ++ii) + { + // Using M_TILE_L1 here is deliberate since we assume that the cache tile + // is square in the number of elements (not necessarily the number of bytes). + for (int jj = ii + 1; jj < M_TILE_L1; ++jj) + { + int const ii_byte = ii / ELTS_PER_BYTE; + int const ii_bit_offset = ii % ELTS_PER_BYTE; + + int const jj_byte = jj / ELTS_PER_BYTE; + int const jj_bit_offset = jj % ELTS_PER_BYTE; + + uint8_t src_elt = 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset)); + uint8_t tgt_elt = 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset)); + + cache_buf[ii][jj_byte] &= (0xF0 >> (4 * jj_bit_offset)); + cache_buf[jj][ii_byte] &= (0xF0 >> (4 * ii_bit_offset)); + + cache_buf[ii][jj_byte] |= (tgt_elt << (4 * jj_bit_offset)); + cache_buf[jj][ii_byte] |= (src_elt << (4 * ii_bit_offset)); + } + } + } + else + { + TLLM_CHECK_WITH_INFO(false, "Unsupported quantization type."); + } + + const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE; + const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE; + + int const row_limit_trans = std::min(row_tile_start_trans + M_TILE_L1, num_cols); + int const col_limit_trans = std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans); + + for (int ii = 0; ii < M_TILE_L1; ++ii) + { + int const row = row_tile_start_trans + ii; + for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) + { + int const col = col_tile_start_byte_trans + jj; + + const size_t logical_tgt_offset = matrix_offset + row * col_bytes_trans + col; + + if (row < row_limit_trans && col < col_limit_trans) + { + for (int v = 0; v < VECTOR_WIDTH; ++v) + { + output_byte_ptr[logical_tgt_offset + v] = cache_buf[ii][jj + v]; + } + } + } + } + } + } + } +} + +void subbyte_transpose(int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor, + std::vector const& shape, QuantType quant_type) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + if (quant_type == QuantType::W8_A16) + { + subbyte_transpose_impl(transposed_quantized_tensor, quantized_tensor, shape); + } + else if (quant_type == QuantType::W4_A16) + { + subbyte_transpose_impl(transposed_quantized_tensor, quantized_tensor, shape); + } + else if (quant_type == QuantType::W4_AFP8) + { + subbyte_transpose_impl(transposed_quantized_tensor, quantized_tensor, shape); + } + else + { + TLLM_CHECK_WITH_INFO(false, "Invalid quant_type"); + } +} + +void add_bias_and_interleave_int8s_inplace(int8_t* int8_tensor, const size_t num_elts) +{ + for (int ii = 0; ii < num_elts; ++ii) + { + int8_tensor[ii] = int8_t(int(int8_tensor[ii]) + 128); + } + + // Step 2 will transform the layout of a 32-bit register in CUDA in order to match the int4 layout. This has no + // performance benefit and is purely so that int4 and int8 have the same layout. + // Pictorially, this does the following: + // bit 32 0 + // [elt_3 elt_2 elt_1 elt_0] (each elt occupies 8 bits) + // + // And it will rearrange the output 32 bit register to be the following: + // bit 32 0 + // [elt_3 elt_1 elt_2 elt_0] (each elt occupies 8 bits) + + TLLM_CHECK_WITH_INFO(num_elts % 4 == 0, "Dimensions of int8 tensor must be a multiple of 4 for register relayout"); + for (size_t base = 0; base < num_elts; base += 4) + { + std::swap(int8_tensor[base + 1], int8_tensor[base + 2]); + } +} + +void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, const size_t num_elts) +{ + int const num_bytes = num_elts / 2; + + // Step 1 will be to transform all the int4s to unsigned in order to make the dequantize take as little + // instructions as possible in the CUDA code. + for (size_t ii = 0; ii < num_bytes; ++ii) + { + int8_t transformed_packed_int4s = 0; + int8_t transformed_first_elt + = (int8_t(packed_int4_tensor[ii] << 4) >> 4) + 8; // The double shift here is to ensure sign extension + int8_t transformed_second_elt = (packed_int4_tensor[ii] >> 4) + 8; + + TLLM_CHECK_WITH_INFO( + transformed_first_elt >= 0 && transformed_first_elt <= 15, "Illegal result for int4 transform (first elt)"); + TLLM_CHECK_WITH_INFO(transformed_second_elt >= 0 && transformed_second_elt <= 15, + "Illegal result for int4 transform (second elt)"); + + // We don't need to mask in these ops since everything should be in the range 0-15 + transformed_packed_int4s |= transformed_first_elt; + transformed_packed_int4s |= (transformed_second_elt << 4); + packed_int4_tensor[ii] = transformed_packed_int4s; + } + + // Step 2 will transform the layout of a 32-bit register in CUDA in order to minimize the number of shift & logical + // instructions That are needed to extract the int4s in the GEMM main loop. Pictorially, the loop below will do the + // following: Take as input a 32 bit register with layout: bit 32 0 + // [elt_7 elt_6 elt_5 elt_4 elt_3 elt_2 elt_1 elt_0] (each elt occupies 4 bits) + // + // And it will rearrange the output 32 bit register to be the following: + // bit 32 0 + // [elt_7 elt_5 elt_3 elt_1 elt_6 elt_4 elt_2 elt_0] (each elt occupies 4 bits) + + TLLM_CHECK_WITH_INFO(num_bytes % 4 == 0, "Dimensions of int4 tensor must be a multiple of 8 for register relayout"); + const size_t num_registers = num_bytes / 4; + + uint32_t* register_ptr = reinterpret_cast(packed_int4_tensor); + for (size_t ii = 0; ii < num_registers; ++ii) + { + const uint32_t current_register = register_ptr[ii]; + uint32_t transformed_register = 0; + + for (int dest_idx = 0; dest_idx < 8; ++dest_idx) + { + int const src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1; + int const src_shift = 4 * src_idx; + int const dest_shift = 4 * dest_idx; + + const uint32_t src_bits = (current_register >> src_shift) & 0xF; + transformed_register |= (src_bits << dest_shift); + } + register_ptr[ii] = transformed_register; + } +} + +void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + if (quant_type == QuantType::W8_A16) + { + add_bias_and_interleave_int8s_inplace(tensor, num_elts); + } + else if (quant_type == QuantType::W4_A16 || quant_type == QuantType::W4_AFP8) + { + // W4_AFP8 uses the same preprocessor as W4_A16 because the FP8 data must + // be converted to FP16 before the scales can be applied using CUDA cores. + // As a result, we still want permute the data so that it is well aligned + // for conversion to FP16. + add_bias_and_interleave_int4s_inplace(tensor, num_elts); + } + else + { + TLLM_CHECK_WITH_INFO(false, "Invalid quantization type for interleaving."); + } +} + +void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, int8_t const* quantized_tensor, + std::vector const& shape, QuantType quant_type, LayoutDetails details) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + int const BITS_PER_ELT = get_weight_quant_bits(quant_type); + int const elts_in_int32 = 32 / BITS_PER_ELT; + + int const rows_per_tile = details.rows_per_column_tile; + + TLLM_CHECK_WITH_INFO(!(num_rows % elts_in_int32), + fmtstr("The number of rows must be a multiple of %d but the number of rows is %ld.", elts_in_int32, num_rows)); + + uint32_t const* input_byte_ptr = reinterpret_cast(quantized_tensor); + uint32_t* output_byte_ptr = reinterpret_cast(interleaved_quantized_tensor); + + TLLM_CHECK_WITH_INFO(!(num_rows % rows_per_tile), + fmtstr("The number of rows must be a multiple of %d but the number of rows is %ld.", rows_per_tile, num_rows)); + + int const num_vec_rows = num_rows / elts_in_int32; + int const vec_rows_per_tile = rows_per_tile / elts_in_int32; + int const interleave = details.columns_interleaved; + + for (int expert = 0; expert < num_experts; ++expert) + { + const int64_t matrix_offset = expert * int64_t(num_vec_rows) * int64_t(num_cols); + for (int read_col = 0; read_col < num_cols; ++read_col) + { + const int64_t write_col = read_col / interleave; + for (int base_vec_row = 0; base_vec_row < num_vec_rows; base_vec_row += vec_rows_per_tile) + { + for (int vec_read_row = base_vec_row; + vec_read_row < std::min(num_vec_rows, base_vec_row + vec_rows_per_tile); ++vec_read_row) + { + const int64_t vec_write_row = interleave * base_vec_row + + vec_rows_per_tile * (read_col % interleave) + vec_read_row % vec_rows_per_tile; + + const int64_t read_offset = matrix_offset + int64_t(read_col) * num_vec_rows + vec_read_row; + const int64_t write_offset + = matrix_offset + int64_t(write_col) * num_vec_rows * interleave + vec_write_row; + output_byte_ptr[write_offset] = input_byte_ptr[read_offset]; + } + } + } + } +} + +void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, int8_t const* row_major_quantized_weight, + std::vector const& shape, QuantType quant_type, bool force_interleave) +{ + int arch = getSMVersion(); + if (force_interleave && arch == 90) + { + // Workaround for MOE which doesn't have specialised Hopper kernels yet + arch = 80; + } + LayoutDetails details = getLayoutDetailsForTransform(quant_type, arch); + + TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); + + size_t num_elts = 1; + for (auto const& dim : shape) + { + num_elts *= dim; + } + + const size_t num_bytes = num_elts * get_weight_quant_bits(quant_type) / 8; + + std::vector src_buf(num_bytes); + std::vector dst_buf(num_bytes); + std::copy(row_major_quantized_weight, row_major_quantized_weight + num_bytes, src_buf.begin()); + + // Works on row major data, so issue this permutation first. + if (details.uses_imma_ldsm) + { + permute_B_rows_for_mixed_gemm(dst_buf.data(), src_buf.data(), shape, quant_type, arch); + src_buf.swap(dst_buf); + } + + if (details.layoutB == LayoutDetails::Layout::COLUMN_MAJOR) + { + subbyte_transpose(dst_buf.data(), src_buf.data(), shape, quant_type); + src_buf.swap(dst_buf); + } + + if (details.columns_interleaved > 1) + { + interleave_column_major_tensor(dst_buf.data(), src_buf.data(), shape, quant_type, details); + src_buf.swap(dst_buf); + } + + if (arch >= 70 && arch < 90) + { + add_bias_and_interleave_quantized_tensor_inplace(src_buf.data(), num_elts, quant_type); + } + std::copy(src_buf.begin(), src_buf.end(), preprocessed_quantized_weight); +} + +/* + Arguments: + input_weight_ptr - the weight tensor to be quantized. Must be 2-D or 3-D and of type FP16. + + quant_type - the type of the output quantization weight. + + This function does symmetric quantization on 2-D or 3-D tensors. It uses the full int range and assumes the + zero-point is zero and will automatically construct the scales. + + It always quantizes the last axis of the tensor. For 3-D tensors, it operates in "batched" mode where the tensor is + viewed as a stack of matrices and a scale is produced for each column of every matrix. + +Outputs + processed_quantized_weight - quantized AND processed weight for GEMM. This MUST be used with the CUTLASS GEMM + unprocessed_quantized_weight - quantized but unprocessed weights. Useful for reference checking. + scale_ptr - scales for the quantized weight. + + Note that the returned quantized_weights will be preprocessed in a way to accelerate the mixed type GEMM. The data + layout may not make sense if printed. + + Shapes: + quant_type == int8: + If weight is a [m,n] matrix, quantized_weights will have shape [m,n] and scales of shape [n] + If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m,n] and scales of shape [b,n] + quant_type == int4: + If weight is a [m,n] matrix, quantized_weights will have shape [m, ceil(n/2)] and scales of shape [n] + If weight is a [b,m,n] tensor, unprocessed_quantized_weight will have shape [b,m, ceil(n/2)] and scales of shape + [b,n] + + The quantized_weight will be of type torch.int8 and have two int4 values packed in a single byte. This is the + reason for halving the shape. At the time of writing this code, there was not an elegant way to handle this kind + of batched quantization using torch's quantized tensors (to the best of the author's knowledge). Scale tensors + must have a dimension of 1, which breaks the semantics we need for batched weights. + */ + +template +void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight, + ComputeType* scale_ptr, WeightType const* input_weight_ptr, std::vector const& shape, QuantType quant_type, + bool force_interleave) +{ + + TLLM_CHECK_WITH_INFO(processed_quantized_weight, "Processed quantized tensor is NULL"); + TLLM_CHECK_WITH_INFO(scale_ptr, "Scale output pointer is NULL"); + TLLM_CHECK_WITH_INFO(input_weight_ptr, "Input weight pointer is NULL"); + + TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); + const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; + const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; + const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; + + int const bits_in_type = get_weight_quant_bits(quant_type); + int const bytes_per_out_col = num_cols * bits_in_type / 8; + + int const bits_per_weigtht_element = get_weight_quant_bits(quant_type); + + std::vector weight_buf; + if (unprocessed_quantized_weight == nullptr) + { + weight_buf.resize(num_experts * num_rows * num_cols); + unprocessed_quantized_weight = weight_buf.data(); + } + + int const input_mat_size = num_rows * num_cols; + int const quantized_mat_size = num_rows * bytes_per_out_col; + float const quant_range_scale = 1.f / float(1 << (bits_in_type - 1)); + + std::vector per_col_max(num_cols); + + for (int expert = 0; expert < num_experts; ++expert) + { + WeightType const* current_weight = input_weight_ptr + expert * input_mat_size; + int8_t* current_quantized_weight = unprocessed_quantized_weight + expert * quantized_mat_size; + + // First we find the per column max for this expert weight. + for (int jj = 0; jj < num_cols; ++jj) + { + per_col_max[jj] = 0.f; + } + + for (int ii = 0; ii < num_rows; ++ii) + { + WeightType const* current_weight_row = current_weight + ii * num_cols; + for (int jj = 0; jj < num_cols; ++jj) + { + per_col_max[jj] = std::max(per_col_max[jj], std::abs(float(current_weight_row[jj]))); + } + } + + // Then, we construct the scales + ComputeType* current_scales = scale_ptr + expert * num_cols; + for (int jj = 0; jj < num_cols; ++jj) + { + per_col_max[jj] *= quant_range_scale; + current_scales[jj] = ComputeType(per_col_max[jj]); + } + + // Finally, construct the weights. + for (int ii = 0; ii < num_rows; ++ii) + { + int8_t* current_quantized_weight_row = current_quantized_weight + ii * bytes_per_out_col; + WeightType const* current_weight_row = current_weight + ii * num_cols; + for (int jj = 0; jj < bytes_per_out_col; ++jj) + { + + if (bits_per_weigtht_element == 8) + { + float const col_scale = per_col_max[jj]; + float const weight_elt = float(current_weight_row[jj]); + float const scaled_weight = (col_scale != 0.0f) ? round(weight_elt / col_scale) : 0.0f; + const int8_t clipped_weight = int8_t(std::max(-128.f, std::min(127.f, scaled_weight))); + current_quantized_weight_row[jj] = clipped_weight; + } + else if (bits_per_weigtht_element == 4) + { + + // We will pack two int4 elements per iteration of the inner loop. + int8_t packed_int4s = 0; + for (int packed_idx = 0; packed_idx < 2; ++packed_idx) + { + int const input_idx = 2 * jj + packed_idx; + if (input_idx < num_cols) + { + float const col_scale = per_col_max[input_idx]; + float const weight_elt = float(current_weight_row[input_idx]); + float const scaled_weight = (col_scale != 0.0f) ? round(weight_elt / col_scale) : 0.0f; + int int_weight = int(scaled_weight); + const int8_t clipped_weight = std::max(-8, std::min(7, int_weight)); + + // Kill the sign extension bits (hence 0x0F mask) then shift to upper bits + // if packing the second int4 and or the bits into the final result. + packed_int4s |= ((clipped_weight & 0x0F) << (4 * packed_idx)); + } + } + current_quantized_weight_row[jj] = packed_int4s; + } + else + { + TLLM_CHECK_WITH_INFO(false, "Unsupported quantization type"); + } + } + } + } + + preprocess_weights_for_mixed_gemm( + processed_quantized_weight, unprocessed_quantized_weight, shape, quant_type, force_interleave); +} + +template void symmetric_quantize( + int8_t*, int8_t*, half*, float const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize( + int8_t*, int8_t*, half*, half const*, std::vector const&, QuantType, bool); + +#ifdef ENABLE_BF16 +template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>( + int8_t*, int8_t*, __nv_bfloat16*, __nv_bfloat16 const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize<__nv_bfloat16, float>( + int8_t*, int8_t*, __nv_bfloat16*, float const*, std::vector const&, QuantType, bool); +#endif + +template +void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, WeightType const* input_weight_ptr, + std::vector const& shape, QuantType quant_type, bool force_interleave) +{ + symmetric_quantize( + processed_quantized_weight, nullptr, scale_ptr, input_weight_ptr, shape, quant_type, force_interleave); +} + +template void symmetric_quantize( + int8_t*, float*, float const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize( + int8_t*, half*, float const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize(int8_t*, half*, half const*, std::vector const&, QuantType, bool); + +#ifdef ENABLE_BF16 +template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>( + int8_t*, __nv_bfloat16*, __nv_bfloat16 const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize<__nv_bfloat16, half>( + int8_t*, __nv_bfloat16*, half const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize( + int8_t*, half*, __nv_bfloat16 const*, std::vector const&, QuantType, bool); + +template void symmetric_quantize<__nv_bfloat16, float>( + int8_t*, __nv_bfloat16*, float const*, std::vector const&, QuantType, bool); +#endif + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h new file mode 100644 index 00000000000..b12fd737245 --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.h @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include "tensorrt_llm/common/cudaUtils.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ + +enum class QuantType +{ + W8_A16, + W4_A16, + W4_AFP8 +}; + +constexpr int get_weight_quant_bits(QuantType quant_type) +{ + switch (quant_type) + { + case QuantType::W8_A16: return 8; + case QuantType::W4_A16: return 4; + case QuantType::W4_AFP8: return 4; + default: TLLM_CHECK_WITH_INFO(false, "Invalid quant_type"); return -1; + } +} + +// Shapes here can be 2 or 3D. 2-D shapes are [num_rows, num_cols] +// 3-D shapes are [num_experts, num_rows, num_cols] +void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, int8_t const* quantized_tensor, + std::vector const& shape, QuantType quant_type, const int64_t arch_version); + +void subbyte_transpose(int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor, + std::vector const& shape, QuantType quant_type); + +void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type); + +void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, int8_t const* row_major_quantized_weight, + std::vector const& shape, QuantType quant_type, bool force_interleave = false); + +template +void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, WeightType const* input_weight_ptr, + std::vector const& shape, QuantType quant_type, bool force_interleave); + +// This is exposed so that we can write tests that use the processed weights for CUTLASS but the unprocessed weight +// to implement a simple reference implementation. +template +void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight, + ComputeType* scale_ptr, WeightType const* input_weight_ptr, std::vector const& shape, QuantType quant_type, + bool force_interleave); + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h new file mode 100644 index 00000000000..0ec8ab2e39b --- /dev/null +++ b/sgl-kernel/3rdparty/tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include "cutlass/bfloat16.h" +#include "cutlass/float8.h" +#include "cutlass/half.h" + +namespace tensorrt_llm +{ +namespace kernels +{ +namespace cutlass_kernels +{ + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// Tllm to Cutlass + +template +struct TllmToCutlassTypeAdapter +{ + using type = T; +}; + +template <> +struct TllmToCutlassTypeAdapter +{ + using type = cutlass::half_t; +}; + +#if defined(ENABLE_BF16) +template <> +struct TllmToCutlassTypeAdapter<__nv_bfloat16> +{ + using type = cutlass::bfloat16_t; +}; +#endif + +#if defined(ENABLE_FP8) +template <> +struct TllmToCutlassTypeAdapter<__nv_fp8_e4m3> +{ + using type = cutlass::float_e4m3_t; +}; + +template <> +struct TllmToCutlassTypeAdapter<__nv_fp8_e5m2> +{ + using type = cutlass::float_e5m2_t; +}; +#endif + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// Cutlass to Tllm + +template +struct CutlassToTllmTypeAdapter +{ + using type = T; +}; + +template <> +struct CutlassToTllmTypeAdapter +{ + using type = half; +}; + +#if defined(ENABLE_BF16) +template <> +struct CutlassToTllmTypeAdapter +{ + using type = __nv_bfloat16; +}; +#endif + +#if defined(ENABLE_FP8) +template <> +struct CutlassToTllmTypeAdapter +{ + using type = __nv_fp8_e4m3; +}; + +template <> +struct CutlassToTllmTypeAdapter +{ + using type = __nv_fp8_e5m2; +}; +#endif + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass_kernels +} // namespace kernels +} // namespace tensorrt_llm diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index e5d097b14e7..1da704c099c 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -101,6 +101,17 @@ def _get_version(): "3rdparty/tensorrt_llm/common/stringUtils.cpp", "3rdparty/tensorrt_llm/common/tllmException.cpp", "3rdparty/tensorrt_llm/common/cudaFp8Utils.cu", + "3rdparty/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp", + "3rdparty/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp", + "3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu", + "3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint4.cu", + "3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_uint8.cu", + "3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp16.cu", + "3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint4.cu", + "3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_uint8.cu", + "3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp32_fp32.cu", + "3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp8_fp8.cu", + "3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_hopper_input.cu", ] enable_bf16 = os.getenv("SGL_KERNEL_ENABLE_BF16", "0") == "1" From fc24280e779483590ee3be15461a1e190d7f1bdf Mon Sep 17 00:00:00 2001 From: zhyncs Date: Thu, 30 Jan 2025 13:06:47 -0800 Subject: [PATCH 3/4] upd --- sgl-kernel/setup.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 1da704c099c..354740ffb43 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -101,6 +101,9 @@ def _get_version(): "3rdparty/tensorrt_llm/common/stringUtils.cpp", "3rdparty/tensorrt_llm/common/tllmException.cpp", "3rdparty/tensorrt_llm/common/cudaFp8Utils.cu", +] + +sources_cutlass_moe_gemm = [ "3rdparty/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp", "3rdparty/tensorrt_llm/kernels/cutlass_kernels/cutlass_preprocessors.cpp", "3rdparty/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_bf16.cu", @@ -123,18 +126,30 @@ def _get_version(): if torch.cuda.is_available(): if cuda_version >= (12, 0) and sm_version >= 90: nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") + nvcc_flags.append("-DCOMPILE_HOPPER_TMA_GEMMS") + nvcc_flags.append("-DCUTLASS_ARCH_MMA_SM90_SUPPORTED") + sources.extend(sources_cutlass_moe_gemm) if sm_version >= 90: nvcc_flags.extend(nvcc_flags_fp8) + nvcc_flags.append("-DENABLE_FP8") if sm_version >= 80: nvcc_flags.append("-DFLASHINFER_ENABLE_BF16") + nvcc_flags.append("-DCUTLASS_ARCH_MMA_SM80_SUPPORTED") + nvcc_flags.append("-DENABLE_BF16") else: # compilation environment without GPU if enable_sm90a: nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") + nvcc_flags.append("-DCOMPILE_HOPPER_TMA_GEMMS") + nvcc_flags.append("-DCUTLASS_ARCH_MMA_SM90_SUPPORTED") + sources.extend(sources_cutlass_moe_gemm) if enable_fp8: nvcc_flags.extend(nvcc_flags_fp8) + nvcc_flags.append("-DENABLE_FP8") if enable_bf16: nvcc_flags.append("-DFLASHINFER_ENABLE_BF16") + nvcc_flags.append("-DCUTLASS_ARCH_MMA_SM80_SUPPORTED") + nvcc_flags.append("-DENABLE_BF16") for flag in [ "-D__CUDA_NO_HALF_OPERATORS__", From d04f52512f198e3bb10d1fb4c77f90226e1a4916 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Thu, 30 Jan 2025 13:35:46 -0800 Subject: [PATCH 4/4] upd --- sgl-kernel/setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 354740ffb43..6f1ea43e54d 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -128,6 +128,7 @@ def _get_version(): nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") nvcc_flags.append("-DCOMPILE_HOPPER_TMA_GEMMS") nvcc_flags.append("-DCUTLASS_ARCH_MMA_SM90_SUPPORTED") + nvcc_flags.append("-DCUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED") sources.extend(sources_cutlass_moe_gemm) if sm_version >= 90: nvcc_flags.extend(nvcc_flags_fp8) @@ -142,6 +143,7 @@ def _get_version(): nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") nvcc_flags.append("-DCOMPILE_HOPPER_TMA_GEMMS") nvcc_flags.append("-DCUTLASS_ARCH_MMA_SM90_SUPPORTED") + nvcc_flags.append("-DCUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED") sources.extend(sources_cutlass_moe_gemm) if enable_fp8: nvcc_flags.extend(nvcc_flags_fp8)