From 812e7f4fb6a6e65ed8256bd16884e137a3abc55d Mon Sep 17 00:00:00 2001 From: Keith Wyss Date: Wed, 12 Feb 2025 13:50:52 -0800 Subject: [PATCH 1/2] Blockwise float8 quantizer and quantized tensor class. The classes are configurable for 128x128 blocksize and 1x128 blocksize via setting block_scaling_dim == 2,1 respectively. Scale tensors are stored in a format emenable for matrix multiplication, however the integration of matmul is deferred as a separate story. Fusions of quantization and DBIAS or activation functions are not yet implemented, and the dequantization is currently implemented in torch. Tests for quantization are included in C++ and pytorch layers, with exact comparison to reference quantizer behavior as well as an attempt to hit interesting branches through the API such as tensor creation in pytorch and CPP and dequantization of row and columnwise usage. Two CUDA kernels for quantization are included, and are direct ports of equivalents in the kitchen repository, where a subchannel recipe has been used for end to end training. --- tests/cpp/operator/CMakeLists.txt | 1 + .../cpp/operator/test_cast_float8blockwise.cu | 640 ++++++++++++++++++ tests/cpp/test_common.cu | 160 +++-- tests/cpp/test_common.h | 35 +- .../blockwise_quantizer_reference.py | 373 ++++++++++ .../test_float8_blockwise_scaling_exact.py | 288 ++++++++ tests/pytorch/test_float8blockwisetensor.py | 207 ++++++ transformer_engine/common/CMakeLists.txt | 2 + transformer_engine/common/common.h | 49 +- .../common/gemm/cublaslt_gemm.cu | 8 + .../common/include/transformer_engine/cast.h | 14 +- .../transformer_engine/transformer_engine.h | 80 +++ .../common/transformer_engine.cpp | 45 ++ .../common/transpose/cast_transpose.h | 22 +- .../common/transpose/compute_scale.cuh | 134 ++++ .../quantize_transpose_square_blockwise.cu | 640 ++++++++++++++++++ .../quantize_transpose_vector_blockwise.cu | 563 +++++++++++++++ .../common/util/cast_kernels.cuh | 29 + .../common/util/dequantize_kernels.cuh | 1 + transformer_engine/pytorch/constants.py | 16 + transformer_engine/pytorch/csrc/common.h | 30 + .../pytorch/csrc/extensions/cast.cpp | 4 +- .../pytorch/csrc/extensions/pybind.cpp | 26 + .../pytorch/csrc/extensions/quantizer.cpp | 127 ++++ .../csrc/extensions/type_converters.cpp | 32 + transformer_engine/pytorch/csrc/pybind.h | 19 +- .../_internal/float8_blockwise_tensor_base.py | 250 +++++++ .../tensor/_internal/mxfp8_tensor_base.py | 2 +- .../pytorch/tensor/float8_blockwise_tensor.py | 553 +++++++++++++++ 29 files changed, 4274 insertions(+), 76 deletions(-) create mode 100644 tests/cpp/operator/test_cast_float8blockwise.cu create mode 100644 tests/pytorch/references/blockwise_quantizer_reference.py create mode 100644 tests/pytorch/test_float8_blockwise_scaling_exact.py create mode 100644 tests/pytorch/test_float8blockwisetensor.py create mode 100644 transformer_engine/common/transpose/compute_scale.cuh create mode 100644 transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu create mode 100644 transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu create mode 100644 transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py create mode 100644 transformer_engine/pytorch/tensor/float8_blockwise_tensor.py diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index ce78fcaae2..ff225cccba 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -10,6 +10,7 @@ add_executable(test_operator test_cast_mxfp8_gated_swiglu.cu test_qdq.cu test_cast_mxfp8.cu + test_cast_float8blockwise.cu test_dequantize_mxfp8.cu test_transpose.cu test_cast_transpose.cu diff --git a/tests/cpp/operator/test_cast_float8blockwise.cu b/tests/cpp/operator/test_cast_float8blockwise.cu new file mode 100644 index 0000000000..171d22be71 --- /dev/null +++ b/tests/cpp/operator/test_cast_float8blockwise.cu @@ -0,0 +1,640 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +constexpr size_t kBlockLen = 128; + +enum ProcessingMethod { + CAST_ONLY, + // CAST_DBIAS, + // CAST_DBIAS_DACT, + // CAST_DACT, + // CAST_ACT +}; + +enum ActivationType { + Identity, + // GeLU, + // SiLU, + // ReLU, + // QGeLU, + // SReLU +}; + +template +void scales_from_amax(float amax, const QuantizationOptions& opts, float* qscale_out, + float* qscale_inv_out) { + float input_type_max_val = Quantized_Limits::max(); + float quant_type_max_val = Quantized_Limits::max(); + float eps = opts.amax_epsilon; + amax = std::max(amax, eps); + float qscale = quant_type_max_val / amax; + if (std::isinf(qscale)) { + qscale = input_type_max_val; + } + if (std::isnan(qscale) || amax == 0) { + qscale = 1.0; + } + + if (opts.force_pow_2_scales && qscale != 0.0) { + uint32_t scale_bits = *reinterpret_cast(&qscale); + // Scale must be positive, shift it + uint8_t exp = scale_bits >> 23; + ASSERT_FALSE(exp == 0) << "Subnormals in this path is a logic error."; + qscale = ldexpf(1.0f, static_cast(exp) - 127); + } + + float qscale_inv = 1.0 / qscale; + *qscale_out = qscale; + *qscale_inv_out = qscale_inv; +} + +template +void ref_quantize(const ProcessingMethod processing_method, const InputType* input, + const std::pair& input_hw, OutputType* output, float* scale_inv, + OutputType* output_t, float* scale_inv_t, const QuantizationOptions& opts) { + constexpr size_t kBlockLenX = kBlockLen; + constexpr size_t kBlockLenY = kBlockLen; + + auto quantize_element = [](InputType element, float qscale) -> OutputType { + // Scale in FP32 and cast result to nearest FP8. + return static_cast(float(element) * qscale); + }; + + size_t height = input_hw.first; + size_t width = input_hw.second; + size_t blocks_x = (width + kBlockLenX - 1) / kBlockLenX; + size_t blocks_y = (height + kBlockLenY - 1) / kBlockLenY; + // Find the absolute maximum value in the block + for (size_t block_x = 0; block_x < blocks_x; ++block_x) { + for (size_t block_y = 0; block_y < blocks_y; ++block_y) { + float amax = 0.0f; + // Calculate amax for a tile. + for (size_t i = 0; i < kBlockLenX; ++i) { + for (size_t j = 0; j < kBlockLenY; ++j) { + size_t x_pos = i + block_x * kBlockLenX; + size_t y_pos = j + block_y * kBlockLenY; + if (y_pos >= height || x_pos >= width) { + continue; + } + float val = static_cast(input[y_pos * width + x_pos]); + amax = std::max(amax, std::abs(val)); + } + } + + // We've calculated amax for a tile. Calculate scale and + // scale_inv and populate outputs. + float qscale, qscale_inv; + scales_from_amax(amax, opts, &qscale, &qscale_inv); + + // NOTE: This reference function outputs contigous scale tensors. + // It calculates a naive scale data format. Strides are handled + // in comparison. + if (scale_inv != nullptr) { + scale_inv[block_y * blocks_x + block_x] = qscale_inv; + } + if (scale_inv_t != nullptr) { + scale_inv_t[block_x * blocks_y + block_y] = qscale_inv; + } + + for (size_t i = 0; i < kBlockLenX; ++i) { + for (size_t j = 0; j < kBlockLenY; ++j) { + size_t x_pos = i + block_x * kBlockLenX; + size_t y_pos = j + block_y * kBlockLenY; + if (y_pos >= height || x_pos >= width) { + continue; + } + if (output != nullptr) { + output[y_pos * width + x_pos] = quantize_element(input[y_pos * width + x_pos], qscale); + } + if (output_t != nullptr) { + output_t[x_pos * height + y_pos] = + quantize_element(input[y_pos * width + x_pos], qscale); + } + } + } + } + } +} + +template +void ref_quantize_onedimensional_blocks(const ProcessingMethod processing_method, + const InputType* input, + const std::pair& input_hw, + OutputType* output, float* scale_inv, OutputType* output_t, + float* scale_inv_t, const QuantizationOptions& opts) { + float input_type_max_val = Quantized_Limits::max(); + float quant_type_max_val = Quantized_Limits::max(); + + constexpr size_t kBlockLenX = kBlockLen; + + auto quantize_element = [](InputType element, float qscale) -> OutputType { + // Scale in FP32 and cast result to nearest FP8. + return static_cast(float(element) * qscale); + }; + + size_t height = input_hw.first; + size_t width = input_hw.second; + size_t blocks_x = (width + kBlockLenX - 1) / kBlockLenX; + size_t blocks_x_t = (height + kBlockLenX - 1) / kBlockLenX; + if (output != nullptr && scale_inv != nullptr) { + // Find the absolute maximum value in the block + for (size_t block_x = 0; block_x < blocks_x; ++block_x) { + for (size_t y = 0; y < height; ++y) { + float amax = 0.0f; + // Calculate amax for a tile. + for (size_t i = 0; i < kBlockLenX; ++i) { + size_t x_pos = i + block_x * kBlockLenX; + if (x_pos >= width) { + continue; + } + float val = static_cast(input[y * width + x_pos]); + amax = std::max(amax, std::abs(val)); + } + + // We've calculated amax for a tile. Calculate scale and + // scale_inv and populate outputs. + float qscale, qscale_inv; + scales_from_amax(amax, opts, &qscale, &qscale_inv); + + scale_inv[y + height * block_x] = qscale_inv; + + for (size_t i = 0; i < kBlockLenX; ++i) { + size_t x_pos = i + block_x * kBlockLenX; + if (x_pos >= width) { + continue; + } + output[y * width + x_pos] = quantize_element(input[y * width + x_pos], qscale); + } + } + } + } + if (output_t != nullptr && scale_inv_t != nullptr) { + // Find the absolute maximum value in the block + for (size_t block_x_t = 0; block_x_t < blocks_x_t; ++block_x_t) { + for (size_t x = 0; x < width; ++x) { + float amax = 0.0f; + // Calculate amax for a tile. + for (size_t i = 0; i < kBlockLenX; ++i) { + size_t y_pos = i + block_x_t * kBlockLenX; + if (y_pos >= height) { + continue; + } + float val = static_cast(input[x + y_pos * width]); + amax = std::max(amax, std::abs(val)); + } + + // We've calculated amax for a tile. Calculate scale and + // scale_inv and populate outputs. + float qscale, qscale_inv; + scales_from_amax(amax, opts, &qscale, &qscale_inv); + + scale_inv_t[x + width * block_x_t] = qscale_inv; + + for (size_t i = 0; i < kBlockLenX; ++i) { + size_t y_pos = i + block_x_t * kBlockLenX; + if (y_pos >= height) { + continue; + } + output_t[x * height + y_pos] = quantize_element(input[y_pos * width + x], qscale); + } + } + } + } +} + +void compare_scaling_factors(const std::string& name, const float* test, const float* ref, + const size_t row_blocks, const size_t col_blocks, + const size_t test_stride, const size_t ref_stride) { + for (int i = 0; i < row_blocks; ++i) { + for (int j = 0; j < col_blocks; ++j) { + const int test_idx = i * test_stride + j; + const int ref_idx = i * ref_stride + j; + ASSERT_FALSE(test[test_idx] != ref[ref_idx]) + << "Error in " << name << std::endl + << "Mismatch: " << test[test_idx] << " vs " << ref[ref_idx] << " at index " << test_idx + << "," << ref_idx; + } + } +} + +void compare_scaling_factors_one_dimensional_blocks(const std::string& name, const float* test, + const float* ref, const size_t rows, + const size_t col_blocks) { + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < col_blocks; ++j) { + const int test_idx = i + rows * j; + const int ref_idx = i + rows * j; + ASSERT_FALSE(test[test_idx] != ref[ref_idx]) + << "Error in " << name << std::endl + << "Mismatch: " << test[test_idx] << " vs " << ref[ref_idx] << " at index " << test_idx + << "," << ref_idx; + } + } +} + +template +void runTestCase(const ProcessingMethod processing_method, const std::vector& shape, + const bool rowwise, const bool colwise, InputsFillCase fill_case, + const QuantizationOptions& opts) { + using namespace test; + using EncodingType = fp32; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t rows = first_dimension(shape); + const size_t cols = last_dimension(shape); + + size_t blocks_x = (cols + kBlockLen - 1) / kBlockLen; + size_t blocks_y = (rows + kBlockLen - 1) / kBlockLen; + + Tensor input("input", shape, itype); + Tensor grad("grad", shape, itype); + Tensor output_c("output_c", shape, otype, rowwise, colwise, NVTE_BLOCK_SCALING, &opts); + Tensor output_dbias("output_dbias", {cols}, itype); + + std::unique_ptr ref_output = std::make_unique(rows * cols); + std::unique_ptr ref_output_t = std::make_unique(rows * cols); + std::unique_ptr ref_scale_inv = std::make_unique(blocks_y * blocks_x); + std::unique_ptr ref_scale_inv_t = std::make_unique(blocks_y * blocks_x); + + if (!rowwise) { + ref_output = nullptr; + ref_scale_inv = nullptr; + } + if (!colwise) { + ref_output_t = nullptr; + ref_scale_inv_t = nullptr; + } + + fillCase(&input, fill_case); + fillUniform(&grad); + + Tensor workspace; + switch (processing_method) { + case ProcessingMethod::CAST_ONLY: { + nvte_quantize(input.data(), output_c.data(), 0); + break; + } + } + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + ref_quantize(processing_method, input.rowwise_cpu_dptr(), + {rows, cols}, ref_output.get(), ref_scale_inv.get(), + ref_output_t.get(), ref_scale_inv_t.get(), opts); + + float atol = 0.0; + float rtol = 0.0; + + auto scale_align_stride = [](size_t inner_elements) -> size_t { + return ((inner_elements + 4u - 1u) / 4u) * 4u; + }; + + if (rowwise) { + compareResults("output_c", output_c, ref_output.get(), true, atol, rtol); + compare_scaling_factors("scale_inv", output_c.rowwise_cpu_scale_inv_ptr(), + ref_scale_inv.get(), blocks_y, blocks_x, scale_align_stride(blocks_x), + blocks_x); + } + if (colwise) { + compareResults("output_c_t", output_c, ref_output_t.get(), false, atol, rtol); + compare_scaling_factors("scale_inv_t", output_c.columnwise_cpu_scale_inv_ptr(), + ref_scale_inv_t.get(), blocks_x, blocks_y, scale_align_stride(blocks_y), + blocks_y); + } +} + +template +void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method, + const std::vector& shape, const bool rowwise, + const bool colwise, InputsFillCase fill_case, + const QuantizationOptions& opts) { + using namespace test; + using EncodingType = fp32; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t rows = first_dimension(shape); + const size_t cols = last_dimension(shape); + + size_t blocks_x = (cols + kBlockLen - 1) / kBlockLen; + size_t blocks_x_t = (rows + kBlockLen - 1) / kBlockLen; + + Tensor input("input", shape, itype); + Tensor grad("grad", shape, itype); + Tensor output_c("output_c", shape, otype, rowwise, colwise, NVTE_BLOCK_SCALING, &opts); + Tensor output_dbias("output_dbias", {cols}, itype); + + std::unique_ptr ref_output = std::make_unique(rows * cols); + std::unique_ptr ref_output_t = std::make_unique(rows * cols); + std::unique_ptr ref_scale_inv = std::make_unique(rows * blocks_x); + std::unique_ptr ref_scale_inv_t = std::make_unique(cols * blocks_x_t); + + if (!rowwise) { + ref_output = nullptr; + ref_scale_inv = nullptr; + } + if (!colwise) { + ref_output_t = nullptr; + ref_scale_inv_t = nullptr; + } + + fillCase(&input, fill_case); + fillUniform(&grad); + + Tensor workspace; + switch (processing_method) { + case ProcessingMethod::CAST_ONLY: { + nvte_quantize(input.data(), output_c.data(), 0); + break; + } + } + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + ref_quantize_onedimensional_blocks( + processing_method, input.rowwise_cpu_dptr(), {rows, cols}, ref_output.get(), + ref_scale_inv.get(), ref_output_t.get(), ref_scale_inv_t.get(), opts); + + float atol = 0.0; + float rtol = 0.0; + + if (rowwise) { + compareResults("output_c", output_c, ref_output.get(), true, atol, rtol); + compare_scaling_factors_one_dimensional_blocks("scale_inv", + output_c.rowwise_cpu_scale_inv_ptr(), + ref_scale_inv.get(), rows, blocks_x); + } + if (colwise) { + compareResults("output_c_t", output_c, ref_output_t.get(), false, atol, rtol); + compare_scaling_factors_one_dimensional_blocks("scale_inv_t", + output_c.columnwise_cpu_scale_inv_ptr(), + ref_scale_inv_t.get(), cols, blocks_x_t); + } +} + +std::vector> matrix_sizes = { + {1, 16}, {16, 48}, {65, 96}, {128, 128}, {256, 256}, {993, 512}, + {256, 65536}, {2048, 6144}, {16384, 128}, {32768, 160}, {4096, 1632}, {1024, 1}, + {32, 1024}, {16, 512}, {1024}, {8, 32, 1024}, {16, 8, 4, 512}, +}; + +std::vector input_scenarios = { + InputsFillCase::uniform, +}; + +std::vector processing_methods = { + ProcessingMethod::CAST_ONLY, + // ProcessingMethod::CAST_DBIAS, + // ProcessingMethod::CAST_DBIAS_DACT, + // ProcessingMethod::CAST_DACT, + // ProcessingMethod::CAST_ACT, +}; + +// Only GeLU activation tests are supported +std::vector Activation_types = { + ActivationType::Identity, + // ActivationType::GeLU, + // ActivationType::SiLU, + // ActivationType::ReLU, + // ActivationType::QGeLU, + // ActivationType::SReLU, +}; + + +std::vector amax_epsilons = { + 0.0f, + // Set large epsilon to get observable behavior. + 0.1f, +}; + +} // namespace + +class FusedCastFloat8BlockwiseTestSuite + : public ::testing::TestWithParam, transformer_engine::DType, + transformer_engine::DType, InputsFillCase, bool, float, bool>> {}; + +class FusedCastFloat8VectorwiseTestSuite + : public ::testing::TestWithParam, transformer_engine::DType, + transformer_engine::DType, InputsFillCase, bool, float, bool>> {}; + +#define DACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \ + switch (OP_FUNC_TYPE) { \ + case ActivationType::Identity: { \ + constexpr auto OP = &identity; \ + { \ + __VA_ARGS__ \ + } \ + } break; \ + } + +#define ACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \ + switch (OP_FUNC_TYPE) { \ + case ActivationType::Identity: { \ + constexpr auto OP = &identity; \ + { \ + __VA_ARGS__ \ + } \ + } break; \ + } + +TEST_P(FusedCastFloat8BlockwiseTestSuite, TestFusedCastFloat8Blockwise) { + if (getDeviceComputeCapability() < hopperComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const ProcessingMethod processing_method = std::get<0>(GetParam()); + const ActivationType Act_type = std::get<1>(GetParam()); + const auto matrix_size = std::get<2>(GetParam()); + const DType input_type = std::get<3>(GetParam()); + const DType output_type = std::get<4>(GetParam()); + const InputsFillCase fill_case = std::get<5>(GetParam()); + const bool colwise = std::get<6>(GetParam()); + const bool rowwise = true; + const float eps = std::get<7>(GetParam()); + const bool force_pow_2 = std::get<8>(GetParam()); + + QuantizationOptions q_opts; + q_opts.force_pow_2_scales = force_pow_2; + q_opts.amax_epsilon = eps; + q_opts.block_scaling_dim = 2u; + + if (colwise && matrix_size.size() < 2) { + // test_common Tensor initialization code does not + // handle this case. + GTEST_SKIP(); + } + // Skips non Act tests if the Activation type is not an identity + if ( // (processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS) + (processing_method == ProcessingMethod::CAST_ONLY) && Act_type != ActivationType::Identity) { + GTEST_SKIP(); + } + // Skips Act tests if the Activation is an identity + // if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT + // || processing_method == ProcessingMethod::CAST_DACT + // || processing_method == ProcessingMethod::CAST_ACT) && (Act_type == ActivationType::Identity)) { + // GTEST_SKIP(); + // } + + DACT_FUNC_SWITCH( + Act_type, OP, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY( + input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY( + output_type, OutputType, + runTestCase(processing_method, matrix_size, rowwise, colwise, + fill_case, q_opts);););); +} + +TEST_P(FusedCastFloat8VectorwiseTestSuite, TestFusedCastFloat8Vectorwise) { + if (getDeviceComputeCapability() < hopperComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const ProcessingMethod processing_method = std::get<0>(GetParam()); + const ActivationType Act_type = std::get<1>(GetParam()); + const auto matrix_size = std::get<2>(GetParam()); + const DType input_type = std::get<3>(GetParam()); + const DType output_type = std::get<4>(GetParam()); + const InputsFillCase fill_case = std::get<5>(GetParam()); + const bool colwise = std::get<6>(GetParam()); + const bool rowwise = true; + const float eps = std::get<7>(GetParam()); + const bool force_pow_2 = std::get<8>(GetParam()); + + QuantizationOptions q_opts; + q_opts.force_pow_2_scales = force_pow_2; + q_opts.amax_epsilon = eps; + q_opts.block_scaling_dim = 1u; + + if (colwise && matrix_size.size() < 2) { + // test_common Tensor initialization code does not + // handle this case. + GTEST_SKIP(); + } + // Skips non Act tests if the Activation type is not an identity + if ( // (processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS) + (processing_method == ProcessingMethod::CAST_ONLY) && Act_type != ActivationType::Identity) { + GTEST_SKIP(); + } + // Skips Act tests if the Activation is an identity + // if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT + // || processing_method == ProcessingMethod::CAST_DACT + // || processing_method == ProcessingMethod::CAST_ACT) && (Act_type == ActivationType::Identity)) { + // GTEST_SKIP(); + // } + + DACT_FUNC_SWITCH( + Act_type, OP, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY( + input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY( + output_type, OutputType, + runTestCaseOneDimensionalBlocks( + processing_method, matrix_size, rowwise, colwise, fill_case, q_opts);););); +} + +std::string to_string(const ProcessingMethod method) { + switch (method) { + case ProcessingMethod::CAST_ONLY: + return "CAST_ONLY"; + // case ProcessingMethod::CAST_DBIAS: return "CAST_DBIAS"; + // case ProcessingMethod::CAST_DBIAS_DACT: return "CAST_DBIAS_DACT"; + // case ProcessingMethod::CAST_DACT: return "CAST_DACT"; + // case ProcessingMethod::CAST_ACT: return "CAST_ACT"; + default: + return ""; + } +} + +std::string to_string(const ActivationType Act_type) { + switch (Act_type) { + case ActivationType::Identity: + return "Identity"; + // case ActivationType::GeLU: return "GeLU"; + // case ActivationType::SiLU: return "SiLU"; + // case ActivationType::ReLU: return "ReLU"; + // case ActivationType::QGeLU: return "QGeLU"; + // case ActivationType::SReLU: return "SReLU"; + default: + return ""; + } +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, FusedCastFloat8BlockwiseTestSuite, + ::testing::Combine(::testing::ValuesIn(processing_methods), + ::testing::ValuesIn(Activation_types), ::testing::ValuesIn(matrix_sizes), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(input_scenarios), ::testing::Values(true, false), + ::testing::ValuesIn(amax_epsilons), ::testing::Values(false, true)), + [](const testing::TestParamInfo& info) { + std::string name = + to_string(std::get<0>(info.param)) + "X" + to_string(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for (const auto& s : shape) { + name += "X" + std::to_string(s); + } + name += "X" + test::typeName(std::get<3>(info.param)) + "X" + + test::typeName(std::get<4>(info.param)) + "X" + + test::caseName(std::get<5>(info.param)) + "X" + + std::to_string(std::get<6>(info.param)) + "X" + + std::to_string(std::get<7>(info.param) != 0.0f) + "X" + + std::to_string(std::get<8>(info.param)); + return name; + }); + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, FusedCastFloat8VectorwiseTestSuite, + ::testing::Combine(::testing::ValuesIn(processing_methods), + ::testing::ValuesIn(Activation_types), ::testing::ValuesIn(matrix_sizes), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(input_scenarios), ::testing::Values(true, false), + ::testing::ValuesIn(amax_epsilons), ::testing::Values(false, true)), + [](const testing::TestParamInfo& info) { + std::string name = + to_string(std::get<0>(info.param)) + "X" + to_string(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for (const auto& s : shape) { + name += "X" + std::to_string(s); + } + name += "X" + test::typeName(std::get<3>(info.param)) + "X" + + test::typeName(std::get<4>(info.param)) + "X" + + test::caseName(std::get<5>(info.param)) + "X" + + std::to_string(std::get<6>(info.param)) + "X" + + std::to_string(std::get<7>(info.param) != 0.0f) + "X" + + std::to_string(std::get<8>(info.param)); + return name; + }); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index ec4a9bdbb7..78bbb97687 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -118,7 +119,8 @@ NVTEShape convertShape(const std::vector& shape) { } std::pair get_scales(const NVTEShape& shape, - const NVTEScalingMode scaling_mode) { + const NVTEScalingMode scaling_mode, + const int block_scaling_dim) { if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { scale_inv_meta ret; ret.shape = {1}; @@ -136,27 +138,19 @@ std::pair get_scales(const NVTEShape& shape, scale_inv_meta ret_rowwise, ret_colwise; - auto block_alignment = std::vector{128ul,4ul}; + auto block_alignment = std::vector{128ul, 4ul}; { auto alignment = block_alignment[0]; - auto scale_dim_0 = DIVUP(DIVUP(first_dim, - static_cast(1)), - alignment) * alignment; + auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast(1)), alignment) * alignment; alignment = block_alignment[1]; - auto scale_dim_1 = DIVUP(DIVUP(last_dim, - static_cast(32)), - alignment) * alignment; + auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(32)), alignment) * alignment; ret_rowwise.shape = {scale_dim_0, scale_dim_1}; } { auto alignment = block_alignment[1]; - auto scale_dim_0 = DIVUP(DIVUP(first_dim, - static_cast(32)), - alignment) * alignment; + auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast(32)), alignment) * alignment; alignment = block_alignment[0]; - auto scale_dim_1 = DIVUP(DIVUP(last_dim, - static_cast(1)), - alignment) * alignment; + auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(1)), alignment) * alignment; ret_colwise.shape = {scale_dim_0, scale_dim_1}; } ret_rowwise.type = DType::kFloat8E8M0; @@ -166,6 +160,61 @@ std::pair get_scales(const NVTEShape& shape, return {ret_rowwise, ret_colwise}; } + if (scaling_mode == NVTE_BLOCK_SCALING) { + if (block_scaling_dim == 2) { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + size_t first_dim = first_dimension(shape_vec); + size_t last_dim = last_dimension(shape_vec); + + scale_inv_meta ret_rowwise, ret_colwise; + + { + auto scale_dim_0 = DIVUP(first_dim, static_cast(128)); + auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(128)), 4) * 4; + ret_rowwise.shape = {scale_dim_0, scale_dim_1}; + } + { + auto scale_dim_0 = DIVUP(last_dim, static_cast(128)); + auto scale_dim_1 = DIVUP(DIVUP(first_dim, static_cast(128)), 4) * 4; + ret_colwise.shape = {scale_dim_0, scale_dim_1}; + } + ret_rowwise.type = DType::kFloat32; + ret_colwise.type = DType::kFloat32; + ret_rowwise.type_size = sizeof(float); + ret_colwise.type_size = sizeof(float); + + return {ret_rowwise, ret_colwise}; + } else if (block_scaling_dim == 1) { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + size_t first_dim = first_dimension(shape_vec); + size_t last_dim = last_dimension(shape_vec); + scale_inv_meta ret_rowwise, ret_colwise; + + { + auto scale_dim_1 = first_dim; + auto scale_dim_0 = DIVUP(last_dim, static_cast(128)); + ret_rowwise.shape = {scale_dim_0, scale_dim_1}; + } + { + auto scale_dim_1 = last_dim; + auto scale_dim_0 = DIVUP(first_dim, static_cast(128)); + ret_colwise.shape = {scale_dim_0, scale_dim_1}; + } + ret_rowwise.type = DType::kFloat32; + ret_colwise.type = DType::kFloat32; + ret_rowwise.type_size = sizeof(float); + ret_colwise.type_size = sizeof(float); + return {ret_rowwise, ret_colwise}; + } else { + NVTE_ERROR("Unsupported block scaling dim!"); + } + } NVTE_ERROR("Invalid scaling mode!"); } @@ -173,7 +222,8 @@ std::pair get_scales(const NVTEShape& shape, Tensor::Tensor(const std::string& name, const NVTEShape &shape, const DType type, const bool rowwise, const bool columnwise, - const NVTEScalingMode &scaling_mode) { + const NVTEScalingMode &scaling_mode, + const QuantizationOptions* q_opts) { name_ = name; const size_t seed = create_seed_from_tensor_name(name); gen_.seed(seed); @@ -197,9 +247,12 @@ Tensor::Tensor(const std::string& name, std::vector normalized_shape_v = {product(shape, 0, shape.ndim - 1), shape.data[shape.ndim - 1]}; NVTEShape normalized_shape = convertShape(normalized_shape_v); - + size_t block_scaling_dim = 0; + if (q_opts != nullptr) { + block_scaling_dim = q_opts->block_scaling_dim; + } std::vector columnwise_shape_vec; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING) { // Transpose when tensor scaling columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]); for (size_t i = 0; i < shape.ndim - 1; ++i) { @@ -256,27 +309,34 @@ Tensor::Tensor(const std::string& name, std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0); } } else { - auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape, - tensor_.scaling_mode()); + auto [rowwise_scale_meta, colwise_scale_meta] = + get_scales(normalized_shape, tensor_.scaling_mode(), block_scaling_dim); auto rowwise_scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; auto columnwise_scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; auto scale_shape = rowwise_scale_meta.shape; auto columnwise_scale_shape = colwise_scale_meta.shape; if (rowwise) { - cudaMalloc((void**)&rowwise_scale_inv, rowwise_scale_size); // NOLINT(*) + cudaMalloc((void **)&rowwise_scale_inv, rowwise_scale_size); // NOLINT(*) cudaMemset(rowwise_scale_inv, 0, rowwise_scale_size); rowwise_scale_inv_cpu_data_ = std::make_unique(rowwise_scale_size); std::fill_n(rowwise_scale_inv_cpu_data_.get(), rowwise_scale_size, 0); - tensor_.set_rowwise_scale_inv(rowwise_scale_inv, DType::kFloat8E8M0, scale_shape); + auto scale_dtype = rowwise_scale_meta.type; + tensor_.set_rowwise_scale_inv(rowwise_scale_inv, scale_dtype, scale_shape); } if (columnwise) { cudaMalloc((void**)&columnwise_scale_inv, columnwise_scale_size); // NOLINT(*) cudaMemset(columnwise_scale_inv, 0, columnwise_scale_size); columnwise_scale_inv_cpu_data_ = std::make_unique(columnwise_scale_size); std::fill_n(columnwise_scale_inv_cpu_data_.get(), columnwise_scale_size, 0); - tensor_.set_columnwise_scale_inv(columnwise_scale_inv, DType::kFloat8E8M0, columnwise_scale_shape); + auto scale_dtype = colwise_scale_meta.type; + tensor_.set_columnwise_scale_inv(columnwise_scale_inv, scale_dtype, columnwise_scale_shape); } } + if (q_opts != nullptr) { + tensor_.set_qopt_force_pow_2_scales(q_opts->force_pow_2_scales); + tensor_.set_qopt_amax_epsilon(q_opts->amax_epsilon); + tensor_.set_qopt_block_scaling_dim(q_opts->block_scaling_dim); + } } } @@ -306,7 +366,8 @@ void Tensor::to_cpu() const { sizeof(float), cudaMemcpyDeviceToHost); } - auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); + auto [rowwise_scale_meta, colwise_scale_meta] = + get_scales(s, tensor_.scaling_mode(), tensor_.get_qopt_block_scaling_dim()); if (rowwise_) { auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; cudaMemcpy(rowwise_scale_inv_cpu_data_.get(), @@ -342,7 +403,8 @@ void Tensor::from_cpu() const { cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice); } - auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); + auto [rowwise_scale_meta, colwise_scale_meta] = + get_scales(s, tensor_.scaling_mode(), tensor_.get_qopt_block_scaling_dim()); if (rowwise_) { auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr, @@ -361,7 +423,7 @@ void Tensor::from_cpu() const { void Tensor::set_scale(float scale) { if (isFp8Type(dtype())) { NVTE_CHECK(scale_cpu_data_); - if (is_tensor_scaling(tensor_.scaling_mode())) { + if (is_tensor_scaling(tensor_.scaling_mode())) { *scale_cpu_data_ = scale; from_cpu(); } @@ -376,27 +438,29 @@ void Tensor::set_scale_inv(float scale_inv) { if (columnwise_) { NVTE_CHECK(columnwise_scale_inv_cpu_data_); } - auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(tensor_.shape(), tensor_.scaling_mode()); + + auto [rowwise_scale_meta, colwise_scale_meta] = + get_scales(tensor_.shape(), tensor_.scaling_mode(), tensor_.get_qopt_block_scaling_dim()); if (rowwise_) { auto num_scales = product(rowwise_scale_meta.shape); - if (num_scales == 1){ + if (num_scales == 1) { rowwise_cpu_scale_inv_ptr()[0] = scale_inv; - } else{ + } else { std::uniform_int_distribution dis(0, 127); - auto* scale_inv_ptr = rowwise_cpu_scale_inv_ptr(); - for (size_t i = 0; i < num_scales; i++){ + auto *scale_inv_ptr = rowwise_cpu_scale_inv_ptr(); + for (size_t i = 0; i < num_scales; i++) { scale_inv_ptr[i] = dis(gen_); } } } if (columnwise_) { auto num_scales = product(colwise_scale_meta.shape); - if (num_scales == 1){ + if (num_scales == 1) { columnwise_cpu_scale_inv_ptr()[0] = scale_inv; - } else{ + } else { std::uniform_int_distribution dis(0, 127); - auto* scale_inv_ptr = columnwise_cpu_scale_inv_ptr(); - for (size_t i = 0; i < num_scales; i++){ + auto *scale_inv_ptr = columnwise_cpu_scale_inv_ptr(); + for (size_t i = 0; i < num_scales; i++) { scale_inv_ptr[i] = dis(gen_); } } @@ -406,23 +470,20 @@ void Tensor::set_scale_inv(float scale_inv) { } void Tensor::shareFP8Meta(const Tensor &other) { - if(isFp8Type(dtype()) && isFp8Type(other.dtype())) { + if (isFp8Type(dtype()) && isFp8Type(other.dtype())) { auto new_tensor = TensorWrapper(other.tensor_.scaling_mode()); auto my_rowwise_data = tensor_.get_rowwise_data(); - new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, - static_cast(my_rowwise_data.dtype), + new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, static_cast(my_rowwise_data.dtype), my_rowwise_data.shape); auto my_columnwise_data = tensor_.get_columnwise_data(); new_tensor.set_columnwise_data(my_columnwise_data.data_ptr, static_cast(my_columnwise_data.dtype), my_columnwise_data.shape); auto other_amax = other.tensor_.get_amax(); - new_tensor.set_amax(other_amax.data_ptr, - static_cast(other_amax.dtype), + new_tensor.set_amax(other_amax.data_ptr, static_cast(other_amax.dtype), other_amax.shape); auto other_scale = other.tensor_.get_scale(); - new_tensor.set_scale(other_scale.data_ptr, - static_cast(other_scale.dtype), + new_tensor.set_scale(other_scale.data_ptr, static_cast(other_scale.dtype), other_scale.shape); auto other_row_scale_inv = other.tensor_.get_rowwise_scale_inv(); new_tensor.set_rowwise_scale_inv(other_row_scale_inv.data_ptr, @@ -453,9 +514,7 @@ std::string to_string(const std::vector &v) { std::vector unravel(const size_t i, const NVTEShape &shape) { std::vector ret; size_t current_i = i; - for (size_t current = shape.ndim - 1; - current > 0; - --current) { + for (size_t current = shape.ndim - 1; current > 0; --current) { ret.push_back(current_i % shape.data[current]); current_i /= shape.data[current]; } @@ -693,7 +752,7 @@ void fillCase_special(Tensor *t) { }); } else { double minAbs = -2.0; - double maxAbs = 1.0; + double maxAbs = 1.0; if constexpr (Case != InputsFillCase::uniform) { minAbs = Quantized_Limits::ranges[Case]; maxAbs = Quantized_Limits::ranges[Case + 1]; @@ -752,14 +811,13 @@ void setRandomScaleInv(Tensor *t) { } bool isFp8Type(DType type) { - return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0; + return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0; } -int32_t getDeviceComputeCapability() -{ - cudaDeviceProp deviceProp; - cudaGetDeviceProperties(&deviceProp, 0); - return 10 * deviceProp.major + deviceProp.minor; +int32_t getDeviceComputeCapability() { + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, 0); + return 10 * deviceProp.major + deviceProp.minor; } size_t first_dimension(const std::vector &shape) { diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index dc515ccb8e..01134a23ca 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -95,21 +95,29 @@ struct TypeInfo{ constexpr static size_t size = sizeof(T); }; +struct QuantizationOptions { + bool force_pow_2_scales = false; + float amax_epsilon = 0.0; + size_t block_scaling_dim = 2u; +}; + class Tensor { public: Tensor(const std::string& name, const NVTEShape &shape, const DType type, const bool rowwise = true, const bool columnwise = false, - const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING); + const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING, + const QuantizationOptions* q_opts = nullptr); Tensor(const std::string& name, const std::vector &shape, const DType type, const bool rowwise = true, const bool columnwise = false, - const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) : - Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode) {} + const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING, + const QuantizationOptions* q_opts = nullptr) : + Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode, q_opts) {} Tensor() {} @@ -136,25 +144,19 @@ class Tensor { if (scale_inv != nullptr) { cudaFree(scale_inv); } - if (columnwise_data_ptr != nullptr){ + if (columnwise_data_ptr != nullptr) { cudaFree(columnwise_data_ptr); } - if (columnwise_scale_inv != nullptr){ + if (columnwise_scale_inv != nullptr) { cudaFree(columnwise_scale_inv); } } - NVTETensor data() const noexcept { - return tensor_.data(); - } + NVTETensor data() const noexcept { return tensor_.data(); } - NVTEShape rowwise_shape() const noexcept { - return tensor_.get_rowwise_data().shape; - } + NVTEShape rowwise_shape() const noexcept { return tensor_.get_rowwise_data().shape; } - NVTEShape columnwise_shape() const noexcept { - return tensor_.get_columnwise_data().shape; - } + NVTEShape columnwise_shape() const noexcept { return tensor_.get_columnwise_data().shape; } NVTEShape rowwise_scale_inv_shape() const { NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); @@ -221,6 +223,8 @@ class Tensor { T *rowwise_cpu_scale_inv_ptr(){ if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); + } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING) { + NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else { NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); } @@ -232,6 +236,8 @@ class Tensor { T *columnwise_cpu_scale_inv_ptr(){ if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); + } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING) { + NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else { NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); } @@ -455,6 +461,7 @@ extern std::vector all_fp_types; bool isFp8Type(DType type); int32_t getDeviceComputeCapability(); +constexpr int32_t hopperComputeCapability = 90; constexpr int32_t blackwellComputeCapability = 100; } // namespace test diff --git a/tests/pytorch/references/blockwise_quantizer_reference.py b/tests/pytorch/references/blockwise_quantizer_reference.py new file mode 100644 index 0000000000..5331fd4839 --- /dev/null +++ b/tests/pytorch/references/blockwise_quantizer_reference.py @@ -0,0 +1,373 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import dataclasses +import math +import torch +from typing import Optional, Protocol, Tuple + + +@dataclasses.dataclass() +class QuantizeResult: + data: torch.Tensor + scale: torch.Tensor + data_t: Optional[torch.Tensor] + scale_t: Optional[torch.Tensor] + + +# FIXME(kwyss): Put this in a common location for per-tensor current +# scaling reference +def _scale_from_amax_tensor( + x_dtype: torch.dtype, + amax: torch.Tensor, + quant_dtype: torch.dtype, + *, + eps: float, + pow_2_scales: bool, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Derives quantization and dequantization from amax and options. + + Reference implementation for scale calculation. + + Returns: + - scale: quantization scales + - scale_inv: dequantization scales + - amax: Amax tensor with updates made for extrema values. + """ + assert amax.dtype == torch.float, "amax must be a float tensor." + fp8_max = torch.finfo(quant_dtype).max + # Clamping amax to avoid division by small numbers + amax = torch.max(amax, torch.tensor(eps)) + + # Compute scale factor + scale = torch.div(fp8_max, amax) + # Note frexp doesn't give back inf for exponent with an inf input + # We take care of inf before pow_2_scales + scale = torch.where(scale == torch.inf, torch.finfo(x_dtype).max, scale) + if pow_2_scales: + # Calculate rounded down exponent + _, exp = torch.frexp(scale) + # Positive numbers are always returned as mant, exp with + # a mantissa in [0.5, 1.0). Because a normal float has a mantissa with + # hidden bit in [1.0, 2.0), the exponent will be off by exactly one because + # of the shift. Subnormal and zero cases need not be considered because + # the smallest possible result of fp8_max / amax is still normal. + exp = exp - 1 + # No subnormals and zero. + assert (exp > -127).all() + unity = torch.tensor([1.0], device=exp.device) + torch.ldexp(unity, exp, out=scale) + # Case where amax is inf. The frexp, ldexp logic changes 0.0 scales + # Return 0.0 for 0.0 scale for consistency with non-pow2 scale + # calculation. + scale = torch.where(amax == float("inf"), 0.0, scale) + + # Handle overflow cases for amax zero causing NaN + scale = torch.where(amax == 0, 1.0, scale) + + # Compute scale_inv + scale_inv = torch.reciprocal(scale) + + return scale, scale_inv, amax + + +@dataclasses.dataclass() +class CuBLASScaleMunger: + + def munge_scale_shapes_for_backend( + self, + unmunged: QuantizeResult, + tile_shape: Tuple[int, int], + ) -> QuantizeResult: + """ + cuBLAS GEMMs requires 1x128 quantized tensors to be have scales transposed + so that for an (M, N) tensor, the scales are (RounUpDiv(N, 128), M) + + For 128x128 quantized tensors, the GEMM expects (M, PadToAlign(RoundUpDivide(N, 128), 4)) + format. If RoundUpDivide(N, 128) is not divisible by 4, a transformation is required + """ + if tile_shape[0] != 1: + # 2D block quantized tensor needs padding for cuBLAS GEMM. + def _munge_scale_tensor(s: torch.Tensor) -> torch.Tensor: + M, K = s.shape + if K % 4 == 0: + return s + k_pad = 4 - (K % 4) + return torch.nn.functional.pad( + s, (0, k_pad), mode="constant", value=0 + ).contiguous() + + s = _munge_scale_tensor(unmunged.scale) + if unmunged.scale_t is None: + s_t = None + else: + s_t = _munge_scale_tensor(unmunged.scale_t) + return QuantizeResult(unmunged.data, s, unmunged.data_t, s_t) + + # 1D block quantized tensors needs transpose to prepare for the GEMM. + s = unmunged.scale.transpose(-1, -2).contiguous() + if unmunged.scale_t is None: + s_t = None + else: + s_t = unmunged.scale_t.transpose(-1, -2).contiguous() + return QuantizeResult(unmunged.data, s, unmunged.data_t, s_t) + + def demunge_scale_shape_from_backend( + cls, + qtensor_shape: Tuple[int, int], + scales: torch.Tensor, + tile_shape: Tuple[int, int], + ) -> torch.Tensor: + """ + Inverse operation of munge_scale_shapes_for_backend + """ + if tile_shape[0] != 1: + # 2D block quantized tensor may need padding stripped off + derived_scale_k_shape = math.ceil(qtensor_shape[1] / tile_shape[1]) + M, K = scales.shape + if derived_scale_k_shape == K: + return scales + else: + return scales[:, :derived_scale_k_shape].contiguous() + return scales.transpose(-1, -2).contiguous() + + +@dataclasses.dataclass() +class BlockwiseQuantizerReference: + """ + A reference QuantizeOp for subchannel/block hybrid quantization. + + Defers to ref GEMMs and quantizization formatting based on the backend. + """ + + def __init__(self) -> None: + self.scale_munger = CuBLASScaleMunger() + + @classmethod + def _quantize_square_block_tiling( + cls, + x: torch.Tensor, + quant_dtype: torch.dtype, + tile_len: int, + *, + return_transpose: bool, + pow_2_scales: bool, + eps: float, + ) -> QuantizeResult: + M, K = x.shape + + pad_m_k = [0, 0] + if K % tile_len != 0: + pad_m_k[1] = tile_len - (K % tile_len) + if M % tile_len != 0: + pad_m_k[0] = tile_len - (M % tile_len) + + unpadded_m, unpadded_k = M, K + if pad_m_k[0] != 0 or pad_m_k[1] != 0: + x = torch.nn.functional.pad( + x, (0, pad_m_k[1], 0, pad_m_k[0]), mode="constant", value=0 + ).contiguous() + M, K = x.shape + + x_tiled = x.reshape(M // tile_len, tile_len, K // tile_len, tile_len) + amax_grid = ( + torch.abs(x_tiled.transpose(-3, -2)) + .reshape(M // tile_len, K // tile_len, tile_len**2) + .amax(dim=-1) + ).float() + dtype_max = torch.finfo(quant_dtype).max + + scale, scale_inv, _ = _scale_from_amax_tensor( + x_dtype=x.dtype, + amax=amax_grid, + quant_dtype=quant_dtype, + pow_2_scales=pow_2_scales, + eps=eps, + ) + qx = x_tiled * scale.reshape(M // tile_len, 1, K // tile_len, 1) + qx = torch.clamp(qx, min=-dtype_max, max=dtype_max) + qx = qx.to(dtype=quant_dtype) + qx = qx.reshape(M, K) + if unpadded_k != K or unpadded_m != M: + qx = qx[:unpadded_m, :unpadded_k].contiguous() + if return_transpose: + # Valid because of square block sizes + qx_t = qx.transpose(-1, -2).contiguous() + scale_inv_t = scale_inv.transpose(-1, -2).contiguous() + else: + qx_t = None + scale_inv_t = None + + return QuantizeResult( + data=qx, scale=scale_inv, data_t=qx_t, scale_t=scale_inv_t + ) + + @classmethod + def _quantize_vectorwise_reference( + cls, + x: torch.Tensor, + quant_dtype: torch.dtype, + tile_len: int, + *, + pow_2_scales: bool, + eps: float, + ) -> Tuple[torch.Tensor, torch.Tensor]: + M, K = x.shape + dtype_max = torch.finfo(quant_dtype).max + x_tiled = x.reshape(M, K // tile_len, tile_len) + amax_grid = torch.abs(x_tiled).amax(dim=-1).float() + scale, scale_inv, _ = _scale_from_amax_tensor( + x_dtype=x.dtype, + amax=amax_grid, + quant_dtype=quant_dtype, + pow_2_scales=pow_2_scales, + eps=eps, + ) + qx = x_tiled * scale.reshape(M, K // tile_len, 1) + qx = torch.clamp(qx, min=-dtype_max, max=dtype_max) + qx = qx.to(dtype=quant_dtype) + qx = qx.reshape(M, K) + return qx, scale_inv + + @classmethod + def _quantize_vector_tiling( + cls, + x: torch.Tensor, + quant_dtype: torch.dtype, + tile_len: int, + *, + return_transpose: bool, + pow_2_scales: bool, + eps: float, + ) -> QuantizeResult: + M, K = x.shape + + if K % tile_len == 0: + qref_input = x + else: + pad_amount = tile_len - (K % tile_len) + pad = (0, pad_amount) + qref_input = torch.nn.functional.pad(x, pad, mode="constant", value=0) + qout_padded, scale_inv = cls._quantize_vectorwise_reference( + qref_input, + quant_dtype, + tile_len=tile_len, + pow_2_scales=pow_2_scales, + eps=eps, + ) + if K % tile_len == 0: + qout = qout_padded + else: + qout = qout_padded[:, :K].contiguous() + + if return_transpose: + if M % tile_len == 0: + qref_input = x.transpose(-1, -2).contiguous() + else: + amount_to_pad = tile_len - (M % tile_len) + pad = (0, amount_to_pad) + qref_input = torch.nn.functional.pad( + x.transpose(-1, -2), pad, mode="constant", value=0 + ).contiguous() + qout_t_padded, scale_inv_t = cls._quantize_vectorwise_reference( + qref_input, + quant_dtype, + tile_len=tile_len, + pow_2_scales=pow_2_scales, + eps=eps, + ) + if M % tile_len == 0: + qout_t = qout_t_padded + else: + qout_t = qout_t_padded[:, :M].contiguous() + else: + qout_t, scale_inv_t = None, None + + return QuantizeResult( + data=qout, scale=scale_inv, data_t=qout_t, scale_t=scale_inv_t + ) + + def ref_dequantize_rowwise( + self, + q: torch.Tensor, + quant_tile_shape: Tuple[int, int], + s: torch.Tensor, + dtype: torch.dtype, + ) -> torch.Tensor: + assert q.dim() == 2 + q_M, q_K = q.shape + s = self.scale_munger.demunge_scale_shape_from_backend( + (q_M, q_K), s, quant_tile_shape + ) + assert len(s.shape) == 2 + m_tiles, k_tiles = s.shape + M, K = q.shape + unpadded_m, unpadded_k = M, K + if M % quant_tile_shape[0] != 0 or K % quant_tile_shape[1] != 0: + m_pad_amount = ( + quant_tile_shape[0] - (M % quant_tile_shape[0]) + ) % quant_tile_shape[0] + k_pad_amount = ( + quant_tile_shape[1] - (K % quant_tile_shape[1]) + ) % quant_tile_shape[1] + q = torch.nn.functional.pad( + q, (0, k_pad_amount, 0, m_pad_amount), mode="constant", value=0 + ).contiguous() + M, K = q.shape + q_tiled = q.reshape(m_tiles, quant_tile_shape[0], k_tiles, quant_tile_shape[1]) + result = q_tiled.to(dtype) * s.reshape(m_tiles, 1, k_tiles, 1) + result = result.view(M, K).to(dtype) + if M != unpadded_m or K != unpadded_k: + result = result[:unpadded_m, :unpadded_k].contiguous() + return result + + def quantize( + self, + x: torch.Tensor, + quant_dtype: torch.dtype, + return_transpose: bool = False, + eps: float = 0.0, + pow_2_scales: bool = False, + quant_tile_shape: Tuple[int, int] = (128, 128), + ) -> QuantizeResult: + # sanity checks + assert x.dim() == 2 + assert x.dtype in ( + torch.float, + torch.float16, + torch.bfloat16, + torch.float32, + ), "Unsupported input dtype." + assert quant_dtype in ( + torch.float8_e4m3fn, + torch.float8_e5m2, + ), "Unsupported quant dtype." + + assert quant_tile_shape in ((1, 128), (128, 128)) + if quant_tile_shape[0] == 1: + # Quantize row-wise + return self.scale_munger.munge_scale_shapes_for_backend( + self._quantize_vector_tiling( + x, + quant_dtype, + tile_len=quant_tile_shape[1], + return_transpose=return_transpose, + pow_2_scales=pow_2_scales, + eps=eps, + ), + quant_tile_shape, + ) + else: + # Quantize block-wise + return self.scale_munger.munge_scale_shapes_for_backend( + self._quantize_square_block_tiling( + x, + quant_dtype, + tile_len=quant_tile_shape[0], + return_transpose=return_transpose, + pow_2_scales=pow_2_scales, + eps=eps, + ), + quant_tile_shape, + ) diff --git a/tests/pytorch/test_float8_blockwise_scaling_exact.py b/tests/pytorch/test_float8_blockwise_scaling_exact.py new file mode 100644 index 0000000000..3603268d09 --- /dev/null +++ b/tests/pytorch/test_float8_blockwise_scaling_exact.py @@ -0,0 +1,288 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from typing import Tuple +import math +import pytest +import torch +import transformer_engine as te +import transformer_engine_torch as tex + +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + Float8BlockQuantizer, + Float8BlockwiseQTensor, +) +from tests.pytorch.references.blockwise_quantizer_reference import BlockwiseQuantizerReference, QuantizeResult + +def initialize_for_many_scales( + x_shape_2d: Tuple[int, int], + tile_shape: Tuple[int, int], + *, + dtype: torch.dtype, + device: str +) -> torch.Tensor: + """ + Put separate distributions into each quantization tile + to avoid many tiles having similar scale values and + causing false passes. + """ + tile_grid_shape = ( + math.ceil(x_shape_2d[0] / tile_shape[0]), + math.ceil(x_shape_2d[1] / tile_shape[1]), + ) + # Arbitrary size + max_val = 8192.0 + # Make a uniform distribution of [-max_val, max_val] + tile_extrema = torch.rand(*tile_grid_shape, dtype=dtype) * max_val * 2 - max_val + result = torch.empty(x_shape_2d, dtype=dtype, device=device) + tile_elements = tile_shape[0] * tile_shape[1] + for i in range(tile_grid_shape[0]): + for j in range(tile_grid_shape[1]): + target = tile_extrema[i, j].item() + step = target / (tile_elements) + if target == 0: + tile = torch.zeros(tile_shape, dtype=dtype, device=device) + else: + tile = torch.arange(0.0, target, step=step, dtype=dtype, device=device) + tile = tile.reshape(*tile_shape) + min_dst_vals = (i * tile_shape[0], j * tile_shape[1]) + max_dst_vals = ( + min((i + 1) * tile_shape[0], x_shape_2d[0]), + min((j + 1) * tile_shape[1], x_shape_2d[1]), + ) + max_src_vals = ( + max_dst_vals[0] - min_dst_vals[0], + max_dst_vals[1] - min_dst_vals[1], + ) + result[ + min_dst_vals[0] : max_dst_vals[0], min_dst_vals[1] : max_dst_vals[1] + ] = tile[: max_src_vals[0], : max_src_vals[1]] + return result + + +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (128, 128), + (256, 256), + (256, 1024), + (1024, 256), + # Padding required cases + (256, 272), + (300, 300), + (305, 256), + # Some larger tiles. + (2000, 2000), + (2048, 2000), + (2000, 1024), + (2048, 1024), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str +) +@pytest.mark.parametrize("eps", [0, 1e-12], ids=["eps_0", "eps_1e-12"]) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"] +) +@pytest.mark.parametrize("pow_2_scales", [True, False], ids=["pow2scales", "f32scales"]) +@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"]) +def test_quantization_block_tiling_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + quant_dtype: torch.dtype, + eps: float, + return_transpose: bool, + pow_2_scales: bool, + tile_size: Tuple[int, int], +) -> None: + te_dtype = TE_DType[quant_dtype] + if tile_size == (1, 128): + block_scaling_dim=1 + elif tile_size == (128, 128): + block_scaling_dim=2 + else: + raise ValueError("Non support tile size") + # This test runs a comparison of the ref class versus the class using + # CUDA kernels to quantize. They should quantize identically for pixels + # that are not DC values in the scale factor shape. + ref_quantizer = BlockwiseQuantizerReference() + sut_quantizer = Float8BlockQuantizer(fp8_dtype=te_dtype, + rowwise=True, + columnwise=return_transpose, + amax_epsilon=eps, + force_pow_2_scales=pow_2_scales, + block_scaling_dim=block_scaling_dim) + + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Input + x = initialize_for_many_scales((M, N), tile_size, dtype=x_dtype, device=device) + + x_fp8_sut = sut_quantizer.make_empty((M, N), dtype=x_dtype, device=device, requires_grad=False) + x_fp8_sut = sut_quantizer.update_quantized( + x, + x_fp8_sut + ) + + assert x_fp8_sut._rowwise_data is not None + qx: torch.Tensor = x_fp8_sut._rowwise_data.view(dtype=quant_dtype) + assert x_fp8_sut._rowwise_scale_inv is not None + sx: torch.Tensor = x_fp8_sut._rowwise_scale_inv + qx_t = x_fp8_sut._columnwise_data + sx_t = x_fp8_sut._columnwise_scale_inv + + qresult_ref = ref_quantizer.quantize( + x, quant_dtype=quant_dtype, return_transpose=return_transpose, + eps=eps, pow_2_scales=pow_2_scales, quant_tile_shape=tile_size + ) + qx_ref, sx_ref, qx_t_ref, sx_t_ref = ( + qresult_ref.data, + qresult_ref.scale, + qresult_ref.data_t, + qresult_ref.scale_t, + ) + + # Check + torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0, rtol=0.0) + if (tile_size[0] != 1): + # Zero out values that are don't care values + # cuBLAS has padding of 2D tensors. + scale_mask = torch.ones( + (math.ceil(M / tile_size[0]), math.ceil(N / tile_size[1])), device=sx.device + ) + scale_mask = ref_quantizer.scale_munger.munge_scale_shapes_for_backend( + QuantizeResult(qx, scale_mask, None, None), tile_size + ).scale + sx = sx * scale_mask + + torch.testing.assert_close(sx, sx_ref, atol=0.0, rtol=0.0) + + if return_transpose: + assert qx_t is not None + qx_t = qx_t.view(dtype=quant_dtype) + assert qx_t_ref is not None + assert sx_t is not None + assert sx_t_ref is not None + if (tile_size[0] != 1): + scale_mask = torch.ones( + (math.ceil(N / tile_size[0]), math.ceil(M / tile_size[1])), + device=sx_t.device, + ) + scale_mask = ref_quantizer.scale_munger.munge_scale_shapes_for_backend( + QuantizeResult(qx_t, scale_mask, None, None), tile_size + ).scale + sx_t = sx_t * scale_mask + torch.testing.assert_close(qx_t.float(), qx_t_ref.float(), atol=0.0, rtol=0.0) + torch.testing.assert_close(sx_t, sx_t_ref, atol=0.0, rtol=0.0) + else: + # should be None + assert qx_t is None and qx_t_ref is None + assert sx_t is None and sx_t_ref is None + +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (1, 128), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str +) +@pytest.mark.parametrize("eps", [0, math.pow(2, -125)], ids=["eps_0", "eps_small"]) +@pytest.mark.parametrize("pow_2_scales", [True, False], ids=["pow2scales", "f32scales"]) +@pytest.mark.parametrize("tile_size", [(1, 128)]) +@pytest.mark.parametrize("extrema_high", [False, True], ids=["zeros", "maxes"]) +def test_quantization_block_tiling_extrema_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + quant_dtype: torch.dtype, + eps: float, + pow_2_scales: bool, + tile_size: Tuple[int, int], + extrema_high: bool, +) -> None: + # This test runs a single tile through a quantizer as a way to test + # branch coverage of scale computation. + te_dtype = TE_DType[quant_dtype] + if tile_size == (1, 128): + block_scaling_dim=1 + elif tile_size == (128, 128): + block_scaling_dim=2 + else: + raise ValueError("Non support tile size") + ref_quantizer = BlockwiseQuantizerReference() + sut_quantizer = Float8BlockQuantizer(fp8_dtype=te_dtype, + rowwise=True, + columnwise=False, + amax_epsilon=eps, + force_pow_2_scales=pow_2_scales, + block_scaling_dim=block_scaling_dim) + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + return_transpose = False + # Input + if extrema_high: + x = torch.full((M, N), torch.finfo(x_dtype).max, dtype=x_dtype, device=device) + else: + x = torch.zeros((M, N), dtype=x_dtype, device=device) + + # Run cast and transpose kernel + # Internal call ops.quantize_tensorwise + x_fp8_sut = sut_quantizer.make_empty((M, N), dtype=x_dtype, device=device, requires_grad=False) + x_fp8_sut = sut_quantizer.update_quantized( + x, + x_fp8_sut + ) + qx = x_fp8_sut._rowwise_data.view(dtype=quant_dtype) + sx = x_fp8_sut._rowwise_scale_inv + + qresult_ref = ref_quantizer.quantize( + x, quant_dtype=quant_dtype, return_transpose=return_transpose, + eps=eps, pow_2_scales=pow_2_scales, quant_tile_shape=tile_size + ) + qx_ref, sx_ref = ( + qresult_ref.data, + qresult_ref.scale, + ) + + # Check + torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0, rtol=0.0) + torch.testing.assert_close(sx, sx_ref, atol=0.0, rtol=0.0) + + if extrema_high: + expected_value = torch.finfo(quant_dtype).max / torch.finfo(x_dtype).max + if pow_2_scales: + expected_value = math.floor(math.log2(expected_value)) + expected_value = math.pow(2.0, expected_value) + expected_value = 1 / expected_value + elif not extrema_high and eps == 0: + expected_value = 1.0 + else: + assert not extrema_high + # eps is small enough to trigger inf in quant_dtype_max / eps + if pow_2_scales: + expected_value = math.pow(2.0, -127) + else: + expected_value = 1 / torch.finfo(x_dtype).max + torch.testing.assert_close( + sx, + torch.tensor([expected_value], device=sx.device).reshape(1, 1), + atol=0.0, + rtol=0.0, + ) diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py new file mode 100644 index 0000000000..6deb714ce2 --- /dev/null +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -0,0 +1,207 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from collections.abc import Iterable +import io +import math +from typing import Any, Dict, List, Tuple, Union + +import pytest +import torch + +import transformer_engine.common.recipe +import transformer_engine.pytorch as te +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + Float8BlockQuantizer, + Float8BlockwiseQTensor, +) +import transformer_engine_torch as tex + +# PyTorch tensor dtypes +_dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16] +# TE FP8 dtypes +_fp8_dtypes: List[tex.DType] = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2] + +# Numerical tolerances with FP8 types +_tols: Dict[tex.DType, Dict[str, float]] = { + tex.DType.kFloat8E4M3: dict(rtol=0.125, atol=0.0675), # epsilon = 0.0625 + tex.DType.kFloat8E5M2: dict(rtol=0.25, atol=0.125), # epsilon = 0.125 +} + + +def _to_list(x: Union[Iterable, Any]) -> List: + """Convert to list if iterable, otherwise put in singleton list""" + if isinstance(x, Iterable): + return list(x) + else: + return [x] + + +# Types that can be interpreted as tensor dims +DimsType = Union[Iterable[int], int] + +# Check if FP8 is supported +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +class TestFloat8BlockwiseTensor: + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + def test_constructor( + self, + dims: DimsType = 1, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + dtype: torch.dtype = torch.float32, + ) -> None: + """Call constructor and perform sanity checks""" + dims = _to_list(dims) + + rowwise = True + columnwise = True + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, rowwise=rowwise, columnwise=columnwise + ) + + scale_dims = quantizer.get_scale_shape(dims, columnwise=False) + columnwise_scale_dims = quantizer.get_scale_shape(dims, columnwise=True) + columnwise_dims = quantizer.get_columnwise_shape(dims) + tensor = Float8BlockwiseQTensor( + shape=dims, + dtype=dtype, + rowwise_data=torch.zeros(dims, device="cuda", dtype=torch.uint8), + rowwise_scale_inv=torch.zeros( + scale_dims, device="cuda", dtype=torch.float32 + ), + columnwise_data=torch.zeros( + columnwise_dims, device="cuda", dtype=torch.uint8 + ), + columnwise_scale_inv=torch.zeros( + columnwise_scale_dims, device="cuda", dtype=torch.float32 + ), + fp8_dtype=fp8_dtype, + quantizer=quantizer, + ) + assert list(tensor.size()) == dims, "Incorrect dims" + assert tensor.dtype == dtype, "Incorrect nominal dtype" + assert tensor.is_cuda, "Incorrect device" + + def _test_quantize_dequantize( + self, + quantizer: Float8BlockQuantizer, + dtype: torch.dtype = torch.float32, + dims: DimsType = (23, 128), + rtol: float = 0.0, + atol: float = 0.0, + dequant_columnwise: bool = False, + use_cpp_allocation: bool = False, + ) -> None: + """Check numerical error when casting to FP8 and back""" + dims = _to_list(dims) + + # Initialize random data + x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + x_ref_cuda = x_ref.to("cuda") + + # Cast to FP8 and back + if not use_cpp_allocation: + x_fp8 = quantizer.make_empty(shape=dims, device="cuda") + quantizer.update_quantized(x_ref_cuda, x_fp8) + else: + # This codepath allows the CPP binding to allocate the output + # tensor + x_fp8 = tex.quantize(x_ref_cuda, quantizer, None, None) + if dequant_columnwise: + # Strip out rowwise data to verify dequantization of + # columnwise data. + x_fp8.update_usage(rowwise_usage=False, columnwise_usage=True) + x_fp8 = x_fp8.dequantize(dtype=dtype).cpu() + + # Check results + torch.testing.assert_close(x_fp8, x_ref, rtol=rtol, atol=atol) + + # Make sure we are not trivially passing the test + with pytest.raises(AssertionError): + torch.testing.assert_close(x_fp8, -x_ref, rtol=rtol, atol=atol) + + @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("block_scaling_dim", [1, 2]) + def test_quantize_dequantize_dtypes( + self, fp8_dtype: tex.DType, dtype: torch.dtype, block_scaling_dim: int + ) -> None: + atol = _tols[fp8_dtype]["atol"] + rtol = _tols[fp8_dtype]["rtol"] + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, + rowwise=True, + columnwise=False, + block_scaling_dim=block_scaling_dim, + ) + self._test_quantize_dequantize( + quantizer=quantizer, dtype=dtype, atol=atol, rtol=rtol + ) + + @pytest.mark.parametrize( + "dims", [[], 256, 311, [264], [256, 512], [250, 500], [7, 5, 3], [2, 3, 5, 3]] + ) + @pytest.mark.parametrize("block_scaling_dim", [1, 2]) + @pytest.mark.parametrize("dq_columnwise", [True, False]) + def test_quantize_dequantize_dims( + self, dims: DimsType, block_scaling_dim: int, dq_columnwise: bool + ) -> None: + atol = _tols[tex.DType.kFloat8E4M3]["atol"] + rtol = _tols[tex.DType.kFloat8E4M3]["rtol"] + quantizer = Float8BlockQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=dq_columnwise, + block_scaling_dim=block_scaling_dim, + ) + self._test_quantize_dequantize( + quantizer=quantizer, + dims=dims, + atol=atol, + rtol=rtol, + dequant_columnwise=dq_columnwise, + ) + + @pytest.mark.parametrize( + "dims", [[], 256, 311, [264], [256, 512], [250, 500], [7, 5, 3], [2, 3, 5, 3]] + ) + @pytest.mark.parametrize("block_scaling_dim", [1, 2]) + @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) + @pytest.mark.parametrize("dq_columnwise", [True, False]) + def test_quantize_dequantize_dims_cpp_allocate_output( + self, dims: DimsType, block_scaling_dim: int, fp8_dtype: tex.DType, dq_columnwise: bool + ) -> None: + atol = _tols[fp8_dtype]["atol"] + rtol = _tols[fp8_dtype]["rtol"] + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, + rowwise=True, + columnwise=dq_columnwise, + block_scaling_dim=block_scaling_dim, + ) + self._test_quantize_dequantize( + quantizer=quantizer, + dims=dims, + atol=atol, + rtol=rtol, + dequant_columnwise=dq_columnwise, + use_cpp_allocation=True, + ) + + # FIXME(kwyss): Add some testing for other tensor operations. + # - basic_ops + # - in_place_ops + # - serialization + # - set_data diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index c77d230ce5..17421d039e 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -58,6 +58,8 @@ list(APPEND transformer_engine_SOURCES transpose/cast_transpose_fusion.cu transpose/transpose_fusion.cu transpose/multi_cast_transpose.cu + transpose/quantize_transpose_square_blockwise.cu + transpose/quantize_transpose_vector_blockwise.cu activation/gelu.cu fused_attn/fused_attn_f16_max512_seqlen.cu fused_attn/fused_attn_f16_arbitrary_seqlen.cu diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 46eb248156..fbcc0244bc 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -86,6 +86,10 @@ struct Tensor { NVTEScalingMode scaling_mode; + float amax_epsilon; + bool force_pow_2_scales; + int block_scaling_dim; + Tensor() : data(), columnwise_data(), @@ -93,7 +97,10 @@ struct Tensor { scale(nullptr, {1}, DType::kFloat32), scale_inv(nullptr, {1}, DType::kFloat32), columnwise_scale_inv(nullptr, {1}, DType::kFloat32), - scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {} + scaling_mode(NVTE_DELAYED_TENSOR_SCALING), + amax_epsilon(0.0), + force_pow_2_scales(false), + block_scaling_dim(scaling_mode == NVTE_BLOCK_SCALING ? 2 : 0) {} int numel() const { NVTE_CHECK(data.dptr != nullptr || columnwise_data.dptr != nullptr, @@ -116,6 +123,33 @@ struct Tensor { bool has_columnwise_data() const noexcept { return columnwise_data.dptr != nullptr; } + bool supports_force_pow_2_scales_qopt() const noexcept { + switch (scaling_mode) { + case NVTE_BLOCK_SCALING: + return true; + default: + return false; + } + } + + bool supports_amax_epsilon_qopt() const noexcept { + switch (scaling_mode) { + case NVTE_BLOCK_SCALING: + return true; + default: + return false; + } + } + + bool supports_block_scaling_dim(int block_scaling_dim) const noexcept { + switch (scaling_mode) { + case NVTE_BLOCK_SCALING: + return block_scaling_dim == 1 || block_scaling_dim == 2; + default: + return false; + } + } + DType dtype() const { if (has_data()) return data.dtype; if (has_columnwise_data()) return columnwise_data.dtype; @@ -396,6 +430,19 @@ struct TypeInfo { } \ } +#define TRANSFORMER_ENGINE_SWITCH_CONDITION(CONDITION, FLAG, ...) \ + if (CONDITION) { \ + constexpr bool FLAG = true; \ + { \ + __VA_ARGS__ \ + } \ + } else { \ + constexpr bool FLAG = false; \ + { \ + __VA_ARGS__ \ + } \ + } + //////////////////////////////////////////////////////////////////////////////////////////////////// inline int log2_ceil(int value) { diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 52fa89b914..cb1cc0e757 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -76,6 +76,9 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla const transformer_engine::Tensor &B, const cublasOperation_t transB, const int k, const int lda, const int ldb) { using namespace transformer_engine; + // FIXME(kwyss): 1x128 by 128x128 GEMM is part of the subchannel design. + // Must either force them both into a common block scaling mode or loosen this + // restriction. NVTE_CHECK(A.scaling_mode == B.scaling_mode, "Inputs A and B to GEMM need to have the same scaling mode!"); NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!"); @@ -85,6 +88,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.lda = lda; ret.ldb = ldb; + // FIXME(kwyss): 128x128 by 128x128 GEMMs and 1x128 by 128x128 GEMMs need cases + // or need to be treated as `is_tensor_scaling`. if (is_tensor_scaling(A.scaling_mode)) { ret.A = A.data.dptr; ret.A_scale_inv = A.scale_inv.dptr; @@ -238,6 +243,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(fastAccuMode))); + // FIXME(kwyss): Add binding code for 128x128 block quantized 1x128 block quantized + // GEMM types. + // Scaling factors. #if CUDA_VERSION >= 12080 cublasLtMatmulMatrixScale_t scaling_mode; diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index d57975b2f4..9be0e14d8a 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -42,23 +42,25 @@ extern "C" { * of the output tensor should be set to 0. */ -/*! \brief Casts input tensor to FP8/MXFP8. +/*! \brief Casts input tensor to FP8/MXFP8/BlockwiseFP8. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, - * the block quantization (MXFP8) of the specified shape of the block will be used. + * the MXFP8 block quantization of the specified shape of the block will be used. + * If the scaling mode of the output tensor is set to NVTE_BLOCK_SCALING, + * blockwise float8 scaling will be used. * * \param[in] input Input tensor to be cast. - * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[in,out] output Output FP8/MXFP8/BlockwiseFP8 tensor. * \param[in] stream CUDA stream used for the operation. */ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); /*! \brief Casts input tensor to FP8/MXFP8, providing the option to immediately exit the kernel * based on the value of the 'noop' tensor. - * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, - * the block quantization (MXFP8) of the specified shape of the block will be used. + * The type of quantized tensor in the output depends on the scaling mode of the output + * tensor. * * \param[in] input Input tensor to be cast. - * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[in,out] output Output quantized tensor. * \param[out] noop Noop tensor. * \param[in] stream CUDA stream used for the operation. */ diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index e393dbffc4..ac0a4a0a77 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -77,6 +77,11 @@ enum NVTEScalingMode { /*! Single scale per block of 32 elements consecutive in either rowwise or columnwise direction */ NVTE_MXFP8_1D_SCALING = 1, + /*! Tensor is split into NxN quantization tiles or 1xN quantization tiles, + which each yield a scale. The block_scaling_dim property of the quantizer + selects the granularity. + */ + NVTE_BLOCK_SCALING = 2, NVTE_INVALID_SCALING }; @@ -231,6 +236,63 @@ void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream); void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, const NVTEBasicTensor *param); +/*! \brief Set a quantization option for whether to force power of 2 scales. + * + * \param[in/out] tensor Tensor. + * \param[in] zero_if_false Whether to force power of 2 scales. + * + * \return zero if the tensor supports this option and it was set. non-zero if + * call had no effect. + */ +int nvte_set_qopt_force_pow_2_scales(NVTETensor tensor, int zero_if_false); + +/*! \brief Set a quantization option for epsilon to set floor of amax. + * + * \param[in/out] tensor Tensor. + * \param[in] amax_epsilon Epsilon to use for amax calculation. + * + * \return zero if the tensor supports this option and it was set. non-zero if + * call had no effect. + */ +int nvte_set_qopt_amax_epsilon(NVTETensor tensor, float amax_epsilon); + +/*! \brief Set a quantization option to use 1D or 2D quantization blocks + * to scale the tensor. + * + * \param[in/out] tensor Tensor. + * \param[in] block_scaling_dim, 1D or 2D. + * + * \return zero if the tensor supports this option and it was set. non-zero if + * call had no effect or the number of dims is not supported. + */ +int nvte_set_qopt_block_scaling_dim(NVTETensor tensor, int block_scaling_dim); + +/*! \brief Get a quantization option for whether to force power of 2 scales. + * + * \param[in] tensor Tensor. + * + * \return zero if the tensor will not force power of 2 scales or if the + * setting is irrelevant. non-zero if the flag is configured. + */ +int nvte_get_qopt_force_pow_2_scales(NVTETensor tensor); + +/*! \brief Get a quantization option for amax epsilon. + * + * \param[in] tensor Tensor. + * + * \return amax_epsilon value or zero if not applicable. + */ +float nvte_get_qopt_amax_epsilon(const NVTETensor tensor); + +/*! \brief Get the number of dimensions in the quantization blocks. + * + * \param[in] tensor Tensor. + * + * \return zero if the quantization does not support the block_scaling_dim + * option or the block_scaling_dim configured. + */ +int nvte_get_qopt_block_scaling_dim(const NVTETensor tensor); + /*! \brief Get a value of the parameter of the tensor. * * \param[in] tensor Tensor. @@ -598,6 +660,24 @@ class TensorWrapper { void zero_(cudaStream_t stream) { nvte_zero_tensor(tensor_, stream); } + int set_qopt_force_pow_2_scales(bool flag) { + return nvte_set_qopt_force_pow_2_scales(tensor_, flag ? 1 : 0); + } + + int set_qopt_amax_epsilon(float eps) { return nvte_set_qopt_amax_epsilon(tensor_, eps); } + + int set_qopt_block_scaling_dim(int block_scaling_dim) { + return nvte_set_qopt_block_scaling_dim(tensor_, block_scaling_dim); + } + + bool get_qopt_force_pow_2_scales() const { + return nvte_get_qopt_force_pow_2_scales(tensor_) != 0; + } + + float get_qopt_amax_epsilon() const { return nvte_get_qopt_amax_epsilon(tensor_); } + + int get_qopt_block_scaling_dim() const { return nvte_get_qopt_block_scaling_dim(tensor_); } + static constexpr size_t defaultData = 1; static constexpr NVTEShape defaultShape = {&defaultData, 1}; diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index faf6ec990d..9e9174237a 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -412,3 +412,48 @@ void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) { } cudaStreamSynchronize(stream); } + +int nvte_set_qopt_force_pow_2_scales(NVTETensor tensor, int zero_if_false) { + auto &t = *reinterpret_cast(tensor); + if (t.supports_force_pow_2_scales_qopt()) { + t.force_pow_2_scales = zero_if_false != 0; + return 0; + } else { + return 1; + } +} + +int nvte_set_qopt_amax_epsilon(NVTETensor tensor, float amax_epsilon) { + auto &t = *reinterpret_cast(tensor); + if (t.supports_amax_epsilon_qopt()) { + t.amax_epsilon = amax_epsilon; + return 0; + } else { + return 1; + } +} + +int nvte_set_qopt_block_scaling_dim(NVTETensor tensor, int block_scaling_dim) { + auto &t = *reinterpret_cast(tensor); + if (t.supports_block_scaling_dim(block_scaling_dim)) { + t.block_scaling_dim = block_scaling_dim; + return 0; + } else { + return 1; + } +} + +int nvte_get_qopt_force_pow_2_scales(const NVTETensor tensor) { + const auto &t = *reinterpret_cast(tensor); + return t.force_pow_2_scales ? 1 : 0; +} + +float nvte_get_qopt_amax_epsilon(const NVTETensor tensor) { + const auto &t = *reinterpret_cast(tensor); + return t.amax_epsilon; +} + +int nvte_get_qopt_block_scaling_dim(const NVTETensor tensor) { + const auto &t = *reinterpret_cast(tensor); + return t.block_scaling_dim; +} diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index ed9bd5f5f7..893207a51a 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -23,6 +23,26 @@ template +#include +#include + +#include +#include + +namespace transformer_engine { + +// Type trait for extreme values of fp8 types. +// Used in the calculation of scale factors +// as a constexpr lookup from e4m3 or e5m2 to +// the max finite value. +template +struct F8LimitsTrait; + +template <> +struct F8LimitsTrait<__nv_fp8_e4m3> { + static constexpr float max = 448.0f; +}; + +template <> +struct F8LimitsTrait<__nv_fp8_e5m2> { + static constexpr float max = 57344.0f; +}; + +// Type trait to resolve the max finite value +// represented by a input type to quantization. +// Or to represent max representable power of 2 +// finite value. +template +struct HighPrecisionFloatScaleLimitsTrait; + +template <> +struct HighPrecisionFloatScaleLimitsTrait { + static constexpr float max = std::numeric_limits::max(); +}; + +template <> +struct HighPrecisionFloatScaleLimitsTrait { + // Hex float format of 1.0 * 2 ^ 127 + static constexpr float max = 0x1.0p127; +}; + +template <> +struct HighPrecisionFloatScaleLimitsTrait { + // Hex float format of 1.(7 bits of 1) * 2 ^ 127 + static constexpr float max = 0x1.FEp127; +}; + +template <> +struct HighPrecisionFloatScaleLimitsTrait { + // Hex float format of 1.0 * 2 ^ 127 + static constexpr float max = 0x1.0p127; +}; + +template <> +struct HighPrecisionFloatScaleLimitsTrait { + // Hex float format of 1.(10 bits of 1) * 2 ^ 15 + static constexpr float max = 0x1.FFCp15; +}; + +template <> +struct HighPrecisionFloatScaleLimitsTrait { + // Hex float format of 1.0 * 2 ^ 15 + static constexpr float max = 0x1.0p15; +}; + +// Calculate the quantization scale for an individual data element +// given the amax(abs(tile)) value for a given quantization tile. +// +// +// Arguments: +// IType: data type of the tensor being quantized (float or bf16) +// OType: quantized data type (e4m3 or e5m2) +// pow_2_scaling: Whether to force the scale to be a power of 2. +// amax: The evaluation of amax(abs(tile)) for the quantization tile. +// eps: An epsilon used as a floor for amax. +template +__device__ __forceinline__ float ComputeScale(const float amax, const float eps) { + constexpr float fp8_max = F8LimitsTrait::max; + + // Clamping amax to avoid division by small numbers + float amax_mod = fmaxf(amax, eps); + + // Handle overflow cases for non-clamped amax (eps is 0 or very small) + if (amax_mod == 0.f) { + // If amax is 0, return 1 + return 1.f; + } + // Compute scale factor + float scale = fp8_max / amax_mod; + + if (isinf(scale)) { + // If scale is infinity, return max value of IType + return HighPrecisionFloatScaleLimitsTrait::max; + } + if (scale == 0.0) { + // Case that amax is "inf". The frexp, ldexp logic changes 0.0 scales. + // Return 0.0 for 0.0 scale here is consistent with non-Power2Scaling model. + // quantization will remove signal from the tensor, + // this is bad for the model, but define pow2Scale behavior + // as returning 0.0 scale. amax calculation can + // improve the situation to avoid this by taking largest finite. + return scale; + } + if constexpr (Power2Scaling) { + // NOTE: using bit fiddling based on advice of Asit in this + // thread: https://nvidia.slack.com/archives/C06EDT7LZEW/p1738274404153439 + + // inf scales already early returned, as did nan scales. + // The cases to consider here are normals, zero, and subnormals. + // zero is not possible with current math as + // 448.0 / float_max == 1.31655e-36, which is the smallest + // possible scale given current dtypes. It is still in the normal + // fp32 range with an exponent of -120, so subnormals are also + // not possible. To handle normals, we can simply mask off the + // mantissa. + uint32_t scale_bits = *reinterpret_cast(&scale); + scale_bits &= 0xFF800000; + // If the exponent was zero, we have a logic error. + __builtin_assume(scale_bits != 0); + __builtin_assume(scale_bits != 0x80000000); + scale = *reinterpret_cast(&scale_bits); + } + return scale; +} + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_COMPUTE_SCALE_CUH_ \ No newline at end of file diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu new file mode 100644 index 0000000000..578475e219 --- /dev/null +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -0,0 +1,640 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include + +#include "common/common.h" +#include "common/utils.cuh" +#include "compute_scale.cuh" + +#if (!defined(__CUDA_MINIMUM_ARCH__)) || \ + (defined(__CUDA_MINIMUM_ARCH__) && __CUDA_MINIMUM_ARCH__ >= 900) +#define TMA_HW_SUPPORTED +#endif + +namespace transformer_engine { +namespace { + +#ifdef TMA_HW_SUPPORTED +using barrier = cuda::barrier; +namespace cde = cuda::device::experimental; +#endif + +// const values configuration + +constexpr size_t kThreadsPerWarp = 32; +#ifdef TMA_HW_SUPPORTED +constexpr size_t BLOCK_TILE_DIM = 128; +constexpr size_t WARP_TILE_DIM_X = 32; +constexpr size_t WARP_TILE_DIM_Y = 64; +constexpr size_t THREAD_TILE_DIM_X = 16; +constexpr size_t THREAD_TILE_DIM_Y = 4; +#else +constexpr size_t BLOCK_TILE_DIM = 128; +constexpr size_t WARP_TILE_DIM_X = 64; +constexpr size_t WARP_TILE_DIM_Y = 32; +constexpr size_t THREAD_TILE_DIM_X = 8; +constexpr size_t THREAD_TILE_DIM_Y = 8; +#endif + +#ifdef TMA_HW_SUPPORTED +constexpr size_t NUM_BYTES_PER_BANK = 4; +constexpr size_t NUM_BANKS_PER_SHARED_ELEM = THREAD_TILE_DIM_Y / NUM_BYTES_PER_BANK; +constexpr size_t SHARED_BLOCK_TILE_DIM_Y = BLOCK_TILE_DIM; +constexpr size_t SHARED_BLOCK_TILE_DIM_X_BANKS = + BLOCK_TILE_DIM / (NUM_BYTES_PER_BANK * NUM_BANKS_PER_SHARED_ELEM); +constexpr size_t NUM_BANKS_Y_IN_WARP = WARP_TILE_DIM_Y / NUM_BYTES_PER_BANK; +#endif +constexpr size_t ELE_PER_THREAD = THREAD_TILE_DIM_X * THREAD_TILE_DIM_Y; +constexpr size_t THREADS_PER_BLOCK = BLOCK_TILE_DIM * BLOCK_TILE_DIM / ELE_PER_THREAD; +constexpr size_t NUM_WARPS_X_IN_BLOCK = BLOCK_TILE_DIM / WARP_TILE_DIM_X; +constexpr size_t NUM_WARPS_Y_IN_BLOCK = BLOCK_TILE_DIM / WARP_TILE_DIM_Y; +constexpr size_t NUM_WARPS_IN_BLOCK = NUM_WARPS_X_IN_BLOCK * NUM_WARPS_Y_IN_BLOCK; + +constexpr size_t NUM_THREADS_X_IN_WARP = WARP_TILE_DIM_X / THREAD_TILE_DIM_X; +constexpr size_t NUM_THREADS_Y_IN_WARP = kThreadsPerWarp / NUM_THREADS_X_IN_WARP; + +#define MIN(a, b) (a < b ? a : b) + +template +__global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose_kernel( + const IType* const input, + OType* const output_c, + OType* const output_t, + CType* const tile_scales_inv_c, + CType* const tile_scales_inv_t, + const size_t row_length, + const size_t num_rows, + const size_t scale_stride_x, + const size_t scale_stride_y, + const size_t scale_t_stride_x, + const size_t scale_t_stride_y, + const float epsilon, + const __grid_constant__ CUtensorMap tensor_map_output_t) { + using IVec = Vec; + using OVecCast = Vec; + using OVecTrans = Vec; + + // shared mem for amax reduction in entire block, each warp produces one amax, there are + // NUM_WARPS_IN_BLOCK amax to reduce + __shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK]; + + IVec thrd_tile_input[THREAD_TILE_DIM_Y]; + constexpr int THREAD_TILE_DIM_X_ = kReturnTranspose ? THREAD_TILE_DIM_X : 1; + OVecTrans thrd_tile_out_trans[THREAD_TILE_DIM_X_]; + + const int tid_in_warp = threadIdx.x % kThreadsPerWarp; + const int tid_in_warp_x = tid_in_warp % NUM_THREADS_X_IN_WARP; + const int tid_in_warp_y = tid_in_warp / NUM_THREADS_X_IN_WARP; + const int warp_id_in_block = threadIdx.x / kThreadsPerWarp; + const int warp_id_in_block_x = warp_id_in_block % NUM_WARPS_X_IN_BLOCK; + const int warp_id_in_block_y = warp_id_in_block / NUM_WARPS_X_IN_BLOCK; + + // This is ONLY true if the input is a full tile + const int tile_id_x = blockIdx.x; + const int tile_id_y = blockIdx.y; + + const size_t block_tile_start_idx = + tile_id_y * BLOCK_TILE_DIM * row_length + tile_id_x * BLOCK_TILE_DIM; + const size_t warp_tile_start_idx = + block_tile_start_idx + + warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP * row_length + + warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP; + const size_t thread_tile_start_idx = warp_tile_start_idx + + tid_in_warp_y * THREAD_TILE_DIM_Y * row_length + + tid_in_warp_x * THREAD_TILE_DIM_X; + + CType warp_tile_amax; + CType block_tile_amax; + CType block_tile_scale; + CType amax = 0; + +// Step 1: Load a block tile of input data into thread tiles on registers +#pragma unroll + for (int i = 0; i < THREAD_TILE_DIM_Y; i++) { + thrd_tile_input[i].load_from(input + thread_tile_start_idx + i * row_length); + } + + // Step 2: calculate block tile amax and scale + // Calculate thread_tile amax + for (int i = 0; i < THREAD_TILE_DIM_Y; i++) { +#pragma unroll + for (int j = 0; j < THREAD_TILE_DIM_X; j++) { + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(static_cast(thrd_tile_input[i].data.elt[j]))); + } + } + // Reduce amax in the warp (32x32 tile) + warp_tile_amax = warp_reduce_max(amax); + // broadcast the amax to all threads in a warp from the lane 0 + constexpr int lane_zero = 0; + warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero); + + // reduce warp_tile_amax across multiple warps in a thread block using shared mem + if (tid_in_warp == 0) { + block_tile_amax_shared[warp_id_in_block_y * NUM_WARPS_X_IN_BLOCK + warp_id_in_block_x] = + warp_tile_amax; + } + __syncthreads(); + // only 8 elements needs reduction, if using reduction tree, multiple _syncthreads will be needed, + // instead we just let thread 0 do the job + if (threadIdx.x == 0) { + CType blk_amax = block_tile_amax_shared[0]; +#pragma unroll + for (int idx = 1; idx < NUM_WARPS_IN_BLOCK; idx++) { + blk_amax = fmaxf(blk_amax, block_tile_amax_shared[idx]); + } + block_tile_amax_shared[0] = blk_amax; + } + __syncthreads(); + block_tile_amax = block_tile_amax_shared[0]; + + block_tile_scale = ComputeScale(block_tile_amax, epsilon); + + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + const CType scale_inv = 1.0f / block_tile_scale; + + size_t row_idx = tile_id_y; + size_t col_idx = tile_id_x; + tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv; + + if constexpr (kReturnTranspose) { + row_idx = tile_id_x; + col_idx = tile_id_y; + tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv; + } + } + + // Step 3: Store cast output, Step 4: do transpose within thread tile + OVecCast tmp_output_c; + + for (int i = 0; i < THREAD_TILE_DIM_Y; i++) { +#pragma unroll + for (int j = 0; j < THREAD_TILE_DIM_X; j++) { + // Step 3: Store cast output + CType scale_data = block_tile_scale; + + OType scaled_elt = + static_cast(static_cast(thrd_tile_input[i].data.elt[j]) * scale_data); + tmp_output_c.data.elt[j] = scaled_elt; + // Step 4: do transpose within thread tile + if constexpr (kReturnTranspose) { + thrd_tile_out_trans[j].data.elt[i] = scaled_elt; + } + } + tmp_output_c.store_to(output_c + thread_tile_start_idx + i * row_length); + } + + // Step 4: store transpose into shared memory + if constexpr (kReturnTranspose) { +#ifdef TMA_HW_SUPPORTED + __shared__ alignas(128) + OVecTrans block_tile_trans_shared[SHARED_BLOCK_TILE_DIM_Y][SHARED_BLOCK_TILE_DIM_X_BANKS]; + OType(*block_tile_trans_shared_otype_ptr)[BLOCK_TILE_DIM] = + reinterpret_cast(block_tile_trans_shared); + +#pragma unroll + for (int i = 0; i < THREAD_TILE_DIM_X; i++) { + auto warp_id_in_block_x_ = warp_id_in_block_y; + auto warp_id_in_block_y_ = warp_id_in_block_x; + int row_idx = warp_id_in_block_y_ * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP + + tid_in_warp_x * THREAD_TILE_DIM_X + i; + int col_idx = + warp_id_in_block_x_ * (NUM_BANKS_Y_IN_WARP / NUM_BANKS_PER_SHARED_ELEM) + tid_in_warp_y; + block_tile_trans_shared[row_idx][col_idx] = thrd_tile_out_trans[i]; + } + + // Wait for shared memory writes to be visible to TMA engine. + cde::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Step 5: store transpose output + // Initiate TMA transfer to copy shared memory to global memory + if (threadIdx.x == 0) { + cde::cp_async_bulk_tensor_2d_shared_to_global( + &tensor_map_output_t, + tile_id_y * BLOCK_TILE_DIM, + tile_id_x * BLOCK_TILE_DIM, + block_tile_trans_shared_otype_ptr); + // Wait for TMA transfer to have finished reading shared memory. + // Create a "bulk async-group" out of the previous bulk copy operation. + cde::cp_async_bulk_commit_group(); + // Wait for the group to have completed reading from shared memory. + cde::cp_async_bulk_wait_group_read<0>(); + } +#else + // Step 4 Alternative (when TMA is not available, skip writing to shared memory) + const size_t block_tile_t_start_idx = + tile_id_x * BLOCK_TILE_DIM * num_rows + tile_id_y * BLOCK_TILE_DIM; + const size_t warp_tile_t_start_idx = + block_tile_t_start_idx + + warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP * num_rows + + warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP; + const size_t thread_tile_t_start_idx = warp_tile_t_start_idx + + tid_in_warp_x * THREAD_TILE_DIM_X * num_rows + + tid_in_warp_y * THREAD_TILE_DIM_Y; +#pragma unroll + for (int i = 0; i < THREAD_TILE_DIM_X; i++) { + thrd_tile_out_trans[i].store_to(output_t + thread_tile_t_start_idx + i * num_rows); + } +#endif + } +} + +template +__global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose_kernel_notaligned( + const IType* const input, + OType* const output_c, + OType* const output_t, + CType* const tile_scales_inv_c, + CType* const tile_scales_inv_t, + const size_t row_length, + const size_t num_rows, + const size_t scale_stride_x, + const size_t scale_stride_y, + const size_t scale_t_stride_x, + const size_t scale_t_stride_y, + const float epsilon) { + using IVec = Vec; + using OVecCast = Vec; + using OVecTrans = Vec; + + // shared mem for amax reduction in entire block, each warp produces one amax, there are + // NUM_WARPS_IN_BLOCK amax to reduce + __shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK]; + + IVec thrd_tile_input[THREAD_TILE_DIM_Y]; + constexpr int THREAD_TILE_DIM_X_ = kReturnTranspose ? THREAD_TILE_DIM_X : 1; + OVecTrans thrd_tile_out_trans[THREAD_TILE_DIM_X_]; + + const int tid_in_warp = threadIdx.x % kThreadsPerWarp; + const int tid_in_warp_x = tid_in_warp % NUM_THREADS_X_IN_WARP; + const int tid_in_warp_y = tid_in_warp / NUM_THREADS_X_IN_WARP; + const int warp_id_in_block = threadIdx.x / kThreadsPerWarp; + const int warp_id_in_block_x = warp_id_in_block % NUM_WARPS_X_IN_BLOCK; + const int warp_id_in_block_y = warp_id_in_block / NUM_WARPS_X_IN_BLOCK; + + const int tile_id_x = blockIdx.x; + const int tile_id_y = blockIdx.y; + + const size_t block_tile_start_row_idx = tile_id_y * BLOCK_TILE_DIM; + const size_t block_tile_start_col_idx = tile_id_x * BLOCK_TILE_DIM; + const size_t block_tile_start_idx = + block_tile_start_row_idx * row_length + block_tile_start_col_idx; + const size_t warp_tile_start_idx = + block_tile_start_idx + + warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP * row_length + + warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP; + const size_t thread_tile_start_idx = warp_tile_start_idx + + tid_in_warp_y * THREAD_TILE_DIM_Y * row_length + + tid_in_warp_x * THREAD_TILE_DIM_X; + + // handle non-full tile + // check for three cases: full thread tile, nonfull thread tile, empty thread tile + // for empty thread tile, directly write zero to the transposed shared mem buffer + // for nonfull thread tile, fill zero to thread tile and act as if it's full + const size_t thread_tile_start_row_idx = + tile_id_y * BLOCK_TILE_DIM + warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP + + tid_in_warp_y * THREAD_TILE_DIM_Y; + const size_t thread_tile_start_col_idx = + tile_id_x * BLOCK_TILE_DIM + warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP + + tid_in_warp_x * THREAD_TILE_DIM_X; + + const size_t thread_tile_end_row_idx = thread_tile_start_row_idx + THREAD_TILE_DIM_Y - 1; + const size_t thread_tile_end_col_idx = thread_tile_start_col_idx + THREAD_TILE_DIM_X - 1; + + bool full_thrd_tile = + (thread_tile_end_row_idx < num_rows) && (thread_tile_end_col_idx < row_length); + bool empty_thrd_tile = + (thread_tile_start_row_idx >= num_rows) || (thread_tile_start_col_idx >= row_length); + bool nonfull_thrd_tile = (!full_thrd_tile) && (!empty_thrd_tile); + + const size_t thread_tile_ncols = + MIN(THREAD_TILE_DIM_X, + (MIN(thread_tile_end_col_idx, row_length - 1) - thread_tile_start_col_idx + 1)); + const size_t thread_tile_nrows = + MIN(THREAD_TILE_DIM_Y, + (MIN(thread_tile_end_row_idx, num_rows - 1) - thread_tile_start_row_idx + 1)); + + CType warp_tile_amax; + CType block_tile_amax; + CType block_tile_scale; + CType amax = 0; + + if (!empty_thrd_tile) { + // Step 1: Load a block tile of input data into thread tiles on registers + // Edge case: nonfull thread tile case, will use the partial load function here + if (nonfull_thrd_tile) { +#pragma unroll + for (int i = 0; i < THREAD_TILE_DIM_Y; i++) { + if (i >= thread_tile_nrows) { + thrd_tile_input[i].clear(); + } else { + thrd_tile_input[i].load_from_elts(input + thread_tile_start_idx + i * row_length, 0, + thread_tile_ncols); + } + } + } else { +#pragma unroll + for (int i = 0; i < THREAD_TILE_DIM_Y; i++) { + thrd_tile_input[i].load_from_elts(input + thread_tile_start_idx + i * row_length, 0, + THREAD_TILE_DIM_X); + } + } + + // Step 2: calculate block tile amax and scale + // Calculate thread_tile amax + for (int i = 0; i < THREAD_TILE_DIM_Y; i++) { +#pragma unroll + for (int j = 0; j < THREAD_TILE_DIM_X; j++) { + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(static_cast(thrd_tile_input[i].data.elt[j]))); + } + } + } + // Reduce amax in the warp (32x32 tile) + warp_tile_amax = warp_reduce_max(amax); + // broadcast the amax to all threads in a warp from the lane 0 + constexpr int lane_zero = 0; + warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero); + + // reduce warp_tile_amax across multiple warps in a thread block using shared mem + if (tid_in_warp == 0) { + block_tile_amax_shared[warp_id_in_block_y * NUM_WARPS_X_IN_BLOCK + warp_id_in_block_x] = + warp_tile_amax; + } + __syncthreads(); + // only 8 elements needs reduction, if using reduction tree, multiple _syncthreads will be needed, + // instead we just let thread 0 do the job + if (threadIdx.x == 0) { + CType blk_amax = block_tile_amax_shared[0]; +#pragma unroll + for (int idx = 1; idx < NUM_WARPS_IN_BLOCK; idx++) { + blk_amax = fmaxf(blk_amax, block_tile_amax_shared[idx]); + } + block_tile_amax_shared[0] = blk_amax; + } + __syncthreads(); + block_tile_amax = block_tile_amax_shared[0]; + + block_tile_scale = ComputeScale(block_tile_amax, epsilon); + + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + const CType scale_inv = 1.0f / block_tile_scale; + + size_t row_idx = tile_id_y; + size_t col_idx = tile_id_x; + tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv; + + if constexpr (kReturnTranspose) { + row_idx = tile_id_x; + col_idx = tile_id_y; + tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv; + } + } + + // Step 3: Store cast output, Step 4: do transpose within thread tile + // Edge case: in the non-full tile case, there are three subcases + // for full thread tile, it's the same thing here + // for nonfull thread tile, pay attention when saving tmp_output_c to global + // memory, cannot vec store_to, but need to elt store to for empty tile, + // it should not enter this step, skip to Step 4 + + // set thrd_tile_out_trans to all zero + if constexpr (kReturnTranspose) { +#pragma unroll + for (int j = 0; j < THREAD_TILE_DIM_X; j++) { + thrd_tile_out_trans[j].clear(); + } + } + + if (!empty_thrd_tile) { + OVecCast tmp_output_c; + for (int i = 0; i < THREAD_TILE_DIM_Y; i++) { + if (i >= thread_tile_nrows) { + continue; + } +#pragma unroll + for (int j = 0; j < THREAD_TILE_DIM_X; j++) { + // Step 3: Store cast output + CType scale_data = block_tile_scale; + + OType scaled_elt = + static_cast(static_cast(thrd_tile_input[i].data.elt[j]) * scale_data); + tmp_output_c.data.elt[j] = scaled_elt; + // Step 4: do transpose within thread tile + if constexpr (kReturnTranspose) { + thrd_tile_out_trans[j].data.elt[i] = scaled_elt; + } + } + tmp_output_c.store_to_elts(output_c + thread_tile_start_idx + i * row_length, 0, + thread_tile_ncols); + } + + if constexpr (kReturnTranspose) { + const size_t block_tile_t_start_idx = + tile_id_x * BLOCK_TILE_DIM * num_rows + tile_id_y * BLOCK_TILE_DIM; + const size_t warp_tile_t_start_idx = + block_tile_t_start_idx + + warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP * num_rows + + warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP; + const size_t thread_tile_t_start_idx = warp_tile_t_start_idx + + tid_in_warp_x * THREAD_TILE_DIM_X * num_rows + + tid_in_warp_y * THREAD_TILE_DIM_Y; +#pragma unroll + for (int i = 0; i < thread_tile_ncols; i++) { + thrd_tile_out_trans[i].store_to_elts(output_t + thread_tile_t_start_idx + i * num_rows, 0, + thread_tile_nrows); + } + } + } +} + +PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() { + void* driver_ptr = nullptr; + cudaDriverEntryPointQueryResult driver_status; + NVTE_CHECK_CUDA(cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &driver_ptr, cudaEnableDefault, + &driver_status)); + return reinterpret_cast(driver_ptr); +} + +template +CUtensorMap get_tensor_map(SimpleTensor& tensor, size_t global_dim_x, size_t global_dim_y) { + // example-begin create-tensor-map + CUtensorMap tensor_map_output_trans{}; + // rank is the number of dimensions of the array. + constexpr uint32_t rank = 2; + uint64_t size[rank] = {global_dim_x, global_dim_y}; // x, y + // The stride is the number of bytes to traverse from the first element of one row to the next. + // It must be a multiple of 16. + uint64_t stride[rank - 1] = {global_dim_x * sizeof(OutputType)}; + // The box_size is the size of the shared memory buffer that is used as the + // destination of a TMA transfer. + uint32_t box_size[rank] = {BLOCK_TILE_DIM, BLOCK_TILE_DIM}; + // The distance between elements in units of sizeof(element). A stride of 2 + // can be used to load only the real component of a complex-valued tensor, for instance. + uint32_t elem_stride[rank] = {1, 1}; + + // Get a function pointer to the cuTensorMapEncodeTiled driver API. + auto cuTensorMapEncodeTiled = get_cuTensorMapEncodeTiled(); + CUtensorMapDataType dataType; + + if constexpr (std::is_same_v || + std::is_same_v) { + dataType = CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8; + } else { + NVTE_CHECK(false, "Invalid Output type (must be FP8)."); + } + + // Create the tensor descriptor. + CUresult res = cuTensorMapEncodeTiled( + &tensor_map_output_trans, // CUtensorMap *tensorMap, + dataType, + rank, // cuuint32_t tensorRank, + reinterpret_cast(tensor.dptr), // void *globalAddress, + size, // const cuuint64_t *globalDim, + stride, // const cuuint64_t *globalStrides, + box_size, // const cuuint32_t *boxDim, + elem_stride, // const cuuint32_t *elementStrides, + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + // Swizzling can be used to avoid shared memory bank conflicts. + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, + // Any element that is outside of bounds will be set to zero by the TMA transfer. + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + + return tensor_map_output_trans; +} + +} // namespace +} // namespace transformer_engine + +namespace transformer_engine::detail { + +void nvte_quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor& scale_inv, + SimpleTensor& scale_inv_t, SimpleTensor& output, + SimpleTensor& output_t, const float epsilon, + const bool return_transpose, const bool pow_2_scale, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_transpose_square_blockwise); + NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); + const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; + size_t num_rows = 1; + for (size_t i = 0; (i < input.shape.size() - 1) && (input.shape.size() > 0); ++i) { + num_rows *= input.shape.at(i); + } + + NVTE_CHECK(scale_inv.shape.size() == 2, "scale_inv must have 2 dimensions."); + + size_t scale_k = scale_inv.shape[1]; + + const size_t scale_stride_x = 1; + const size_t scale_stride_y = scale_k; + + size_t scale_t_stride_x = 0; + size_t scale_t_stride_y = 0; + + if (return_transpose) { + NVTE_CHECK(output_t.shape.size() == input.shape.size(), + "output_t must have same number of dimensions as input."); + if (output_t.shape.size() > 0) { + NVTE_CHECK(output_t.shape[0] == row_length, "Wrong dimension 0 of output_t."); + for (size_t i = 1; i < output_t.shape.size(); ++i) { + NVTE_CHECK(output_t.shape.at(i) == input.shape.at(i - 1), "Wrong dimension in output_t"); + } + } + NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same type."); + + NVTE_CHECK(scale_inv_t.shape.size() == 2, "scale_inv_t must have 2 dimensions."); + + scale_t_stride_x = 1; + scale_t_stride_y = scale_inv_t.shape[1]; + } + + const size_t num_blocks_x = DIVUP(row_length, BLOCK_TILE_DIM); + const size_t num_blocks_y = DIVUP(num_rows, BLOCK_TILE_DIM); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype, + InputType, + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output.dtype, + OutputType, + + dim3 grid(num_blocks_x, num_blocks_y, 1); + const bool full_tile = row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0; + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_transpose, + kReturnTranspose, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + pow_2_scale, + kPow2Scale, + + if (full_tile) { + CUtensorMap tensor_map_output_trans; + if constexpr (kReturnTranspose) { + tensor_map_output_trans = + get_tensor_map(output_t, num_rows, row_length); + } + block_scaled_cast_transpose_kernel< + kReturnTranspose, + kPow2Scale, + float, + InputType, + OutputType><<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), + row_length, + num_rows, + scale_stride_x, + scale_stride_y, + scale_t_stride_x, + scale_t_stride_y, + epsilon, + tensor_map_output_trans); + } else { + block_scaled_cast_transpose_kernel_notaligned< + kReturnTranspose, + kPow2Scale, + float, + InputType, + OutputType><<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), + row_length, + num_rows, + scale_stride_x, + scale_stride_y, + scale_t_stride_x, + scale_t_stride_y, + epsilon); + } // full-tile + + ) // kPow2Scale + ) // kReturnTranspose + ) // OutputType + ) // InputType + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace transformer_engine::detail \ No newline at end of file diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu new file mode 100644 index 0000000000..1e73e74987 --- /dev/null +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -0,0 +1,563 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "common/common.h" +#include "common/utils.cuh" +#include "compute_scale.cuh" + +namespace transformer_engine { +namespace { + +// clang-format off +/* + +Step 1: Load input to shared memory +* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding) +* 8 warps +* Loop 8 times +* What each thread does in each loop: + * 8 elements are read from the input at a time + * 2 elements are written to the shared memory at a time, for a total of 4 times ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 | T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 | +| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 | T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| Warp 1 | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| Warp 7 | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | +| ... | +| Loop 8 times | +| ... | +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ + +Step 2: Cast and store to output_c +* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding) +* 8 warps +* Loop 4 times +* What each thread does in each loop: + * 2 elements are read from the shared memory at a time, for a total of 8 times + * Every 8 consecutive threads do reduction and calculate the amax of each row + * 16 elements are quantized and write to output_c at a time ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 | +| T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 | +| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 | +| T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| | +| Warp 1 | +| | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| | +| Warp 7 | +| | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | +| ... | +| Loop 4 times | +| ... | +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ + +Step 3: Transpose, cast and store to output_t +* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding) +* 8 warps +* Loop 2 times +* What each thread does in each loop: + * 2 elements (in a row) are read from the shared memory at a time, for a total of 16 times + * Every 8 consecutive threads do reduction and calculate the amax of each column + * 16 elements are quantized and write to output_c at a time, for a total of 2 times ++------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+ +| T0 | T8 | T16 | T24 | | | | T0 | T8 | T16 | T24 | | | | +| T1 | T9 | T17 | T25 | | | | T1 | T9 | T17 | T25 | | | | +| T2 | T10 | T18 | T26 | | | | T2 | T10 | T18 | T26 | | | | +| T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 | T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 | +| T4 | T12 | T20 | T28 | | | | T4 | T12 | T20 | T28 | | | | +| T5 | T13 | T21 | T29 | | | | T5 | T13 | T21 | T29 | | | | +| T6 | T14 | T22 | T30 | | | | T6 | T14 | T22 | T30 | | | | +| T7 | T15 | T23 | T31 | | | | T7 | T15 | T23 | T31 | | | | ++-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+ + +*/ +// clang-format on + +constexpr size_t kThreadsPerWarp = 32; + +// Hyperparameters for performance tuning +constexpr int kTileDim = 128; // Fixed to 128 beacause we are using 1x128 and 128x1 quantization +constexpr int kNVecIn = 8; // The number of elements each LDG touches +constexpr int kNVecOut = 16; // The number of elements each STG touches +constexpr int kNVecSMem = 2; // The number of elements each LDS/STS touches +constexpr int kThreadsPerBlock = 256; // Thread block size, 8 warps in total + +// Auto-calculated constants, do not modify directly) +static_assert(kNVecIn % kNVecSMem == 0, "kNVecIn must be divisible by kNVecSMem"); +static_assert(kNVecOut % kNVecSMem == 0, "kNVecOut must be divisible by kNVecSMem"); +constexpr int kSMemRow = kTileDim; +constexpr int kSMemCol = (kTileDim / kNVecSMem) + 1; +constexpr int kSMemSize = kSMemRow * kSMemCol * kNVecSMem; +constexpr int kNumThreadsLoad = kTileDim / kNVecIn; +constexpr int kNumThreadsStore = kTileDim / kNVecOut; +static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp"); +static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); + +template < + bool kReturnTranspose, + bool kIsE8Scaling, + bool kPermuteScale, + bool kAligned, + typename CType, + typename IType, + typename OType> +__global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel( + const IType* const input, + OType* const output_c, + OType* const output_t, + CType* const tile_scales_inv_c, + CType* const tile_scales_inv_t, + const size_t row_length, + const size_t num_rows, + const size_t scale_stride_x, + const size_t scale_stride_y, + const size_t scale_t_stride_x, + const size_t scale_t_stride_y, + const float epsilon) { + using SMemVec = Vec; + using OVec = Vec; + union IVec { + Vec input_type; + Vec smem_type; + }; + + extern __shared__ char smem_base[]; + SMemVec* smem = reinterpret_cast(&smem_base[0]); + + // Step 1: Load input to shared memory + { + constexpr int r_stride = kThreadsPerBlock / kNumThreadsLoad; // stride in rows of shared memory + constexpr int num_iterations = kTileDim / r_stride; + const int c_s = + (threadIdx.x % kNumThreadsLoad) * (kNVecIn / kNVecSMem); // Column in shared memory + int r_s = threadIdx.x / kNumThreadsLoad; // Row in shared memory + const size_t c_g = (size_t)blockIdx.x * kTileDim + c_s * kNVecSMem; // Column in global memory + size_t r_g = (size_t)blockIdx.y * kTileDim + r_s; // Row in global memory + const size_t stride_g = (size_t)r_stride * row_length; // Stride in global memory + const size_t num_ele = + c_g < row_length ? min((size_t)kNVecIn, row_length - c_g) : 0; // For not aligned case + const IType* input_g = &input[r_g * row_length + c_g]; // Input address in global memory +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + IVec input_vec; + // Step 1.1: Load from global memory (input) to registers + if constexpr (kAligned) { + input_vec.input_type.load_from(input_g); + } else { + if (r_g < num_rows) { + input_vec.input_type.load_from_elts(input_g, 0, num_ele); + } else { + input_vec.input_type.clear(); + } + } + // Step 1.2: Write to shared memory +#pragma unroll + for (int i = 0; i < kNVecIn / kNVecSMem; ++i) { + int c = c_s + i; + int r = r_s; + smem[r * kSMemCol + c] = input_vec.smem_type.data.elt[i]; + } + // Step 1.3: Update input address, row index of shared memory, (and row index of global memory for not aligned case) + input_g += stride_g; + r_s += r_stride; + if constexpr (!kAligned) { + r_g += r_stride; + } + } + } + + __syncthreads(); + + // Step 2: Cast and store to output_c + { + constexpr int r_stride = + kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory + constexpr int num_iterations = kTileDim / r_stride; + const int c_s = + (threadIdx.x % kNumThreadsStore) * (kNVecOut / kNVecSMem); // Column in shared memory + int r_s = threadIdx.x / kNumThreadsStore; // Row in shared memory + const size_t c_g = (size_t)blockIdx.x * kTileDim + c_s * kNVecSMem; // Column in global memory + size_t r_g = (size_t)blockIdx.y * kTileDim + r_s; // Row in global memory + const size_t stride_g = (size_t)r_stride * row_length; // Stride in global memory + const size_t num_ele = + c_g < row_length ? min((size_t)kNVecOut, row_length - c_g) : 0; // For not aligned case + OType* output_g = &output_c[r_g * row_length + c_g]; // Output address in global memory + // Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of + // the first thread to do the reduction. + const unsigned src_lane = (threadIdx.x % kThreadsPerWarp) / kNumThreadsStore * kNumThreadsStore; + // This mask represents which threads should do the reduction together. + const unsigned mask = ((1 << kNumThreadsStore) - 1) << src_lane; + const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0; +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + SMemVec smem_vec[kNVecOut / kNVecSMem]; + // Step 2.1: Load from shared memory to registers +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { + int c = c_s + i; + int r = r_s; + smem_vec[i] = smem[r * kSMemCol + c]; + } + // Step 2.2: Compute local amax + CType amax = 0; +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { +#pragma unroll + for (int j = 0; j < kNVecSMem; ++j) { + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[j])); + } + } + // Step 2.3: Reduce amax +#pragma unroll + for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) { + const float other_amax = __shfl_down_sync(mask, amax, delta); + __builtin_assume(amax >= 0); + __builtin_assume(other_amax >= 0); + amax = fmaxf(amax, other_amax); + } + amax = __shfl_sync(mask, amax, src_lane); + // Step 2.4: Compute scale + CType scale = ComputeScale(amax, epsilon); + // Step 2.5: Write scale_inv + bool write_scale_inv = is_src_lane; + if constexpr (!kAligned) { + write_scale_inv &= (r_g < num_rows); + } + if (write_scale_inv) { + CType scale_inv = 1.0 / scale; + size_t row_idx = (size_t)blockIdx.y * kTileDim + r_s; + size_t col_idx = (size_t)blockIdx.x; + if constexpr (kPermuteScale) { + size_t p_row = row_idx / kTileDim; + size_t p_col = col_idx; + size_t p_dep = row_idx % kTileDim; + size_t p_2d_stride = kTileDim * scale_stride_y; + tile_scales_inv_c[p_row * p_2d_stride + p_col * kTileDim + p_dep] = scale_inv; + } else { + tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv; + } + } + // Step 2.6: Quantize + OVec output_vec; +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { +#pragma unroll + for (int j = 0; j < kNVecSMem; ++j) { + output_vec.data.elt[i * kNVecSMem + j] = + static_cast(static_cast(smem_vec[i].data.elt[j]) * scale); + } + } + // Step 2.7: Store output_c + if constexpr (kAligned) { + output_vec.store_to(output_g); + } else { + if (r_g < num_rows) { + output_vec.store_to_elts(output_g, 0, num_ele); + } + } + // Step 2.8: Update output address, row index of shared memory (and row index of global memory for not aligned case) + output_g += stride_g; + r_s += r_stride; + if constexpr (!kAligned) { + r_g += r_stride; + } + } + } + + // Step 3: Transpose, cast and store to output_t + if constexpr (kReturnTranspose) { + constexpr int c_stride = + kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory + constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem); + const int r_s = (threadIdx.x % kNumThreadsStore) * kNVecOut; // Row in shared memory + int c_s = threadIdx.x / kNumThreadsStore; // Column in shared memory + size_t r_g = (size_t)blockIdx.x * kTileDim + c_s * kNVecSMem; // Row in global memory + const size_t c_g = (size_t)blockIdx.y * kTileDim + r_s; // Column in global memory + const size_t stride_g = (size_t)c_stride * kNVecSMem * num_rows; // Stride in global memory + const size_t num_ele = + c_g < num_rows ? min((size_t)kNVecOut, num_rows - c_g) : 0; // For not aligned case + OType* output_g = &output_t[r_g * num_rows + c_g]; // Output address in global memory + // Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of + // the first thread to do the reduction. + const unsigned src_lane = (threadIdx.x % kThreadsPerWarp) / kNumThreadsStore * kNumThreadsStore; + // This mask represents which threads should do the reduction together. + const unsigned mask = ((1 << kNumThreadsStore) - 1) << src_lane; + const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0; +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + SMemVec smem_vec[kNVecOut]; + // Step 3.1: Load from shared memory to registers +#pragma unroll + for (int i = 0; i < kNVecOut; ++i) { + int r = r_s + i; + int c = c_s; + smem_vec[i] = smem[r * kSMemCol + c]; + } +#pragma unroll + for (int smem_idx = 0; smem_idx < kNVecSMem; ++smem_idx) { + // Step 3.2: Compute local amax + CType amax = 0; +#pragma unroll + for (int i = 0; i < kNVecOut; ++i) { + amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[smem_idx])); + } + // Step 3.3: Reduce amax +#pragma unroll + for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) { + const float other_amax = __shfl_down_sync(mask, amax, delta); + __builtin_assume(amax >= 0); + __builtin_assume(other_amax >= 0); + amax = fmaxf(amax, other_amax); + } + amax = __shfl_sync(mask, amax, src_lane); + // Step 3.4: Compute scale + CType scale = ComputeScale(amax, epsilon); + // Step 3.5: Write scale_inv_t + bool write_scale_inv = is_src_lane; + if constexpr (!kAligned) { + write_scale_inv &= (r_g + smem_idx < row_length); + } + if (write_scale_inv) { + CType scale_inv = 1.0 / scale; + size_t row_idx = (size_t)blockIdx.x * kTileDim + c_s * kNVecSMem + smem_idx; + size_t col_idx = (size_t)blockIdx.y; + if constexpr (kPermuteScale) { + size_t p_row = row_idx / kTileDim; + size_t p_col = col_idx; + size_t p_dep = row_idx % kTileDim; + size_t p_2d_stride = kTileDim * scale_t_stride_y; + tile_scales_inv_t[p_row * p_2d_stride + p_col * kTileDim + p_dep] = scale_inv; + } else { + tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv; + } + } + // Step 3.6: Quantize + OVec output_vec; +#pragma unroll + for (int i = 0; i < kNVecOut; ++i) { + output_vec.data.elt[i] = + static_cast(static_cast(smem_vec[i].data.elt[smem_idx]) * scale); + } + // Step 3.7: Store output_t + if constexpr (kAligned) { + output_vec.store_to(output_g + smem_idx * num_rows); + } else { + if (r_g + smem_idx < row_length) { + output_vec.store_to_elts(output_g + smem_idx * num_rows, 0, num_ele); + } + } + } + // Step 3.8: Update output address, column index of shared memory (and row index of global memory for not aligned case) + output_g += stride_g; + c_s += c_stride; + if constexpr (!kAligned) { + r_g += c_stride * kNVecSMem; + } + } + } +} + +} // namespace +} // namespace transformer_engine + +namespace transformer_engine::detail { + +void nvte_quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor& scale_inv, + SimpleTensor& scale_inv_t, SimpleTensor& output, + SimpleTensor& output_t, const float epsilon, + const bool return_transpose, const bool pow2_scale, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_transpose_vector_blockwise); + NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); + + const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; + size_t num_elements = row_length; + size_t num_rows = 1; + for (size_t i = 0; (i < input.shape.size() - 1) && (input.shape.size() > 0); ++i) { + num_rows *= input.shape.at(i); + num_elements *= input.shape.at(i); + } + + // Early return if the input tensor is empty + if (num_elements == 0) { + return; + } + + // Options for scale layout of cuBLAS GEMM kernel. + constexpr bool kPermuteScale = false; + bool permute_scale = false; + bool transpose_scales = true; + + NVTE_CHECK(input.shape.size() == output.shape.size(), + "Input and output must have the same shape."); + NVTE_CHECK((!transpose_scales || !permute_scale), + "Permute scale and transpose scales are mutually exclusive flags."); + + size_t scale_stride_x = 0; + size_t scale_stride_y = 0; + if (permute_scale) { + NVTE_CHECK(scale_inv.shape.size() == 3, "scale_inv must have 3 dimensions."); + size_t scale_k = scale_inv.shape[1]; + NVTE_CHECK(scale_inv.shape[2] == kTileDim, "Scale inner dimension must be kTileDim."); + scale_stride_x = 1; + scale_stride_y = scale_k; + } else { + NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2 when not permuting scale."); + size_t scale_k = scale_inv.shape[1]; + scale_stride_x = 1; + scale_stride_y = scale_k; + if (transpose_scales) { + std::swap(scale_stride_x, scale_stride_y); + } + } + + size_t scale_t_stride_x = 0; + size_t scale_t_stride_y = 0; + + if (return_transpose) { + NVTE_CHECK(output_t.shape.size() == input.shape.size(), + "output_t must have same number of dimensions as input."); + if (output_t.shape.size() > 0) { + NVTE_CHECK(output_t.shape[0] == row_length, "Wrong dimension 0 of output_t."); + for (size_t i = 1; i < output_t.shape.size(); ++i) { + NVTE_CHECK(output_t.shape.at(i) == input.shape.at(i - 1), "Wrong dimension in output_t"); + } + } + + NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same dtype."); + + if (permute_scale) { + NVTE_CHECK(scale_inv_t.shape.size() == 3, "Scale_t dimension must be 3."); + scale_t_stride_x = 1; + scale_t_stride_y = scale_inv_t.shape[1]; + NVTE_CHECK(scale_inv_t.shape[2] == kTileDim, "Scale_t inner dimension must be kTileDim."); + } else { + NVTE_CHECK(scale_inv_t.shape.size() == 2, + "Scale_t dimension must be 2 when not permuting scale."); + scale_t_stride_x = 1; + scale_t_stride_y = scale_inv_t.shape[1]; + if (transpose_scales) { + std::swap(scale_t_stride_x, scale_t_stride_y); + } + } + } + + const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim); + const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype, + InputType, + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output.dtype, + OutputType, + + dim3 grid(num_blocks_x, num_blocks_y, 1); + + const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0; + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_transpose, + kReturnTranspose, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + pow2_scale, + kPow2Scale, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + full_tile, + kAligned, + + size_t smem_bytes = kSMemSize * sizeof(InputType); + // shared memory must be requested up + if (smem_bytes >= 48 * 1024) { + cudaError_t err = cudaFuncSetAttribute( + &block_scaled_1d_cast_transpose_kernel< + kReturnTranspose, + kPow2Scale, + kPermuteScale, + kAligned, + float, + InputType, + OutputType>, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_bytes); + NVTE_CHECK( + err == cudaSuccess, "Failed to set dynamic shared memory size."); + } + block_scaled_1d_cast_transpose_kernel< + kReturnTranspose, + kPow2Scale, + kPermuteScale, + kAligned, + float, + InputType, + OutputType> + <<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), + row_length, + num_rows, + scale_stride_x, + scale_stride_y, + scale_t_stride_x, + scale_t_stride_y, + epsilon); + ) // kAligned + ) // kPow2Scale + ) // kReturnTranspose + ) // OutputType + ) // InputType + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace transformer_engine::detail \ No newline at end of file diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index d1ede8d98d..3ca5f5c124 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1245,6 +1245,35 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe workspace_tensor, stream); break; } + case NVTE_BLOCK_SCALING: { + // FIXME(kwyss): Currently ignoring IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters. + if (output_tensor->block_scaling_dim == 2) { + nvte_quantize_transpose_square_blockwise( + input_tensor->data, + output_tensor->scale_inv, + output_tensor->columnwise_scale_inv, + output_tensor->data, + output_tensor->columnwise_data, + /*epsilon=*/ output_tensor->amax_epsilon, + /*return_transpose=*/ output_tensor->has_columnwise_data(), + output_tensor->force_pow_2_scales, + stream); + } else if (output_tensor->block_scaling_dim == 1) { + nvte_quantize_transpose_vector_blockwise( + input_tensor->data, + output_tensor->scale_inv, + output_tensor->columnwise_scale_inv, + output_tensor->data, + output_tensor->columnwise_data, + /*epsilon=*/ output_tensor->amax_epsilon, + /*return_transpose=*/ output_tensor->has_columnwise_data(), + output_tensor->force_pow_2_scales, + stream); + } else { + NVTE_ERROR("Not supported block scaling dim."); + } + break; + } default: NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); } diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index e529289640..6e798ca748 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -349,6 +349,7 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); } } else { + // FIXME(kwyss): Move dequantization code from torch to C++ for NVTE_BLOCK_SCALING NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); } } diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index ff475caf21..a83a508aed 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -24,6 +24,22 @@ torch.bfloat16: tex.DType.kBFloat16, } +""" +This is a map: int -> torch.dtype +Used for resolving cuda extension types to torch. +Has one to one mapping with enum in +transformer_engine.h +""" +TE_DType_To_Torch = { + tex.DType.kByte: torch.uint8, + tex.DType.kFloat8E4M3: torch.float8_e4m3fn, + tex.DType.kFloat8E5M2: torch.float8_e5m2, + tex.DType.kInt32: torch.int32, + tex.DType.kFloat32: torch.float32, + tex.DType.kFloat16: torch.half, + tex.DType.kBFloat16: torch.bfloat16, +} + AttnMaskTypes = ( "no_mask", "padding", diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 40245cf2d9..5e51a71ca5 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -136,6 +136,36 @@ class Float8Quantizer : public Quantizer { std::optional rowwise_data = std::nullopt) const override; }; +class Float8BlockQuantizer : public Quantizer { + public: + // Which float8 type is used for q data. + DType dtype; + + private: + // Options about how to quantize the tensor + // Quantization scales are rounded down to powers of 2. + bool force_pow_2_scales = false; + // Amax within quantization tile has a floor of epsilon. + float amax_epsilon = 0.0; + int block_scaling_dim = 2; + + public: + // Initializes from a python handle to a Float8BlockQuantizer + explicit Float8BlockQuantizer(const py::handle& quantizer); + + NVTEScalingMode get_scaling_mode() const override { return NVTE_BLOCK_SCALING; } + + // Gets rowwise and columnwise_data from tensor and sets them on wrapper + void set_quantization_params(TensorWrapper* tensor) const override; + + // Create a python Float8BlockQuantized tensor and C++ wrapper + // for the tensor. Should set quantized data, scales for rowwise + // and optionally columnwise usage. + std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional rowwise_data = std::nullopt) const override; +}; + class MXFP8Quantizer : public Quantizer { public: DType dtype; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 66dafdaafb..e1e8da1bb6 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -13,7 +13,9 @@ namespace transformer_engine::pytorch { -py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::object& output, +py::object quantize(const at::Tensor& tensor, + py::handle quantizer, + const py::object& output, std::optional noop) { init_extension(); auto my_quantizer = convert_quantizer(quantizer); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 442837d767..27fcb524fa 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -27,6 +27,9 @@ PyTypeObject *Float8QuantizerClass = nullptr; PyTypeObject *MXFP8TensorPythonClass = nullptr; /// TODO Remove PyTypeObject *MXFP8TensorBasePythonClass = nullptr; PyTypeObject *MXFP8QuantizerClass = nullptr; +PyTypeObject *Float8BlockwiseQTensorPythonClass = nullptr; +PyTypeObject *Float8BlockwiseQTensorBasePythonClass = nullptr; +PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; void init_float8_extension() { if (Float8TensorPythonClass) return; @@ -58,9 +61,31 @@ void init_mxfp8_extension() { "Internal error: could not initialize pyTorch MXFP8 extension."); } +void init_float8blockwise_extension() { + if (Float8BlockwiseQTensorBasePythonClass) return; + auto fp8_module = + py::module_::import("transformer_engine.pytorch.tensor.float8_blockwise_tensor"); + auto fp8_base_module = py::module_::import( + "transformer_engine.pytorch.tensor._internal.float8_blockwise_tensor_base"); + Float8BlockwiseQuantizerClass = reinterpret_cast( + PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockQuantizer")); + Float8BlockwiseQTensorBasePythonClass = reinterpret_cast( + PyObject_GetAttrString(fp8_base_module.ptr(), "Float8BlockwiseQTensorBase")); + Float8BlockwiseQTensorPythonClass = reinterpret_cast( + PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockwiseQTensor")); + + NVTE_CHECK(Float8BlockwiseQuantizerClass != nullptr, + "Internal error: could not initialize pyTorch float8blockwise extension."); + NVTE_CHECK(Float8BlockwiseQTensorBasePythonClass != nullptr, + "Internal error: could not initialize pyTorch float8blockwise extension."); + NVTE_CHECK(Float8BlockwiseQTensorPythonClass != nullptr, + "Internal error: could not initialize pyTorch float8blockwise extension."); +} + void init_extension() { init_float8_extension(); init_mxfp8_extension(); + init_float8blockwise_extension(); } } // namespace transformer_engine::pytorch @@ -73,6 +98,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("output") = py::none(), py::arg("noop") = py::none()); m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"), py::arg("otype")); + m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)", diff --git a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp index effeb8cb4d..b1c77d84ec 100644 --- a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp @@ -140,6 +140,133 @@ std::pair Float8Quantizer::create_tensor( return {std::move(tensor), std::move(ret)}; } +Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) { + this->dtype = quantizer.attr("dtype").cast(); + this->force_pow_2_scales = quantizer.attr("force_pow_2_scales").cast(); + this->amax_epsilon = quantizer.attr("amax_epsilon").cast(); + this->block_scaling_dim = quantizer.attr("block_scaling_dim").cast(); +} + +void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const { + // Change the rowwise and columnwise_data to the configured dtype. + // May be a switch between E5M2 and E4M3. + auto rowwise_data = tensor->get_rowwise_data(); + rowwise_data.dtype = static_cast(dtype); + + auto columnwise_data = tensor->get_columnwise_data(); + columnwise_data.dtype = static_cast(dtype); + + tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), + rowwise_data.shape); + tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), + columnwise_data.shape); + + // Set options on TensorWrapper from quantization. + tensor->set_qopt_force_pow_2_scales(force_pow_2_scales); + tensor->set_qopt_amax_epsilon(amax_epsilon); + tensor->set_qopt_block_scaling_dim(block_scaling_dim); +} + +std::pair Float8BlockQuantizer::create_tensor( + const std::vector& shape, DType dtype, std::optional rowwise_data) const { + using namespace pybind11::literals; + std::vector torch_shape; + size_t numel = 1; + for (auto s : shape) { + torch_shape.emplace_back(static_cast(s)); + numel *= s; + } + + TensorWrapper tensor(NVTE_BLOCK_SCALING); + at::TensorOptions opts; + at::TensorOptions scale_opts; + at::Tensor data_rowwise, data_colwise, scale_inv_rowwise, scale_inv_colwise; + opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); + scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA); + + size_t k_dim = torch_shape.size() == 0 ? 1u : torch_shape.back(); + size_t m_dim = numel / k_dim; + constexpr size_t kBlockLen = 128; + + if (rowwise_usage) { + if (rowwise_data.has_value()) { + data_rowwise = std::move(*rowwise_data); + } else { + data_rowwise = at::empty(torch_shape, opts); + } + size_t sinv0 = 0; + size_t sinv1 = 0; + if (block_scaling_dim == 2) { + sinv0 = (m_dim + kBlockLen - 1) / kBlockLen; + sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4); + } else if (block_scaling_dim == 1) { + sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; + sinv1 = m_dim; + } else { + NVTE_CHECK(false, "Unsupported block_scaling_dim in create_tensor rowwise."); + } + scale_inv_rowwise = at::empty({sinv0, sinv1}, scale_opts); + tensor.set_rowwise_data(data_rowwise.data_ptr(), this->dtype, shape); + tensor.set_rowwise_scale_inv(scale_inv_rowwise.data_ptr(), DType::kFloat32, + std::vector{sinv0, sinv1}); + } + + if (columnwise_usage) { + std::vector torch_columnwise_shape; + std::vector columnwise_shape; + NVTE_CHECK(torch_shape.size() == shape.size(), "Shape expected to match torch shape."); + if (torch_shape.size() > 0) { + torch_columnwise_shape.reserve(torch_shape.size()); + columnwise_shape.reserve(shape.size()); + torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]); + columnwise_shape.push_back(shape[shape.size() - 1]); + for (size_t i = 0; i < torch_shape.size() - 1; ++i) { + torch_columnwise_shape.push_back(torch_shape[i]); + columnwise_shape.push_back(shape[i]); + } + } + size_t sinv0 = 0; + size_t sinv1 = 0; + if (block_scaling_dim == 2) { + sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; + sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4); + } else if (block_scaling_dim == 1) { + sinv0 = (m_dim + kBlockLen - 1) / kBlockLen; + sinv1 = k_dim; + } else { + NVTE_CHECK(false, "Unsupported block_scaling_dim in create_tensor columnwise."); + } + data_colwise = at::empty(torch_columnwise_shape, opts); + scale_inv_colwise = at::empty({sinv0, sinv1}, scale_opts); + + tensor.set_columnwise_data(data_colwise.data_ptr(), this->dtype, + columnwise_shape); + tensor.set_columnwise_scale_inv(scale_inv_colwise.data_ptr(), DType::kFloat32, + std::vector{sinv0, sinv1}); + } + this->set_quantization_params(&tensor); + + py::object ret; + if (internal) { + py::handle Float8BlockwiseQTensorClass( + reinterpret_cast(Float8BlockwiseQTensorBasePythonClass)); + ret = Float8BlockwiseQTensorClass( + "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, + "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, + "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + } else { + py::handle Float8BlockwiseQTensorClass( + reinterpret_cast(Float8BlockwiseQTensorPythonClass)); + ret = Float8BlockwiseQTensorClass( + "shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise, + "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, + "columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype, + "quantizer"_a = this->quantizer); + } + + return {std::move(tensor), std::move(ret)}; +} + MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { this->dtype = quantizer.attr("dtype").cast(); } diff --git a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp index d2607e4ed0..48b0919a69 100644 --- a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp @@ -74,6 +74,38 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) return ret; } +TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer *quantizer) { + const DType dtype = tensor.attr("_fp8_dtype").cast(); + auto ret = TensorWrapper(NVTE_BLOCK_SCALING); + + bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); + bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); + + if (rowwise_usage) { + const at::Tensor &data_rowwise = tensor.attr("_rowwise_data").cast(); + const at::Tensor &scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast(); + void *scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr(); + const auto &shape = getTensorShape(data_rowwise); + ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, shape); + + const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise); + ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat32, scale_inv_rowwise_shape); + } + + if (columnwise_usage) { + const at::Tensor &data_colwise = tensor.attr("_columnwise_data").cast(); + const at::Tensor &scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast(); + void *scale_inv_colwise_dptr = scale_inv_colwise.data_ptr(); + const auto &shape = getTensorShape(data_colwise); + ret.set_columnwise_data(data_colwise.data_ptr(), dtype, shape); + + const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise); + ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape); + } + quantizer->set_quantization_params(&ret); + return ret; +} + } // namespace detail } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index 0679528b94..ba3d4bb6fc 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -24,6 +24,9 @@ extern PyTypeObject *Float8QuantizerClass; extern PyTypeObject *MXFP8TensorPythonClass; extern PyTypeObject *MXFP8TensorBasePythonClass; extern PyTypeObject *MXFP8QuantizerClass; +extern PyTypeObject *Float8BlockwiseQTensorPythonClass; +extern PyTypeObject *Float8BlockwiseQTensorBasePythonClass; +extern PyTypeObject *Float8BlockwiseQuantizerClass; void init_extension(); @@ -45,6 +48,15 @@ inline bool IsMXFP8Tensor(PyObject *obj) { return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorBasePythonClass; } +inline bool IsFloat8BlockwiseQParams(PyObject *obj) { + return Py_TYPE(obj) == Float8BlockwiseQuantizerClass; +} + +inline bool IsFloat8BlockwiseQTensor(PyObject *obj) { + return Py_TYPE(obj) == Float8BlockwiseQTensorPythonClass || + Py_TYPE(obj) == Float8BlockwiseQTensorBasePythonClass; +} + TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer); template @@ -56,6 +68,9 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizati std::unique_ptr CreateMXFP8Params(const py::handle params); +TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, + Quantizer *quantization_params); + inline bool IsFloatingPointType(at::ScalarType type) { return type == at::kFloat || type == at::kHalf || type == at::kBFloat16; } @@ -64,7 +79,9 @@ constexpr std::array custom_types_converters = { std::make_tuple(IsFloat8Tensor, IsFloat8QParams, NVTETensorFromFloat8Tensor, CreateQuantizer), std::make_tuple(IsMXFP8Tensor, IsMXFP8QParams, NVTETensorFromMXFP8Tensor, - CreateQuantizer)}; + CreateQuantizer), + std::make_tuple(IsFloat8BlockwiseQTensor, IsFloat8BlockwiseQParams, + NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer)}; } // namespace detail diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py new file mode 100644 index 0000000000..b1a61fae4b --- /dev/null +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -0,0 +1,250 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Mixin class holding data specific for Float8BlockwiseQTensor""" + +from __future__ import annotations +import math +from typing import Optional, Dict, Any, Tuple +import torch + +import transformer_engine_torch as tex +from transformer_engine_torch import DType as TE_DType + +from ...constants import TE_DType as torch_to_transformer_engine_dtype +from ...constants import TE_DType_To_Torch + +from ..quantized_tensor import Quantizer + + +class Float8BlockwiseQTensorBase: + """Mixin class that holds data attributes of Float8BlockwiseQTensor. + + Float8BlockwiseQTensor inherits from the PyTorch tensor class and this + mixin class. If this class is instantiated directly, it has the same + data, lower CPU overhead, and less functionality. It should only + be instantiated directly for performance-critical internal usage. + """ + + _rowwise_data: Optional[torch.Tensor] + _columnwise_data: Optional[torch.Tensor] + _quantizer: Quantizer + _fp8_dtype: TE_DType + _rowwise_scale_inv: Optional[torch.Tensor] + _columnwise_scale_inv: Optional[torch.Tensor] + + def __new__( + cls, + *args, + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: Optional[torch.Tensor], + columnwise_scale_inv: Optional[torch.Tensor], + fp8_dtype: TE_DType, + quantizer: Quantizer, + **kwargs, + ): + instance = super().__new__(cls, *args, **kwargs) + instance._rowwise_data = rowwise_data + instance._columnwise_data = columnwise_data + instance._quantizer = quantizer + instance._fp8_dtype = fp8_dtype + instance._rowwise_scale_inv = rowwise_scale_inv + instance._columnwise_scale_inv = columnwise_scale_inv + + return instance + + def get_metadata(self) -> Dict[str, Any]: + """Get this tensor's metadata.""" + return { + "rowwise_data": self._rowwise_data, + "rowwise_scale_inv": self._rowwise_scale_inv, + "columnwise_data": self._columnwise_data, + "columnwise_scale_inv": self._columnwise_scale_inv, + "fp8_dtype": self._fp8_dtype, + "quantizer": self._quantizer, + } + + def prepare_for_saving( + self, + ) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]: + """Prepare the tensor base for saving for backward + + FIXME(kwyss): Should this clear out data? + FIXME(kwyss): What about dq scales? + """ + tensors = [self._rowwise_data, self._columnwise_data] + return tensors, self + + def restore_from_saved( + self, tensors: list[Optional[torch.Tensor]] + ) -> list[Optional[torch.Tensor]]: + """Restore the tensor base data from the saved tensors list.""" + self._rowwise_data = tensors[0] + self._columnwise_data = tensors[1] + return tensors[2:] + + def get_data_tensors(self): + """Get this Tensor's data.""" + return self._rowwise_data, self._columnwise_data + + def _transpose_dq_columnwise_output( + self, columnwise_dq: torch.Tensor + ) -> torch.Tensor: + """Takes dequantized columnwise data and permutes to a rowwise shape""" + if columnwise_dq.dim() < 2: + return columnwise_dq + permute_dims = [x for x in range(1, columnwise_dq.dim())] + permute_dims.append(0) + return torch.permute(columnwise_dq, tuple(permute_dims)).contiguous() + + def _dequantize_vectorwise( + self, *, dtype: torch.dtype = torch.float32 + ) -> torch.Tensor: + block_len = 128 + + q_M, q_K = 1, 1 + if self._rowwise_data is not None: + q = self._rowwise_data + scale_inv = self._rowwise_scale_inv + transpose_output = False + if len(q.shape) >= 1: + q_K = q.shape[-1] + for i in range(len(q.shape) - 1): + q_M *= q.shape[i] + else: + assert self._columnwise_data is not None, "No data to dequantize" + q = self._columnwise_data + scale_inv = self._columnwise_scale_inv + transpose_output = True + if len(q.shape) >= 1: + q_M = q.shape[0] + for i in range(1, len(q.shape)): + q_K *= q.shape[i] + + orig_shape = q.shape + q = q.reshape(q_M, q_K) + k_tiles, m = scale_inv.shape + if q_K % block_len != 0: + k_pad_amount = (block_len - (q_K % block_len)) % block_len + q = torch.nn.functional.pad( + q, (0, k_pad_amount, 0, 0), mode="constant", value=0 + ).contiguous() + _, padded_K = q.shape + q_tiled = q.reshape(q_M, k_tiles, block_len) + dq_scale = scale_inv.transpose(-2, -1).contiguous().reshape(m, k_tiles, 1) + torch_q_dtype = TE_DType_To_Torch[self._fp8_dtype] + result = q_tiled.view(torch_q_dtype).to(torch.float32) * dq_scale + if padded_K != q_K: + result = result.reshape(q_M, padded_K)[:, :q_K] + result = result.to(dtype) + if len(orig_shape) == 0: + result = result.reshape([]) + else: + result = result.reshape(*orig_shape).contiguous() + + if transpose_output: + return self._transpose_dq_columnwise_output(result) + return result + + def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """ + Construct plain PyTorch tensor from Float8BlockwiseQTensor + """ + block_len = 128 + assert self._quantizer is not None + if self._quantizer.block_scaling_dim != 2: + assert self._quantizer.block_scaling_dim == 1 + return self._dequantize_vectorwise(dtype=dtype) + + def format_scale_as_logical_shape(q_M, q_K, scales, block_len): + # The GEMM for 2D blocks required padding in the scales. + derived_scale_k_shape = math.ceil(q_K / block_len) + scale_M, scale_K = scales.shape + if derived_scale_k_shape == scale_K: + return scales + else: + return scales[:, :derived_scale_k_shape].contiguous() + return formatted_scales + + q_M, q_K = 1, 1 + if self._rowwise_data is not None: + q = self._rowwise_data + scale_inv = self._rowwise_scale_inv + transpose_output = False + if len(q.shape) >= 1: + q_K = q.shape[-1] + for i in range(len(q.shape) - 1): + q_M *= q.shape[i] + else: + assert self._columnwise_data is not None, "No data to dequantize" + q = self._columnwise_data + scale_inv = self._columnwise_scale_inv + transpose_output = True + if len(q.shape) >= 1: + q_M = q.shape[0] + for i in range(1, len(q.shape)): + q_K *= q.shape[i] + + orig_shape = q.shape + q = q.reshape(q_M, q_K) + formatted_scales = format_scale_as_logical_shape(q_M, q_K, scale_inv, block_len) + assert len(formatted_scales.shape) == 2 + m_tiles, k_tiles = formatted_scales.shape + unpadded_m, unpadded_k = q_M, q_K + m_block_len = block_len + k_block_len = block_len + if q_M % m_block_len != 0 or q_K % k_block_len != 0: + m_pad_amount = (m_block_len - (q_M % m_block_len)) % m_block_len + k_pad_amount = (k_block_len - (q_K % k_block_len)) % k_block_len + q = torch.nn.functional.pad( + q, (0, k_pad_amount, 0, m_pad_amount), mode="constant", value=0 + ).contiguous() + padded_M, padded_K = q.shape + q_tiled = q.reshape(m_tiles, m_block_len, k_tiles, k_block_len) + + torch_q_dtype = TE_DType_To_Torch[self._fp8_dtype] + + result = q_tiled.view(torch_q_dtype).to(torch.float32) * formatted_scales.view( + m_tiles, 1, k_tiles, 1 + ) + result = result.view(padded_M, padded_K).to(dtype) + if padded_M != unpadded_m or padded_K != unpadded_k: + result = result[:unpadded_m, :unpadded_k] + if len(orig_shape) == 0: + result = result.reshape([]) + else: + result = result.reshape(*orig_shape).contiguous() + if transpose_output: + return self._transpose_dq_columnwise_output(result) + return result + + def size(self, *args, **kwargs): + # pylint: disable=missing-function-docstring + if self._rowwise_data is not None: + return self._rowwise_data.size(*args, **kwargs) + else: + dims = list(self._columnwise_data.size(*args, **kwargs)) + reordered = [] + for i in range(1, len(dims)): + reordered.append(dims[i]) + reordered.append(dims[0]) + return torch.Size(reordered) + + def __repr__(self): + if self._rowwise_data is not None: + data = self.dequantize() + descriptor = "rowwise" + scale_inv = self._rowwise_scale_inv + else: + data = self.dequantize() + descriptor = "columnwise" + scale_inv = self._columnwise_scale_inv + return ( + "Float8BlockwiseQTensorBase(" + f"fp8_dtype={self._fp8_dtype}, " + f"{descriptor}_scaled_data={data_rowwise}" + f"{descriptor}_scale_inv={scale_inv}, " + ")" + ) diff --git a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py index d78bd55d9a..45de7e621e 100644 --- a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py @@ -96,7 +96,7 @@ def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorB """Prepare the tensor base for saving for backward After calling this, the tensor instance does not hold any - data. + data. Yes it does? TODO """ tensors = [self._rowwise_data, self._columnwise_data] diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py new file mode 100644 index 0000000000..9a5568af48 --- /dev/null +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -0,0 +1,553 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tensor class with FP8 data quantized with NxN tiles""" +from __future__ import annotations +from typing import Optional, Tuple, Iterable +import warnings + +import math +import torch +import transformer_engine_torch as tex + +from transformer_engine_torch import DType as TE_DType +from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc +from ..utils import devices_match, round_up_to_nearest_multiple + +aten = torch.ops.aten + + +class Float8BlockQuantizer(Quantizer): + """Builder class for tensors quantized with current scaling using + NxN quantization tilings to choose scale. + + This class is typically used to convert a high-precision tensor + (e.g. in FP32 or BF16) into a quantized tensor (e.g. in FP8). + + """ + + dtype: TE_DType + block_len: int + amax_epsilon: float + force_pow_2_scales: bool + block_scaling_dim: int + + def __init__( + self, + fp8_dtype: TE_DType, + *, + rowwise: bool, + columnwise: bool, + amax_epsilon: float = 0.0, + force_pow_2_scales: bool = False, + block_scaling_dim: int = 2, + ) -> None: + super().__init__(rowwise=rowwise, columnwise=columnwise) + assert rowwise + self.dtype = fp8_dtype + self.block_len = 128 + self.force_pow_2_scales = force_pow_2_scales + self.amax_epsilon = amax_epsilon + self.block_scaling_dim = block_scaling_dim + + def update_quantized( + self, + src: torch.Tensor, + dst: QuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + assert isinstance( + dst, Float8BlockwiseQTensor + ), f"Cannot store quantized blockwise tensor in {type(dst)} type." + # Make sure input is in expected format + if not devices_match(src.device, dst.device): + src = src.to(device=dst.device) + if not src.is_contiguous(): + src = src.contiguous() + + # Launch cast kernel + tex.quantize(src, self, dst, noop_flag) + + dst._fp8_dtype = self.dtype + return dst + + def get_scale_shape( + self, shape: Iterable[int], columnwise: bool + ) -> Tuple[int, int]: + # cuBLAS kernel format (for NxN by NxN and 1xN by NxN GEMMs) + # The scales for 2D block quantized tensors must have scales padded + # to multiples of 4 on the inner dimension. TODO: Verify whether outer + # dimension also to be padded for either GEMM. + if self.block_scaling_dim == 2: + logical_scale_shape = [1, 1] + for i in range(len(shape) - 1): + logical_scale_shape[-2] *= shape[i] + if len(shape) > 0: + logical_scale_shape[-1] = math.ceil(shape[-1] / self.block_len) + logical_scale_shape[-2] = math.ceil( + logical_scale_shape[-2] / self.block_len + ) + if columnwise: + tmp = logical_scale_shape[-1] + logical_scale_shape[-1] = logical_scale_shape[-2] + logical_scale_shape[-2] = tmp + logical_scale_shape[-1] = round_up_to_nearest_multiple( + logical_scale_shape[-1], 4 + ) + return tuple(logical_scale_shape) + else: + assert self.block_scaling_dim == 1, "Only 1D or 2D blocks supported" + + logical_scale_shape = [1, 1] + for i in range(len(shape) - 1): + logical_scale_shape[-1] *= shape[i] + if len(shape) > 0: + logical_scale_shape[-2] = shape[-1] + if not columnwise: + logical_scale_shape[-2] = math.ceil( + logical_scale_shape[-2] / self.block_len + ) + return tuple(logical_scale_shape) + else: + logical_scale_shape[-1] = math.ceil( + logical_scale_shape[-1] / self.block_len + ) + return (logical_scale_shape[1], logical_scale_shape[0]) + + def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]: + if len(shape) == 0: + return tuple() + colwise_shape = [shape[-1]] + for i in range(len(shape) - 1): + colwise_shape.append(shape[i]) + return tuple(colwise_shape) + + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + requires_grad: bool = False, + ) -> Float8BlockwiseQTensor: + """Construct quantized tensor with uninitialized data""" + if device is None: + device = torch.device("cuda") + + # Allocate FP8 data + data = torch.empty(shape, dtype=torch.uint8, device=device) + scale_shape = self.get_scale_shape(shape, columnwise=False) + scale_inv = torch.empty( + scale_shape, + dtype=torch.float32, + device=device, + ) + + # Allocate FP8 data transpose if needed + columnwise_data = None + columnwise_scale_inv = None + if self.columnwise_usage: + columnwise_data = torch.empty( + self.get_columnwise_shape(shape), dtype=torch.uint8, device=device + ) + columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) + columnwise_scale_inv = torch.empty( + columnwise_scale_shape, + dtype=torch.float32, + device=device, + ) + + # Construct FP8 tensor + return Float8BlockwiseQTensor( + shape=shape, + dtype=dtype, + fp8_dtype=self.dtype, + rowwise_data=data, + rowwise_scale_inv=scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + quantizer=self, + requires_grad=requires_grad, + ) + + def calibrate(self, tensor: torch.Tensor) -> None: + # NOTE: This interface is specific to requirements like delayed scaling + # where state from an estimator influences distribution parameters. + pass + + +class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): + """Tensor class with FP8 data quantized via NxN blocks or 1xN blocks. + + The tensor presents as having a standard, higher-precision dtype, + but the data itself is (scaled) FP8. For most tensor operations, + the data will be cast to the nominal dtype before performing the + operation. + + Parameters + ---------- + rowwise_data: torch.Tensor + FP8 data in a uint8 tensor matching shape of dequantized tensor. + rowwise_scale_inv: torch.Tensor + FP32 dequantization scales in GEMM format for dequantizing rowwise_data. + columnwise_data: Optional[torch.Tensor] + FP8 data in a uint8 tensor matching shape of dequantized tensor transpose. + columnwise_scale_inv: Optional[torch.Tensor] + FP32 dequantization scales in GEMM format for dequantizing columnwise_data. + + fp8_dtype: transformer_engine_torch.DType, default = kFloat8E4M3 + FP8 format. + quantizer: Quantizer - the Float8BlockQuantizer that quantized this tensor and + holds configuration about quantization and dequantization modes. + """ + + def __repr__(self, *, tensor_contents=None): + return f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize(dtype=self.dtype)})" + + def _get_quantizer(self) -> Quantizer: + """Get builder for quantized tensor + + Quantizer can be used for in-place operations. + + """ + assert self._quantizer is not None + return self._quantizer + + def quantize_( + self, + tensor: torch.Tensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> Float8BlockwiseQTensor: + """Update FP8 data + + Parameters + ---------- + tensor: torch.Tensor + Tensor to copy from + noop_flag: torch.Tensor, optional + float32 flag indicating whether to avoid performing update + + """ + if isinstance(tensor, QuantizedTensor): + return self.quantize_(tensor.dequantize()) + self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) + return self + + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """ + Construct plain PyTorch tensor from Float8BlockwiseQTensor + + By default the resulting tensor's dtype is the + Float8BlockwiseQTensor's pre-quantized dtype. + """ + if dtype is not None: + dequant_dtype = dtype + else: + dequant_dtype = self.dtype + return super().dequantize(dtype=dequant_dtype) + + def detach(self) -> Float8BlockwiseQTensor: + # pylint: disable=missing-function-docstring + return Float8BlockwiseQTensor.make_like(self) + + def update_usage(self, rowwise_usage=True, columnwise_usage=True): + """ + update_usage can be used to clear out one of two possible copies of the data. + """ + + assert ( + columnwise_usage or rowwise_usage + ), "Must retain some data either columnwise or rowwise" + + if columnwise_usage and rowwise_usage: + assert ( + self._rowwise_data is not None + and self._rowwise_scale_inv is not None + and self._columnwise_data is not None + and self._columnwise_scale_inv is not None + ), "Cannot update to rowwise and columnwise usage." + return + + if rowwise_usage: + assert ( + self._rowwise_data is not None and self._rowwise_scale_inv is not None + ), "Cannot update to rowwise usage." + self._columnwise_data = None + self._columnwise_scale_inv = None + return + if columnwise_usage: + assert ( + self._columnwise_data is not None + and self._columnwise_scale_inv is not None + ), "Cannot update to columnwise usage." + self._rowwise_data = None + self._rowwise_scale_inv = None + return + + return + + def clone(self) -> Float8BlockwiseQTensor: + # pylint: disable=missing-function-docstring + rowwise_data = None + if self._rowwise_data is not None: + rowwise_data = self._rowwise_data.detach().clone() + columnwise_data = None + if self._columnwise_data is not None: + columnwise_data = self._columnwise_data.detach().clone() + return _IdentityFunc.apply( + self, + { + "rowwise_data": rowwise_data, + "columnwise_data": columnwise_data, + }, + ) + + def view(self, *shape: Tuple[int]) -> Float8BlockwiseQTensor: + # pylint: disable=missing-function-docstring + return _ViewFunc.apply(self, shape) + + def reshape(self, *shape: Tuple[int]) -> Float8BlockwiseQTensor: + # pylint: disable=missing-function-docstring + return _ReshapeFunc.apply(self, shape) + + def contiguous( + self, + memory_format: torch.memory_format = torch.contiguous_format, + ) -> Float8BlockwiseQTensor: + """Returns tensor with data in provided memory format + + Returns `self` if data is already in correct memory format. + + """ + if ( + self._rowwise_data is not None + and self._rowwise_data.is_contiguous(memory_format=memory_format) + and ( + (self._columnwise_data is None) + or (self._columnwise_data.is_contiguous(memory_format=memory_format)) + ) + ): + return self + raise ValueError( + "Float8BlockwiseQTensor does not support different memory formats!" + ) + + def clear(self): + """Deallocate this tensor's memory. Typically not needed and must be used carefully.""" + self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None + self._columnwise_data = ( + torch.Tensor() if self._columnwise_data is not None else None + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + + # View op + if func == aten.view.default: + tensor = args[0] + data = tensor._rowwise_data + out_data = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + out_shape = out_data.size() + return Float8BlockwiseQTensor( + shape=out_shape, + dtype=tensor.dtype, + rowwise_data=out_data, + rowwise_scale_inv=tensor._rowwise_scale_inv, + columnwise_data=tensor._columnwise_data, + columnwise_scale_inv=tensor._columnwise_scale_inv, + quantizer=tensor._quantizer, + requires_grad=False, + fp8_dtype=tensor._fp8_dtype, + ) + + # Default case + return super().__torch_dispatch__(func, types, args, kwargs) + + @classmethod + def _make_in_reduce_ex( + cls, + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: torch.Tensor, + columnwise_scale_inv: torch.Tensor, + fp8_dtype: TE_DType, + dtype: torch.dtype, + ) -> Float8BlockwiseQTensor: + """Build Float8BlockwiseQTensor, for use in __reduce__ + + __reduce_ex__ assumes object constructor has positional + arguments. + + """ + return Float8BlockwiseQTensor( + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + fp8_dtype=fp8_dtype, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + dtype=dtype, + ) + + def __reduce_ex__(self, protocol: int) -> tuple: + """Custom pickling to remove references to FP8 metadata objects""" + return ( + Float8BlockwiseQTensor._make_in_reduce_ex, + ( + self._rowwise_data, + self._rowwise_scale_inv, + self._columnwise_data, + self._columnwise_scale_inv, + self._fp8_dtype, + self.dtype, + ), + ) + + def _get_data(self) -> Float8BlockwiseQTensor: + """Get tensor data property""" + return super().data + + @torch.no_grad() + def _set_data(self, tensor: torch.Tensor) -> None: + """Set tensor data property + + Just takes FP8 data if setting from a Float8BlockwiseQTensor. Otherwise + casts to FP8. + + """ + + # Tensor device + new_device = tensor.device if tensor.is_cuda else self.device + + # Just copy FP8 data if other tensor is Float8BlockwiseQTensor + if isinstance(tensor, Float8BlockwiseQTensor): + if ( # pylint: disable=too-many-boolean-expressions + self.size() != tensor.size() + or self.stride() != tensor.stride() + or self.storage_offset() != tensor.storage_offset() + or self.dtype != tensor.dtype + or self.layout != tensor.layout + or not devices_match(self.device, new_device) + ): + dummy_tensor = torch.Tensor._make_wrapper_subclass( + Float8BlockwiseQTensor, + tensor.size(), + strides=tensor.stride(), + storage_offset=tensor.storage_offset(), + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=tensor.requires_grad, + device=new_device, + ) + # pylint: disable=unnecessary-dunder-call + super(Float8BlockwiseQTensor, type(self)).data.__set__( + self, dummy_tensor + ) + self._rowwise_data = tensor._rowwise_data + self._columnwise_data = tensor._columnwise_data + self._quantizer = tensor._quantizer + self._fp8_dtype = tensor._fp8_dtype + self._rowwise_scale_inv = tensor._rowwise_scale_inv + self._columnwise_scale_inv = tensor._columnwise_scale_inv + return + + # Quantize to FP8 + assert self._quantizer is not None, "Can't quantize without a quantizer" + self.data = self._quantizer.quantize(tensor) + if self.requires_grad != tensor.requires_grad: + self.requires_grad_(requires_grad=tensor.requires_grad) + + # Cast to FP8 when setting Float8BlockwiseQTensor.data + data = property(_get_data, _set_data) + + +class _ViewFunc(torch.autograd.Function): + """View function + + View the Float8BlockwiseQTensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: Float8BlockwiseQTensor, + shape: Optional[list[int]] = None, + ) -> Float8BlockwiseQTensor: + # pylint: disable=missing-function-docstring + + # Return input tensor if shape is not provided + ctx.shape = tensor.shape + if shape is None: + return tensor + + if shape != ctx.shape: + raise NotImplementedError("View not implemented.") + else: + return tensor + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + + if isinstance(grad, Float8BlockwiseQTensor): + raise NotImplementedError("View bwd not implemented") + return grad.view(ctx.shape), None + + +class _ReshapeFunc(torch.autograd.Function): + """Reshape function + + Reshape the Float8BlockwiseQTensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: Float8BlockwiseQTensor, + shape: Optional[list[int]] = None, + ) -> Float8BlockwiseQTensor: + # pylint: disable=missing-function-docstring + + # Return input tensor if shape is not provided + ctx.shape = tensor.shape + if shape is None: + return tensor + + # Canonicalize shape + if not isinstance(shape, Iterable): + shape = [shape] + elif len(shape) == 1 and isinstance(shape[0], Iterable): + shape = shape[0] + if -1 in shape: + shape = list(shape) + d_inferred = -math.prod(ctx.shape) // math.prod(shape) + for i, d in enumerate(shape): + if d == -1: + shape[i] = d_inferred + break + if shape != ctx.shape: + raise NotImplementedError("Reshape not implemented yet.") + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + + if isinstance(grad, Float8BlockwiseQTensor): + raise NotImplementedError("Reshape bwd not implemented yet.") + return grad.view(ctx.shape), None From e6c8c770524b8d06a9a257aa90934a5aff70efbf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 27 Feb 2025 00:28:27 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../blockwise_quantizer_reference.py | 24 +--- .../test_float8_blockwise_scaling_exact.py | 93 ++++++------- tests/pytorch/test_float8blockwisetensor.py | 14 +- transformer_engine/common/common.h | 8 +- .../common/transpose/cast_transpose.h | 26 ++-- .../common/transpose/compute_scale.cuh | 2 +- .../quantize_transpose_square_blockwise.cu | 125 ++++++------------ .../quantize_transpose_vector_blockwise.cu | 101 +++++--------- .../common/util/cast_kernels.cuh | 34 ++--- .../pytorch/csrc/extensions/cast.cpp | 4 +- .../pytorch/csrc/extensions/quantizer.cpp | 3 +- .../_internal/float8_blockwise_tensor_base.py | 8 +- .../pytorch/tensor/float8_blockwise_tensor.py | 40 ++---- 13 files changed, 177 insertions(+), 305 deletions(-) diff --git a/tests/pytorch/references/blockwise_quantizer_reference.py b/tests/pytorch/references/blockwise_quantizer_reference.py index 5331fd4839..72cb062c31 100644 --- a/tests/pytorch/references/blockwise_quantizer_reference.py +++ b/tests/pytorch/references/blockwise_quantizer_reference.py @@ -94,9 +94,7 @@ def _munge_scale_tensor(s: torch.Tensor) -> torch.Tensor: if K % 4 == 0: return s k_pad = 4 - (K % 4) - return torch.nn.functional.pad( - s, (0, k_pad), mode="constant", value=0 - ).contiguous() + return torch.nn.functional.pad(s, (0, k_pad), mode="constant", value=0).contiguous() s = _munge_scale_tensor(unmunged.scale) if unmunged.scale_t is None: @@ -199,9 +197,7 @@ def _quantize_square_block_tiling( qx_t = None scale_inv_t = None - return QuantizeResult( - data=qx, scale=scale_inv, data_t=qx_t, scale_t=scale_inv_t - ) + return QuantizeResult(data=qx, scale=scale_inv, data_t=qx_t, scale_t=scale_inv_t) @classmethod def _quantize_vectorwise_reference( @@ -284,9 +280,7 @@ def _quantize_vector_tiling( else: qout_t, scale_inv_t = None, None - return QuantizeResult( - data=qout, scale=scale_inv, data_t=qout_t, scale_t=scale_inv_t - ) + return QuantizeResult(data=qout, scale=scale_inv, data_t=qout_t, scale_t=scale_inv_t) def ref_dequantize_rowwise( self, @@ -297,20 +291,14 @@ def ref_dequantize_rowwise( ) -> torch.Tensor: assert q.dim() == 2 q_M, q_K = q.shape - s = self.scale_munger.demunge_scale_shape_from_backend( - (q_M, q_K), s, quant_tile_shape - ) + s = self.scale_munger.demunge_scale_shape_from_backend((q_M, q_K), s, quant_tile_shape) assert len(s.shape) == 2 m_tiles, k_tiles = s.shape M, K = q.shape unpadded_m, unpadded_k = M, K if M % quant_tile_shape[0] != 0 or K % quant_tile_shape[1] != 0: - m_pad_amount = ( - quant_tile_shape[0] - (M % quant_tile_shape[0]) - ) % quant_tile_shape[0] - k_pad_amount = ( - quant_tile_shape[1] - (K % quant_tile_shape[1]) - ) % quant_tile_shape[1] + m_pad_amount = (quant_tile_shape[0] - (M % quant_tile_shape[0])) % quant_tile_shape[0] + k_pad_amount = (quant_tile_shape[1] - (K % quant_tile_shape[1])) % quant_tile_shape[1] q = torch.nn.functional.pad( q, (0, k_pad_amount, 0, m_pad_amount), mode="constant", value=0 ).contiguous() diff --git a/tests/pytorch/test_float8_blockwise_scaling_exact.py b/tests/pytorch/test_float8_blockwise_scaling_exact.py index 3603268d09..16647184d6 100644 --- a/tests/pytorch/test_float8_blockwise_scaling_exact.py +++ b/tests/pytorch/test_float8_blockwise_scaling_exact.py @@ -14,14 +14,14 @@ Float8BlockQuantizer, Float8BlockwiseQTensor, ) -from tests.pytorch.references.blockwise_quantizer_reference import BlockwiseQuantizerReference, QuantizeResult +from tests.pytorch.references.blockwise_quantizer_reference import ( + BlockwiseQuantizerReference, + QuantizeResult, +) + def initialize_for_many_scales( - x_shape_2d: Tuple[int, int], - tile_shape: Tuple[int, int], - *, - dtype: torch.dtype, - device: str + x_shape_2d: Tuple[int, int], tile_shape: Tuple[int, int], *, dtype: torch.dtype, device: str ) -> torch.Tensor: """ Put separate distributions into each quantization tile @@ -56,9 +56,9 @@ def initialize_for_many_scales( max_dst_vals[0] - min_dst_vals[0], max_dst_vals[1] - min_dst_vals[1], ) - result[ - min_dst_vals[0] : max_dst_vals[0], min_dst_vals[1] : max_dst_vals[1] - ] = tile[: max_src_vals[0], : max_src_vals[1]] + result[min_dst_vals[0] : max_dst_vals[0], min_dst_vals[1] : max_dst_vals[1]] = tile[ + : max_src_vals[0], : max_src_vals[1] + ] return result @@ -82,9 +82,7 @@ def initialize_for_many_scales( ], ) @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) -@pytest.mark.parametrize( - "quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str -) +@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) @pytest.mark.parametrize("eps", [0, 1e-12], ids=["eps_0", "eps_1e-12"]) @pytest.mark.parametrize( "return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"] @@ -103,21 +101,23 @@ def test_quantization_block_tiling_versus_reference( ) -> None: te_dtype = TE_DType[quant_dtype] if tile_size == (1, 128): - block_scaling_dim=1 + block_scaling_dim = 1 elif tile_size == (128, 128): - block_scaling_dim=2 + block_scaling_dim = 2 else: raise ValueError("Non support tile size") # This test runs a comparison of the ref class versus the class using # CUDA kernels to quantize. They should quantize identically for pixels # that are not DC values in the scale factor shape. ref_quantizer = BlockwiseQuantizerReference() - sut_quantizer = Float8BlockQuantizer(fp8_dtype=te_dtype, - rowwise=True, - columnwise=return_transpose, - amax_epsilon=eps, - force_pow_2_scales=pow_2_scales, - block_scaling_dim=block_scaling_dim) + sut_quantizer = Float8BlockQuantizer( + fp8_dtype=te_dtype, + rowwise=True, + columnwise=return_transpose, + amax_epsilon=eps, + force_pow_2_scales=pow_2_scales, + block_scaling_dim=block_scaling_dim, + ) # Setup device and random seed device = "cuda" @@ -129,10 +129,7 @@ def test_quantization_block_tiling_versus_reference( x = initialize_for_many_scales((M, N), tile_size, dtype=x_dtype, device=device) x_fp8_sut = sut_quantizer.make_empty((M, N), dtype=x_dtype, device=device, requires_grad=False) - x_fp8_sut = sut_quantizer.update_quantized( - x, - x_fp8_sut - ) + x_fp8_sut = sut_quantizer.update_quantized(x, x_fp8_sut) assert x_fp8_sut._rowwise_data is not None qx: torch.Tensor = x_fp8_sut._rowwise_data.view(dtype=quant_dtype) @@ -142,8 +139,12 @@ def test_quantization_block_tiling_versus_reference( sx_t = x_fp8_sut._columnwise_scale_inv qresult_ref = ref_quantizer.quantize( - x, quant_dtype=quant_dtype, return_transpose=return_transpose, - eps=eps, pow_2_scales=pow_2_scales, quant_tile_shape=tile_size + x, + quant_dtype=quant_dtype, + return_transpose=return_transpose, + eps=eps, + pow_2_scales=pow_2_scales, + quant_tile_shape=tile_size, ) qx_ref, sx_ref, qx_t_ref, sx_t_ref = ( qresult_ref.data, @@ -154,7 +155,7 @@ def test_quantization_block_tiling_versus_reference( # Check torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0, rtol=0.0) - if (tile_size[0] != 1): + if tile_size[0] != 1: # Zero out values that are don't care values # cuBLAS has padding of 2D tensors. scale_mask = torch.ones( @@ -173,7 +174,7 @@ def test_quantization_block_tiling_versus_reference( assert qx_t_ref is not None assert sx_t is not None assert sx_t_ref is not None - if (tile_size[0] != 1): + if tile_size[0] != 1: scale_mask = torch.ones( (math.ceil(N / tile_size[0]), math.ceil(M / tile_size[1])), device=sx_t.device, @@ -189,6 +190,7 @@ def test_quantization_block_tiling_versus_reference( assert qx_t is None and qx_t_ref is None assert sx_t is None and sx_t_ref is None + @pytest.mark.parametrize( "M, N", [ @@ -197,9 +199,7 @@ def test_quantization_block_tiling_versus_reference( ], ) @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) -@pytest.mark.parametrize( - "quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str -) +@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) @pytest.mark.parametrize("eps", [0, math.pow(2, -125)], ids=["eps_0", "eps_small"]) @pytest.mark.parametrize("pow_2_scales", [True, False], ids=["pow2scales", "f32scales"]) @pytest.mark.parametrize("tile_size", [(1, 128)]) @@ -218,18 +218,20 @@ def test_quantization_block_tiling_extrema_versus_reference( # branch coverage of scale computation. te_dtype = TE_DType[quant_dtype] if tile_size == (1, 128): - block_scaling_dim=1 + block_scaling_dim = 1 elif tile_size == (128, 128): - block_scaling_dim=2 + block_scaling_dim = 2 else: raise ValueError("Non support tile size") ref_quantizer = BlockwiseQuantizerReference() - sut_quantizer = Float8BlockQuantizer(fp8_dtype=te_dtype, - rowwise=True, - columnwise=False, - amax_epsilon=eps, - force_pow_2_scales=pow_2_scales, - block_scaling_dim=block_scaling_dim) + sut_quantizer = Float8BlockQuantizer( + fp8_dtype=te_dtype, + rowwise=True, + columnwise=False, + amax_epsilon=eps, + force_pow_2_scales=pow_2_scales, + block_scaling_dim=block_scaling_dim, + ) # Setup device and random seed device = "cuda" seed = 0 @@ -245,16 +247,17 @@ def test_quantization_block_tiling_extrema_versus_reference( # Run cast and transpose kernel # Internal call ops.quantize_tensorwise x_fp8_sut = sut_quantizer.make_empty((M, N), dtype=x_dtype, device=device, requires_grad=False) - x_fp8_sut = sut_quantizer.update_quantized( - x, - x_fp8_sut - ) + x_fp8_sut = sut_quantizer.update_quantized(x, x_fp8_sut) qx = x_fp8_sut._rowwise_data.view(dtype=quant_dtype) sx = x_fp8_sut._rowwise_scale_inv qresult_ref = ref_quantizer.quantize( - x, quant_dtype=quant_dtype, return_transpose=return_transpose, - eps=eps, pow_2_scales=pow_2_scales, quant_tile_shape=tile_size + x, + quant_dtype=quant_dtype, + return_transpose=return_transpose, + eps=eps, + pow_2_scales=pow_2_scales, + quant_tile_shape=tile_size, ) qx_ref, sx_ref = ( qresult_ref.data, diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py index 6deb714ce2..7058fdb22f 100644 --- a/tests/pytorch/test_float8blockwisetensor.py +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -78,12 +78,8 @@ def test_constructor( shape=dims, dtype=dtype, rowwise_data=torch.zeros(dims, device="cuda", dtype=torch.uint8), - rowwise_scale_inv=torch.zeros( - scale_dims, device="cuda", dtype=torch.float32 - ), - columnwise_data=torch.zeros( - columnwise_dims, device="cuda", dtype=torch.uint8 - ), + rowwise_scale_inv=torch.zeros(scale_dims, device="cuda", dtype=torch.float32), + columnwise_data=torch.zeros(columnwise_dims, device="cuda", dtype=torch.uint8), columnwise_scale_inv=torch.zeros( columnwise_scale_dims, device="cuda", dtype=torch.float32 ), @@ -146,9 +142,7 @@ def test_quantize_dequantize_dtypes( columnwise=False, block_scaling_dim=block_scaling_dim, ) - self._test_quantize_dequantize( - quantizer=quantizer, dtype=dtype, atol=atol, rtol=rtol - ) + self._test_quantize_dequantize(quantizer=quantizer, dtype=dtype, atol=atol, rtol=rtol) @pytest.mark.parametrize( "dims", [[], 256, 311, [264], [256, 512], [250, 500], [7, 5, 3], [2, 3, 5, 3]] @@ -181,7 +175,7 @@ def test_quantize_dequantize_dims( @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) @pytest.mark.parametrize("dq_columnwise", [True, False]) def test_quantize_dequantize_dims_cpp_allocate_output( - self, dims: DimsType, block_scaling_dim: int, fp8_dtype: tex.DType, dq_columnwise: bool + self, dims: DimsType, block_scaling_dim: int, fp8_dtype: tex.DType, dq_columnwise: bool ) -> None: atol = _tols[fp8_dtype]["atol"] rtol = _tols[fp8_dtype]["rtol"] diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index fbcc0244bc..1f19b14174 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -433,14 +433,10 @@ struct TypeInfo { #define TRANSFORMER_ENGINE_SWITCH_CONDITION(CONDITION, FLAG, ...) \ if (CONDITION) { \ constexpr bool FLAG = true; \ - { \ - __VA_ARGS__ \ - } \ + { __VA_ARGS__ } \ } else { \ constexpr bool FLAG = false; \ - { \ - __VA_ARGS__ \ - } \ + { __VA_ARGS__ } \ } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index 893207a51a..cf2cb15174 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -23,26 +23,18 @@ template -__global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose_kernel( - const IType* const input, - OType* const output_c, - OType* const output_t, - CType* const tile_scales_inv_c, - CType* const tile_scales_inv_t, - const size_t row_length, - const size_t num_rows, - const size_t scale_stride_x, - const size_t scale_stride_y, - const size_t scale_t_stride_x, - const size_t scale_t_stride_y, - const float epsilon, - const __grid_constant__ CUtensorMap tensor_map_output_t) { +__global__ void __launch_bounds__(THREADS_PER_BLOCK) + block_scaled_cast_transpose_kernel(const IType* const input, OType* const output_c, + OType* const output_t, CType* const tile_scales_inv_c, + CType* const tile_scales_inv_t, const size_t row_length, + const size_t num_rows, const size_t scale_stride_x, + const size_t scale_stride_y, const size_t scale_t_stride_x, + const size_t scale_t_stride_y, const float epsilon, + const __grid_constant__ CUtensorMap tensor_map_output_t) { using IVec = Vec; using OVecCast = Vec; using OVecTrans = Vec; @@ -223,9 +217,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose // Initiate TMA transfer to copy shared memory to global memory if (threadIdx.x == 0) { cde::cp_async_bulk_tensor_2d_shared_to_global( - &tensor_map_output_t, - tile_id_y * BLOCK_TILE_DIM, - tile_id_x * BLOCK_TILE_DIM, + &tensor_map_output_t, tile_id_y * BLOCK_TILE_DIM, tile_id_x * BLOCK_TILE_DIM, block_tile_trans_shared_otype_ptr); // Wait for TMA transfer to have finished reading shared memory. // Create a "bulk async-group" out of the previous bulk copy operation. @@ -254,18 +246,10 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose template __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose_kernel_notaligned( - const IType* const input, - OType* const output_c, - OType* const output_t, - CType* const tile_scales_inv_c, - CType* const tile_scales_inv_t, - const size_t row_length, - const size_t num_rows, - const size_t scale_stride_x, - const size_t scale_stride_y, - const size_t scale_t_stride_x, - const size_t scale_t_stride_y, - const float epsilon) { + const IType* const input, OType* const output_c, OType* const output_t, + CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length, + const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y, + const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon) { using IVec = Vec; using OVecCast = Vec; using OVecTrans = Vec; @@ -567,22 +551,18 @@ void nvte_quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleT const size_t num_blocks_y = DIVUP(num_rows, BLOCK_TILE_DIM); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.dtype, - InputType, + input.dtype, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output.dtype, - OutputType, + output.dtype, OutputType, dim3 grid(num_blocks_x, num_blocks_y, 1); const bool full_tile = row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0; TRANSFORMER_ENGINE_SWITCH_CONDITION( - return_transpose, - kReturnTranspose, + return_transpose, kReturnTranspose, TRANSFORMER_ENGINE_SWITCH_CONDITION( - pow_2_scale, - kPow2Scale, + pow_2_scale, kPow2Scale, if (full_tile) { CUtensorMap tensor_map_output_trans; @@ -590,51 +570,34 @@ void nvte_quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleT tensor_map_output_trans = get_tensor_map(output_t, num_rows, row_length); } - block_scaled_cast_transpose_kernel< - kReturnTranspose, - kPow2Scale, - float, - InputType, - OutputType><<>>( - reinterpret_cast(input.dptr), - reinterpret_cast(output.dptr), - reinterpret_cast(output_t.dptr), - reinterpret_cast(scale_inv.dptr), - reinterpret_cast(scale_inv_t.dptr), - row_length, - num_rows, - scale_stride_x, - scale_stride_y, - scale_t_stride_x, - scale_t_stride_y, - epsilon, - tensor_map_output_trans); + block_scaled_cast_transpose_kernel + <<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, + scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, + epsilon, tensor_map_output_trans); } else { - block_scaled_cast_transpose_kernel_notaligned< - kReturnTranspose, - kPow2Scale, - float, - InputType, - OutputType><<>>( - reinterpret_cast(input.dptr), - reinterpret_cast(output.dptr), - reinterpret_cast(output_t.dptr), - reinterpret_cast(scale_inv.dptr), - reinterpret_cast(scale_inv_t.dptr), - row_length, - num_rows, - scale_stride_x, - scale_stride_y, - scale_t_stride_x, - scale_t_stride_y, - epsilon); - } // full-tile - - ) // kPow2Scale - ) // kReturnTranspose - ) // OutputType - ) // InputType + block_scaled_cast_transpose_kernel_notaligned + <<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, + scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, + epsilon); + } // full-tile + + ) // kPow2Scale + ) // kReturnTranspose + ) // OutputType + ) // InputType NVTE_CHECK_CUDA(cudaGetLastError()); } -} // namespace transformer_engine::detail \ No newline at end of file +} // namespace transformer_engine::detail diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index 1e73e74987..d9676504ed 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -24,7 +24,7 @@ namespace { // clang-format off /* -Step 1: Load input to shared memory +Step 1: Load input to shared memory * shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding) * 8 warps * Loop 8 times @@ -137,27 +137,16 @@ constexpr int kNumThreadsStore = kTileDim / kNVecOut; static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp"); static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); -template < - bool kReturnTranspose, - bool kIsE8Scaling, - bool kPermuteScale, - bool kAligned, - typename CType, - typename IType, - typename OType> -__global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel( - const IType* const input, - OType* const output_c, - OType* const output_t, - CType* const tile_scales_inv_c, - CType* const tile_scales_inv_t, - const size_t row_length, - const size_t num_rows, - const size_t scale_stride_x, - const size_t scale_stride_y, - const size_t scale_t_stride_x, - const size_t scale_t_stride_y, - const float epsilon) { +template +__global__ void __launch_bounds__(kThreadsPerBlock) + block_scaled_1d_cast_transpose_kernel(const IType* const input, OType* const output_c, + OType* const output_t, CType* const tile_scales_inv_c, + CType* const tile_scales_inv_t, const size_t row_length, + const size_t num_rows, const size_t scale_stride_x, + const size_t scale_stride_y, + const size_t scale_t_stride_x, + const size_t scale_t_stride_y, const float epsilon) { using SMemVec = Vec; using OVec = Vec; union IVec { @@ -491,73 +480,49 @@ void nvte_quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleT const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.dtype, - InputType, + input.dtype, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output.dtype, - OutputType, + output.dtype, OutputType, dim3 grid(num_blocks_x, num_blocks_y, 1); const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0; TRANSFORMER_ENGINE_SWITCH_CONDITION( - return_transpose, - kReturnTranspose, + return_transpose, kReturnTranspose, TRANSFORMER_ENGINE_SWITCH_CONDITION( - pow2_scale, - kPow2Scale, + pow2_scale, kPow2Scale, TRANSFORMER_ENGINE_SWITCH_CONDITION( - full_tile, - kAligned, + full_tile, kAligned, size_t smem_bytes = kSMemSize * sizeof(InputType); // shared memory must be requested up if (smem_bytes >= 48 * 1024) { cudaError_t err = cudaFuncSetAttribute( - &block_scaled_1d_cast_transpose_kernel< - kReturnTranspose, - kPow2Scale, - kPermuteScale, - kAligned, - float, - InputType, - OutputType>, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_bytes); - NVTE_CHECK( - err == cudaSuccess, "Failed to set dynamic shared memory size."); - } - block_scaled_1d_cast_transpose_kernel< - kReturnTranspose, - kPow2Scale, - kPermuteScale, - kAligned, - float, - InputType, - OutputType> - <<>>( + &block_scaled_1d_cast_transpose_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size."); + } block_scaled_1d_cast_transpose_kernel + <<>>( reinterpret_cast(input.dptr), reinterpret_cast(output.dptr), reinterpret_cast(output_t.dptr), reinterpret_cast(scale_inv.dptr), - reinterpret_cast(scale_inv_t.dptr), - row_length, - num_rows, - scale_stride_x, - scale_stride_y, - scale_t_stride_x, - scale_t_stride_y, - epsilon); - ) // kAligned - ) // kPow2Scale - ) // kReturnTranspose - ) // OutputType - ) // InputType + reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, + scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, + epsilon);) // kAligned + ) // kPow2Scale + ) // kReturnTranspose + ) // OutputType + ) // InputType NVTE_CHECK_CUDA(cudaGetLastError()); } -} // namespace transformer_engine::detail \ No newline at end of file +} // namespace transformer_engine::detail diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 3ca5f5c124..229862da88 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1248,29 +1248,21 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe case NVTE_BLOCK_SCALING: { // FIXME(kwyss): Currently ignoring IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters. if (output_tensor->block_scaling_dim == 2) { - nvte_quantize_transpose_square_blockwise( - input_tensor->data, - output_tensor->scale_inv, - output_tensor->columnwise_scale_inv, - output_tensor->data, - output_tensor->columnwise_data, - /*epsilon=*/ output_tensor->amax_epsilon, - /*return_transpose=*/ output_tensor->has_columnwise_data(), - output_tensor->force_pow_2_scales, - stream); + nvte_quantize_transpose_square_blockwise( + input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, + /*epsilon=*/output_tensor->amax_epsilon, + /*return_transpose=*/output_tensor->has_columnwise_data(), + output_tensor->force_pow_2_scales, stream); } else if (output_tensor->block_scaling_dim == 1) { - nvte_quantize_transpose_vector_blockwise( - input_tensor->data, - output_tensor->scale_inv, - output_tensor->columnwise_scale_inv, - output_tensor->data, - output_tensor->columnwise_data, - /*epsilon=*/ output_tensor->amax_epsilon, - /*return_transpose=*/ output_tensor->has_columnwise_data(), - output_tensor->force_pow_2_scales, - stream); + nvte_quantize_transpose_vector_blockwise( + input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, + /*epsilon=*/output_tensor->amax_epsilon, + /*return_transpose=*/output_tensor->has_columnwise_data(), + output_tensor->force_pow_2_scales, stream); } else { - NVTE_ERROR("Not supported block scaling dim."); + NVTE_ERROR("Not supported block scaling dim."); } break; } diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index e1e8da1bb6..66dafdaafb 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -13,9 +13,7 @@ namespace transformer_engine::pytorch { -py::object quantize(const at::Tensor& tensor, - py::handle quantizer, - const py::object& output, +py::object quantize(const at::Tensor& tensor, py::handle quantizer, const py::object& output, std::optional noop) { init_extension(); auto my_quantizer = convert_quantizer(quantizer); diff --git a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp index b1c77d84ec..fafc9043e3 100644 --- a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp @@ -239,8 +239,7 @@ std::pair Float8BlockQuantizer::create_tensor( data_colwise = at::empty(torch_columnwise_shape, opts); scale_inv_colwise = at::empty({sinv0, sinv1}, scale_opts); - tensor.set_columnwise_data(data_colwise.data_ptr(), this->dtype, - columnwise_shape); + tensor.set_columnwise_data(data_colwise.data_ptr(), this->dtype, columnwise_shape); tensor.set_columnwise_scale_inv(scale_inv_colwise.data_ptr(), DType::kFloat32, std::vector{sinv0, sinv1}); } diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py index b1a61fae4b..ffed102ee5 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -89,9 +89,7 @@ def get_data_tensors(self): """Get this Tensor's data.""" return self._rowwise_data, self._columnwise_data - def _transpose_dq_columnwise_output( - self, columnwise_dq: torch.Tensor - ) -> torch.Tensor: + def _transpose_dq_columnwise_output(self, columnwise_dq: torch.Tensor) -> torch.Tensor: """Takes dequantized columnwise data and permutes to a rowwise shape""" if columnwise_dq.dim() < 2: return columnwise_dq @@ -99,9 +97,7 @@ def _transpose_dq_columnwise_output( permute_dims.append(0) return torch.permute(columnwise_dq, tuple(permute_dims)).contiguous() - def _dequantize_vectorwise( - self, *, dtype: torch.dtype = torch.float32 - ) -> torch.Tensor: + def _dequantize_vectorwise(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: block_len = 128 q_M, q_K = 1, 1 diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 9a5568af48..8090ac20b8 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -74,9 +74,7 @@ def update_quantized( dst._fp8_dtype = self.dtype return dst - def get_scale_shape( - self, shape: Iterable[int], columnwise: bool - ) -> Tuple[int, int]: + def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]: # cuBLAS kernel format (for NxN by NxN and 1xN by NxN GEMMs) # The scales for 2D block quantized tensors must have scales padded # to multiples of 4 on the inner dimension. TODO: Verify whether outer @@ -87,16 +85,12 @@ def get_scale_shape( logical_scale_shape[-2] *= shape[i] if len(shape) > 0: logical_scale_shape[-1] = math.ceil(shape[-1] / self.block_len) - logical_scale_shape[-2] = math.ceil( - logical_scale_shape[-2] / self.block_len - ) + logical_scale_shape[-2] = math.ceil(logical_scale_shape[-2] / self.block_len) if columnwise: tmp = logical_scale_shape[-1] logical_scale_shape[-1] = logical_scale_shape[-2] logical_scale_shape[-2] = tmp - logical_scale_shape[-1] = round_up_to_nearest_multiple( - logical_scale_shape[-1], 4 - ) + logical_scale_shape[-1] = round_up_to_nearest_multiple(logical_scale_shape[-1], 4) return tuple(logical_scale_shape) else: assert self.block_scaling_dim == 1, "Only 1D or 2D blocks supported" @@ -107,14 +101,10 @@ def get_scale_shape( if len(shape) > 0: logical_scale_shape[-2] = shape[-1] if not columnwise: - logical_scale_shape[-2] = math.ceil( - logical_scale_shape[-2] / self.block_len - ) + logical_scale_shape[-2] = math.ceil(logical_scale_shape[-2] / self.block_len) return tuple(logical_scale_shape) else: - logical_scale_shape[-1] = math.ceil( - logical_scale_shape[-1] / self.block_len - ) + logical_scale_shape[-1] = math.ceil(logical_scale_shape[-1] / self.block_len) return (logical_scale_shape[1], logical_scale_shape[0]) def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]: @@ -205,7 +195,10 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): """ def __repr__(self, *, tensor_contents=None): - return f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype}, data={self.dequantize(dtype=self.dtype)})" + return ( + f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype}," + f" data={self.dequantize(dtype=self.dtype)})" + ) def _get_quantizer(self) -> Quantizer: """Get builder for quantized tensor @@ -281,8 +274,7 @@ def update_usage(self, rowwise_usage=True, columnwise_usage=True): return if columnwise_usage: assert ( - self._columnwise_data is not None - and self._columnwise_scale_inv is not None + self._columnwise_data is not None and self._columnwise_scale_inv is not None ), "Cannot update to columnwise usage." self._rowwise_data = None self._rowwise_scale_inv = None @@ -332,16 +324,12 @@ def contiguous( ) ): return self - raise ValueError( - "Float8BlockwiseQTensor does not support different memory formats!" - ) + raise ValueError("Float8BlockwiseQTensor does not support different memory formats!") def clear(self): """Deallocate this tensor's memory. Typically not needed and must be used carefully.""" self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None - self._columnwise_data = ( - torch.Tensor() if self._columnwise_data is not None else None - ) + self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): @@ -448,9 +436,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: device=new_device, ) # pylint: disable=unnecessary-dunder-call - super(Float8BlockwiseQTensor, type(self)).data.__set__( - self, dummy_tensor - ) + super(Float8BlockwiseQTensor, type(self)).data.__set__(self, dummy_tensor) self._rowwise_data = tensor._rowwise_data self._columnwise_data = tensor._columnwise_data self._quantizer = tensor._quantizer