diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index ce78fcaae2..ff225cccba 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -10,6 +10,7 @@ add_executable(test_operator test_cast_mxfp8_gated_swiglu.cu test_qdq.cu test_cast_mxfp8.cu + test_cast_float8blockwise.cu test_dequantize_mxfp8.cu test_transpose.cu test_cast_transpose.cu diff --git a/tests/cpp/operator/test_cast_float8blockwise.cu b/tests/cpp/operator/test_cast_float8blockwise.cu new file mode 100644 index 0000000000..171d22be71 --- /dev/null +++ b/tests/cpp/operator/test_cast_float8blockwise.cu @@ -0,0 +1,640 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" + +using namespace transformer_engine; +using namespace test; + +namespace { + +constexpr size_t kBlockLen = 128; + +enum ProcessingMethod { + CAST_ONLY, + // CAST_DBIAS, + // CAST_DBIAS_DACT, + // CAST_DACT, + // CAST_ACT +}; + +enum ActivationType { + Identity, + // GeLU, + // SiLU, + // ReLU, + // QGeLU, + // SReLU +}; + +template +void scales_from_amax(float amax, const QuantizationOptions& opts, float* qscale_out, + float* qscale_inv_out) { + float input_type_max_val = Quantized_Limits::max(); + float quant_type_max_val = Quantized_Limits::max(); + float eps = opts.amax_epsilon; + amax = std::max(amax, eps); + float qscale = quant_type_max_val / amax; + if (std::isinf(qscale)) { + qscale = input_type_max_val; + } + if (std::isnan(qscale) || amax == 0) { + qscale = 1.0; + } + + if (opts.force_pow_2_scales && qscale != 0.0) { + uint32_t scale_bits = *reinterpret_cast(&qscale); + // Scale must be positive, shift it + uint8_t exp = scale_bits >> 23; + ASSERT_FALSE(exp == 0) << "Subnormals in this path is a logic error."; + qscale = ldexpf(1.0f, static_cast(exp) - 127); + } + + float qscale_inv = 1.0 / qscale; + *qscale_out = qscale; + *qscale_inv_out = qscale_inv; +} + +template +void ref_quantize(const ProcessingMethod processing_method, const InputType* input, + const std::pair& input_hw, OutputType* output, float* scale_inv, + OutputType* output_t, float* scale_inv_t, const QuantizationOptions& opts) { + constexpr size_t kBlockLenX = kBlockLen; + constexpr size_t kBlockLenY = kBlockLen; + + auto quantize_element = [](InputType element, float qscale) -> OutputType { + // Scale in FP32 and cast result to nearest FP8. + return static_cast(float(element) * qscale); + }; + + size_t height = input_hw.first; + size_t width = input_hw.second; + size_t blocks_x = (width + kBlockLenX - 1) / kBlockLenX; + size_t blocks_y = (height + kBlockLenY - 1) / kBlockLenY; + // Find the absolute maximum value in the block + for (size_t block_x = 0; block_x < blocks_x; ++block_x) { + for (size_t block_y = 0; block_y < blocks_y; ++block_y) { + float amax = 0.0f; + // Calculate amax for a tile. + for (size_t i = 0; i < kBlockLenX; ++i) { + for (size_t j = 0; j < kBlockLenY; ++j) { + size_t x_pos = i + block_x * kBlockLenX; + size_t y_pos = j + block_y * kBlockLenY; + if (y_pos >= height || x_pos >= width) { + continue; + } + float val = static_cast(input[y_pos * width + x_pos]); + amax = std::max(amax, std::abs(val)); + } + } + + // We've calculated amax for a tile. Calculate scale and + // scale_inv and populate outputs. + float qscale, qscale_inv; + scales_from_amax(amax, opts, &qscale, &qscale_inv); + + // NOTE: This reference function outputs contigous scale tensors. + // It calculates a naive scale data format. Strides are handled + // in comparison. + if (scale_inv != nullptr) { + scale_inv[block_y * blocks_x + block_x] = qscale_inv; + } + if (scale_inv_t != nullptr) { + scale_inv_t[block_x * blocks_y + block_y] = qscale_inv; + } + + for (size_t i = 0; i < kBlockLenX; ++i) { + for (size_t j = 0; j < kBlockLenY; ++j) { + size_t x_pos = i + block_x * kBlockLenX; + size_t y_pos = j + block_y * kBlockLenY; + if (y_pos >= height || x_pos >= width) { + continue; + } + if (output != nullptr) { + output[y_pos * width + x_pos] = quantize_element(input[y_pos * width + x_pos], qscale); + } + if (output_t != nullptr) { + output_t[x_pos * height + y_pos] = + quantize_element(input[y_pos * width + x_pos], qscale); + } + } + } + } + } +} + +template +void ref_quantize_onedimensional_blocks(const ProcessingMethod processing_method, + const InputType* input, + const std::pair& input_hw, + OutputType* output, float* scale_inv, OutputType* output_t, + float* scale_inv_t, const QuantizationOptions& opts) { + float input_type_max_val = Quantized_Limits::max(); + float quant_type_max_val = Quantized_Limits::max(); + + constexpr size_t kBlockLenX = kBlockLen; + + auto quantize_element = [](InputType element, float qscale) -> OutputType { + // Scale in FP32 and cast result to nearest FP8. + return static_cast(float(element) * qscale); + }; + + size_t height = input_hw.first; + size_t width = input_hw.second; + size_t blocks_x = (width + kBlockLenX - 1) / kBlockLenX; + size_t blocks_x_t = (height + kBlockLenX - 1) / kBlockLenX; + if (output != nullptr && scale_inv != nullptr) { + // Find the absolute maximum value in the block + for (size_t block_x = 0; block_x < blocks_x; ++block_x) { + for (size_t y = 0; y < height; ++y) { + float amax = 0.0f; + // Calculate amax for a tile. + for (size_t i = 0; i < kBlockLenX; ++i) { + size_t x_pos = i + block_x * kBlockLenX; + if (x_pos >= width) { + continue; + } + float val = static_cast(input[y * width + x_pos]); + amax = std::max(amax, std::abs(val)); + } + + // We've calculated amax for a tile. Calculate scale and + // scale_inv and populate outputs. + float qscale, qscale_inv; + scales_from_amax(amax, opts, &qscale, &qscale_inv); + + scale_inv[y + height * block_x] = qscale_inv; + + for (size_t i = 0; i < kBlockLenX; ++i) { + size_t x_pos = i + block_x * kBlockLenX; + if (x_pos >= width) { + continue; + } + output[y * width + x_pos] = quantize_element(input[y * width + x_pos], qscale); + } + } + } + } + if (output_t != nullptr && scale_inv_t != nullptr) { + // Find the absolute maximum value in the block + for (size_t block_x_t = 0; block_x_t < blocks_x_t; ++block_x_t) { + for (size_t x = 0; x < width; ++x) { + float amax = 0.0f; + // Calculate amax for a tile. + for (size_t i = 0; i < kBlockLenX; ++i) { + size_t y_pos = i + block_x_t * kBlockLenX; + if (y_pos >= height) { + continue; + } + float val = static_cast(input[x + y_pos * width]); + amax = std::max(amax, std::abs(val)); + } + + // We've calculated amax for a tile. Calculate scale and + // scale_inv and populate outputs. + float qscale, qscale_inv; + scales_from_amax(amax, opts, &qscale, &qscale_inv); + + scale_inv_t[x + width * block_x_t] = qscale_inv; + + for (size_t i = 0; i < kBlockLenX; ++i) { + size_t y_pos = i + block_x_t * kBlockLenX; + if (y_pos >= height) { + continue; + } + output_t[x * height + y_pos] = quantize_element(input[y_pos * width + x], qscale); + } + } + } + } +} + +void compare_scaling_factors(const std::string& name, const float* test, const float* ref, + const size_t row_blocks, const size_t col_blocks, + const size_t test_stride, const size_t ref_stride) { + for (int i = 0; i < row_blocks; ++i) { + for (int j = 0; j < col_blocks; ++j) { + const int test_idx = i * test_stride + j; + const int ref_idx = i * ref_stride + j; + ASSERT_FALSE(test[test_idx] != ref[ref_idx]) + << "Error in " << name << std::endl + << "Mismatch: " << test[test_idx] << " vs " << ref[ref_idx] << " at index " << test_idx + << "," << ref_idx; + } + } +} + +void compare_scaling_factors_one_dimensional_blocks(const std::string& name, const float* test, + const float* ref, const size_t rows, + const size_t col_blocks) { + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < col_blocks; ++j) { + const int test_idx = i + rows * j; + const int ref_idx = i + rows * j; + ASSERT_FALSE(test[test_idx] != ref[ref_idx]) + << "Error in " << name << std::endl + << "Mismatch: " << test[test_idx] << " vs " << ref[ref_idx] << " at index " << test_idx + << "," << ref_idx; + } + } +} + +template +void runTestCase(const ProcessingMethod processing_method, const std::vector& shape, + const bool rowwise, const bool colwise, InputsFillCase fill_case, + const QuantizationOptions& opts) { + using namespace test; + using EncodingType = fp32; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t rows = first_dimension(shape); + const size_t cols = last_dimension(shape); + + size_t blocks_x = (cols + kBlockLen - 1) / kBlockLen; + size_t blocks_y = (rows + kBlockLen - 1) / kBlockLen; + + Tensor input("input", shape, itype); + Tensor grad("grad", shape, itype); + Tensor output_c("output_c", shape, otype, rowwise, colwise, NVTE_BLOCK_SCALING, &opts); + Tensor output_dbias("output_dbias", {cols}, itype); + + std::unique_ptr ref_output = std::make_unique(rows * cols); + std::unique_ptr ref_output_t = std::make_unique(rows * cols); + std::unique_ptr ref_scale_inv = std::make_unique(blocks_y * blocks_x); + std::unique_ptr ref_scale_inv_t = std::make_unique(blocks_y * blocks_x); + + if (!rowwise) { + ref_output = nullptr; + ref_scale_inv = nullptr; + } + if (!colwise) { + ref_output_t = nullptr; + ref_scale_inv_t = nullptr; + } + + fillCase(&input, fill_case); + fillUniform(&grad); + + Tensor workspace; + switch (processing_method) { + case ProcessingMethod::CAST_ONLY: { + nvte_quantize(input.data(), output_c.data(), 0); + break; + } + } + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + ref_quantize(processing_method, input.rowwise_cpu_dptr(), + {rows, cols}, ref_output.get(), ref_scale_inv.get(), + ref_output_t.get(), ref_scale_inv_t.get(), opts); + + float atol = 0.0; + float rtol = 0.0; + + auto scale_align_stride = [](size_t inner_elements) -> size_t { + return ((inner_elements + 4u - 1u) / 4u) * 4u; + }; + + if (rowwise) { + compareResults("output_c", output_c, ref_output.get(), true, atol, rtol); + compare_scaling_factors("scale_inv", output_c.rowwise_cpu_scale_inv_ptr(), + ref_scale_inv.get(), blocks_y, blocks_x, scale_align_stride(blocks_x), + blocks_x); + } + if (colwise) { + compareResults("output_c_t", output_c, ref_output_t.get(), false, atol, rtol); + compare_scaling_factors("scale_inv_t", output_c.columnwise_cpu_scale_inv_ptr(), + ref_scale_inv_t.get(), blocks_x, blocks_y, scale_align_stride(blocks_y), + blocks_y); + } +} + +template +void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method, + const std::vector& shape, const bool rowwise, + const bool colwise, InputsFillCase fill_case, + const QuantizationOptions& opts) { + using namespace test; + using EncodingType = fp32; + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + const size_t rows = first_dimension(shape); + const size_t cols = last_dimension(shape); + + size_t blocks_x = (cols + kBlockLen - 1) / kBlockLen; + size_t blocks_x_t = (rows + kBlockLen - 1) / kBlockLen; + + Tensor input("input", shape, itype); + Tensor grad("grad", shape, itype); + Tensor output_c("output_c", shape, otype, rowwise, colwise, NVTE_BLOCK_SCALING, &opts); + Tensor output_dbias("output_dbias", {cols}, itype); + + std::unique_ptr ref_output = std::make_unique(rows * cols); + std::unique_ptr ref_output_t = std::make_unique(rows * cols); + std::unique_ptr ref_scale_inv = std::make_unique(rows * blocks_x); + std::unique_ptr ref_scale_inv_t = std::make_unique(cols * blocks_x_t); + + if (!rowwise) { + ref_output = nullptr; + ref_scale_inv = nullptr; + } + if (!colwise) { + ref_output_t = nullptr; + ref_scale_inv_t = nullptr; + } + + fillCase(&input, fill_case); + fillUniform(&grad); + + Tensor workspace; + switch (processing_method) { + case ProcessingMethod::CAST_ONLY: { + nvte_quantize(input.data(), output_c.data(), 0); + break; + } + } + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + ref_quantize_onedimensional_blocks( + processing_method, input.rowwise_cpu_dptr(), {rows, cols}, ref_output.get(), + ref_scale_inv.get(), ref_output_t.get(), ref_scale_inv_t.get(), opts); + + float atol = 0.0; + float rtol = 0.0; + + if (rowwise) { + compareResults("output_c", output_c, ref_output.get(), true, atol, rtol); + compare_scaling_factors_one_dimensional_blocks("scale_inv", + output_c.rowwise_cpu_scale_inv_ptr(), + ref_scale_inv.get(), rows, blocks_x); + } + if (colwise) { + compareResults("output_c_t", output_c, ref_output_t.get(), false, atol, rtol); + compare_scaling_factors_one_dimensional_blocks("scale_inv_t", + output_c.columnwise_cpu_scale_inv_ptr(), + ref_scale_inv_t.get(), cols, blocks_x_t); + } +} + +std::vector> matrix_sizes = { + {1, 16}, {16, 48}, {65, 96}, {128, 128}, {256, 256}, {993, 512}, + {256, 65536}, {2048, 6144}, {16384, 128}, {32768, 160}, {4096, 1632}, {1024, 1}, + {32, 1024}, {16, 512}, {1024}, {8, 32, 1024}, {16, 8, 4, 512}, +}; + +std::vector input_scenarios = { + InputsFillCase::uniform, +}; + +std::vector processing_methods = { + ProcessingMethod::CAST_ONLY, + // ProcessingMethod::CAST_DBIAS, + // ProcessingMethod::CAST_DBIAS_DACT, + // ProcessingMethod::CAST_DACT, + // ProcessingMethod::CAST_ACT, +}; + +// Only GeLU activation tests are supported +std::vector Activation_types = { + ActivationType::Identity, + // ActivationType::GeLU, + // ActivationType::SiLU, + // ActivationType::ReLU, + // ActivationType::QGeLU, + // ActivationType::SReLU, +}; + + +std::vector amax_epsilons = { + 0.0f, + // Set large epsilon to get observable behavior. + 0.1f, +}; + +} // namespace + +class FusedCastFloat8BlockwiseTestSuite + : public ::testing::TestWithParam, transformer_engine::DType, + transformer_engine::DType, InputsFillCase, bool, float, bool>> {}; + +class FusedCastFloat8VectorwiseTestSuite + : public ::testing::TestWithParam, transformer_engine::DType, + transformer_engine::DType, InputsFillCase, bool, float, bool>> {}; + +#define DACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \ + switch (OP_FUNC_TYPE) { \ + case ActivationType::Identity: { \ + constexpr auto OP = &identity; \ + { \ + __VA_ARGS__ \ + } \ + } break; \ + } + +#define ACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \ + switch (OP_FUNC_TYPE) { \ + case ActivationType::Identity: { \ + constexpr auto OP = &identity; \ + { \ + __VA_ARGS__ \ + } \ + } break; \ + } + +TEST_P(FusedCastFloat8BlockwiseTestSuite, TestFusedCastFloat8Blockwise) { + if (getDeviceComputeCapability() < hopperComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const ProcessingMethod processing_method = std::get<0>(GetParam()); + const ActivationType Act_type = std::get<1>(GetParam()); + const auto matrix_size = std::get<2>(GetParam()); + const DType input_type = std::get<3>(GetParam()); + const DType output_type = std::get<4>(GetParam()); + const InputsFillCase fill_case = std::get<5>(GetParam()); + const bool colwise = std::get<6>(GetParam()); + const bool rowwise = true; + const float eps = std::get<7>(GetParam()); + const bool force_pow_2 = std::get<8>(GetParam()); + + QuantizationOptions q_opts; + q_opts.force_pow_2_scales = force_pow_2; + q_opts.amax_epsilon = eps; + q_opts.block_scaling_dim = 2u; + + if (colwise && matrix_size.size() < 2) { + // test_common Tensor initialization code does not + // handle this case. + GTEST_SKIP(); + } + // Skips non Act tests if the Activation type is not an identity + if ( // (processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS) + (processing_method == ProcessingMethod::CAST_ONLY) && Act_type != ActivationType::Identity) { + GTEST_SKIP(); + } + // Skips Act tests if the Activation is an identity + // if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT + // || processing_method == ProcessingMethod::CAST_DACT + // || processing_method == ProcessingMethod::CAST_ACT) && (Act_type == ActivationType::Identity)) { + // GTEST_SKIP(); + // } + + DACT_FUNC_SWITCH( + Act_type, OP, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY( + input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY( + output_type, OutputType, + runTestCase(processing_method, matrix_size, rowwise, colwise, + fill_case, q_opts);););); +} + +TEST_P(FusedCastFloat8VectorwiseTestSuite, TestFusedCastFloat8Vectorwise) { + if (getDeviceComputeCapability() < hopperComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const ProcessingMethod processing_method = std::get<0>(GetParam()); + const ActivationType Act_type = std::get<1>(GetParam()); + const auto matrix_size = std::get<2>(GetParam()); + const DType input_type = std::get<3>(GetParam()); + const DType output_type = std::get<4>(GetParam()); + const InputsFillCase fill_case = std::get<5>(GetParam()); + const bool colwise = std::get<6>(GetParam()); + const bool rowwise = true; + const float eps = std::get<7>(GetParam()); + const bool force_pow_2 = std::get<8>(GetParam()); + + QuantizationOptions q_opts; + q_opts.force_pow_2_scales = force_pow_2; + q_opts.amax_epsilon = eps; + q_opts.block_scaling_dim = 1u; + + if (colwise && matrix_size.size() < 2) { + // test_common Tensor initialization code does not + // handle this case. + GTEST_SKIP(); + } + // Skips non Act tests if the Activation type is not an identity + if ( // (processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS) + (processing_method == ProcessingMethod::CAST_ONLY) && Act_type != ActivationType::Identity) { + GTEST_SKIP(); + } + // Skips Act tests if the Activation is an identity + // if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT + // || processing_method == ProcessingMethod::CAST_DACT + // || processing_method == ProcessingMethod::CAST_ACT) && (Act_type == ActivationType::Identity)) { + // GTEST_SKIP(); + // } + + DACT_FUNC_SWITCH( + Act_type, OP, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY( + input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY( + output_type, OutputType, + runTestCaseOneDimensionalBlocks( + processing_method, matrix_size, rowwise, colwise, fill_case, q_opts);););); +} + +std::string to_string(const ProcessingMethod method) { + switch (method) { + case ProcessingMethod::CAST_ONLY: + return "CAST_ONLY"; + // case ProcessingMethod::CAST_DBIAS: return "CAST_DBIAS"; + // case ProcessingMethod::CAST_DBIAS_DACT: return "CAST_DBIAS_DACT"; + // case ProcessingMethod::CAST_DACT: return "CAST_DACT"; + // case ProcessingMethod::CAST_ACT: return "CAST_ACT"; + default: + return ""; + } +} + +std::string to_string(const ActivationType Act_type) { + switch (Act_type) { + case ActivationType::Identity: + return "Identity"; + // case ActivationType::GeLU: return "GeLU"; + // case ActivationType::SiLU: return "SiLU"; + // case ActivationType::ReLU: return "ReLU"; + // case ActivationType::QGeLU: return "QGeLU"; + // case ActivationType::SReLU: return "SReLU"; + default: + return ""; + } +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, FusedCastFloat8BlockwiseTestSuite, + ::testing::Combine(::testing::ValuesIn(processing_methods), + ::testing::ValuesIn(Activation_types), ::testing::ValuesIn(matrix_sizes), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(input_scenarios), ::testing::Values(true, false), + ::testing::ValuesIn(amax_epsilons), ::testing::Values(false, true)), + [](const testing::TestParamInfo& info) { + std::string name = + to_string(std::get<0>(info.param)) + "X" + to_string(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for (const auto& s : shape) { + name += "X" + std::to_string(s); + } + name += "X" + test::typeName(std::get<3>(info.param)) + "X" + + test::typeName(std::get<4>(info.param)) + "X" + + test::caseName(std::get<5>(info.param)) + "X" + + std::to_string(std::get<6>(info.param)) + "X" + + std::to_string(std::get<7>(info.param) != 0.0f) + "X" + + std::to_string(std::get<8>(info.param)); + return name; + }); + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, FusedCastFloat8VectorwiseTestSuite, + ::testing::Combine(::testing::ValuesIn(processing_methods), + ::testing::ValuesIn(Activation_types), ::testing::ValuesIn(matrix_sizes), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), + ::testing::ValuesIn(input_scenarios), ::testing::Values(true, false), + ::testing::ValuesIn(amax_epsilons), ::testing::Values(false, true)), + [](const testing::TestParamInfo& info) { + std::string name = + to_string(std::get<0>(info.param)) + "X" + to_string(std::get<1>(info.param)); + const auto& shape = std::get<2>(info.param); + for (const auto& s : shape) { + name += "X" + std::to_string(s); + } + name += "X" + test::typeName(std::get<3>(info.param)) + "X" + + test::typeName(std::get<4>(info.param)) + "X" + + test::caseName(std::get<5>(info.param)) + "X" + + std::to_string(std::get<6>(info.param)) + "X" + + std::to_string(std::get<7>(info.param) != 0.0f) + "X" + + std::to_string(std::get<8>(info.param)); + return name; + }); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index ec4a9bdbb7..78bbb97687 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -118,7 +119,8 @@ NVTEShape convertShape(const std::vector& shape) { } std::pair get_scales(const NVTEShape& shape, - const NVTEScalingMode scaling_mode) { + const NVTEScalingMode scaling_mode, + const int block_scaling_dim) { if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { scale_inv_meta ret; ret.shape = {1}; @@ -136,27 +138,19 @@ std::pair get_scales(const NVTEShape& shape, scale_inv_meta ret_rowwise, ret_colwise; - auto block_alignment = std::vector{128ul,4ul}; + auto block_alignment = std::vector{128ul, 4ul}; { auto alignment = block_alignment[0]; - auto scale_dim_0 = DIVUP(DIVUP(first_dim, - static_cast(1)), - alignment) * alignment; + auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast(1)), alignment) * alignment; alignment = block_alignment[1]; - auto scale_dim_1 = DIVUP(DIVUP(last_dim, - static_cast(32)), - alignment) * alignment; + auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(32)), alignment) * alignment; ret_rowwise.shape = {scale_dim_0, scale_dim_1}; } { auto alignment = block_alignment[1]; - auto scale_dim_0 = DIVUP(DIVUP(first_dim, - static_cast(32)), - alignment) * alignment; + auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast(32)), alignment) * alignment; alignment = block_alignment[0]; - auto scale_dim_1 = DIVUP(DIVUP(last_dim, - static_cast(1)), - alignment) * alignment; + auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(1)), alignment) * alignment; ret_colwise.shape = {scale_dim_0, scale_dim_1}; } ret_rowwise.type = DType::kFloat8E8M0; @@ -166,6 +160,61 @@ std::pair get_scales(const NVTEShape& shape, return {ret_rowwise, ret_colwise}; } + if (scaling_mode == NVTE_BLOCK_SCALING) { + if (block_scaling_dim == 2) { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + size_t first_dim = first_dimension(shape_vec); + size_t last_dim = last_dimension(shape_vec); + + scale_inv_meta ret_rowwise, ret_colwise; + + { + auto scale_dim_0 = DIVUP(first_dim, static_cast(128)); + auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(128)), 4) * 4; + ret_rowwise.shape = {scale_dim_0, scale_dim_1}; + } + { + auto scale_dim_0 = DIVUP(last_dim, static_cast(128)); + auto scale_dim_1 = DIVUP(DIVUP(first_dim, static_cast(128)), 4) * 4; + ret_colwise.shape = {scale_dim_0, scale_dim_1}; + } + ret_rowwise.type = DType::kFloat32; + ret_colwise.type = DType::kFloat32; + ret_rowwise.type_size = sizeof(float); + ret_colwise.type_size = sizeof(float); + + return {ret_rowwise, ret_colwise}; + } else if (block_scaling_dim == 1) { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); + } + size_t first_dim = first_dimension(shape_vec); + size_t last_dim = last_dimension(shape_vec); + scale_inv_meta ret_rowwise, ret_colwise; + + { + auto scale_dim_1 = first_dim; + auto scale_dim_0 = DIVUP(last_dim, static_cast(128)); + ret_rowwise.shape = {scale_dim_0, scale_dim_1}; + } + { + auto scale_dim_1 = last_dim; + auto scale_dim_0 = DIVUP(first_dim, static_cast(128)); + ret_colwise.shape = {scale_dim_0, scale_dim_1}; + } + ret_rowwise.type = DType::kFloat32; + ret_colwise.type = DType::kFloat32; + ret_rowwise.type_size = sizeof(float); + ret_colwise.type_size = sizeof(float); + return {ret_rowwise, ret_colwise}; + } else { + NVTE_ERROR("Unsupported block scaling dim!"); + } + } NVTE_ERROR("Invalid scaling mode!"); } @@ -173,7 +222,8 @@ std::pair get_scales(const NVTEShape& shape, Tensor::Tensor(const std::string& name, const NVTEShape &shape, const DType type, const bool rowwise, const bool columnwise, - const NVTEScalingMode &scaling_mode) { + const NVTEScalingMode &scaling_mode, + const QuantizationOptions* q_opts) { name_ = name; const size_t seed = create_seed_from_tensor_name(name); gen_.seed(seed); @@ -197,9 +247,12 @@ Tensor::Tensor(const std::string& name, std::vector normalized_shape_v = {product(shape, 0, shape.ndim - 1), shape.data[shape.ndim - 1]}; NVTEShape normalized_shape = convertShape(normalized_shape_v); - + size_t block_scaling_dim = 0; + if (q_opts != nullptr) { + block_scaling_dim = q_opts->block_scaling_dim; + } std::vector columnwise_shape_vec; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING) { // Transpose when tensor scaling columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]); for (size_t i = 0; i < shape.ndim - 1; ++i) { @@ -256,27 +309,34 @@ Tensor::Tensor(const std::string& name, std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0); } } else { - auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape, - tensor_.scaling_mode()); + auto [rowwise_scale_meta, colwise_scale_meta] = + get_scales(normalized_shape, tensor_.scaling_mode(), block_scaling_dim); auto rowwise_scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; auto columnwise_scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size; auto scale_shape = rowwise_scale_meta.shape; auto columnwise_scale_shape = colwise_scale_meta.shape; if (rowwise) { - cudaMalloc((void**)&rowwise_scale_inv, rowwise_scale_size); // NOLINT(*) + cudaMalloc((void **)&rowwise_scale_inv, rowwise_scale_size); // NOLINT(*) cudaMemset(rowwise_scale_inv, 0, rowwise_scale_size); rowwise_scale_inv_cpu_data_ = std::make_unique(rowwise_scale_size); std::fill_n(rowwise_scale_inv_cpu_data_.get(), rowwise_scale_size, 0); - tensor_.set_rowwise_scale_inv(rowwise_scale_inv, DType::kFloat8E8M0, scale_shape); + auto scale_dtype = rowwise_scale_meta.type; + tensor_.set_rowwise_scale_inv(rowwise_scale_inv, scale_dtype, scale_shape); } if (columnwise) { cudaMalloc((void**)&columnwise_scale_inv, columnwise_scale_size); // NOLINT(*) cudaMemset(columnwise_scale_inv, 0, columnwise_scale_size); columnwise_scale_inv_cpu_data_ = std::make_unique(columnwise_scale_size); std::fill_n(columnwise_scale_inv_cpu_data_.get(), columnwise_scale_size, 0); - tensor_.set_columnwise_scale_inv(columnwise_scale_inv, DType::kFloat8E8M0, columnwise_scale_shape); + auto scale_dtype = colwise_scale_meta.type; + tensor_.set_columnwise_scale_inv(columnwise_scale_inv, scale_dtype, columnwise_scale_shape); } } + if (q_opts != nullptr) { + tensor_.set_qopt_force_pow_2_scales(q_opts->force_pow_2_scales); + tensor_.set_qopt_amax_epsilon(q_opts->amax_epsilon); + tensor_.set_qopt_block_scaling_dim(q_opts->block_scaling_dim); + } } } @@ -306,7 +366,8 @@ void Tensor::to_cpu() const { sizeof(float), cudaMemcpyDeviceToHost); } - auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); + auto [rowwise_scale_meta, colwise_scale_meta] = + get_scales(s, tensor_.scaling_mode(), tensor_.get_qopt_block_scaling_dim()); if (rowwise_) { auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; cudaMemcpy(rowwise_scale_inv_cpu_data_.get(), @@ -342,7 +403,8 @@ void Tensor::from_cpu() const { cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice); } - auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); + auto [rowwise_scale_meta, colwise_scale_meta] = + get_scales(s, tensor_.scaling_mode(), tensor_.get_qopt_block_scaling_dim()); if (rowwise_) { auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size; cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr, @@ -361,7 +423,7 @@ void Tensor::from_cpu() const { void Tensor::set_scale(float scale) { if (isFp8Type(dtype())) { NVTE_CHECK(scale_cpu_data_); - if (is_tensor_scaling(tensor_.scaling_mode())) { + if (is_tensor_scaling(tensor_.scaling_mode())) { *scale_cpu_data_ = scale; from_cpu(); } @@ -376,27 +438,29 @@ void Tensor::set_scale_inv(float scale_inv) { if (columnwise_) { NVTE_CHECK(columnwise_scale_inv_cpu_data_); } - auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(tensor_.shape(), tensor_.scaling_mode()); + + auto [rowwise_scale_meta, colwise_scale_meta] = + get_scales(tensor_.shape(), tensor_.scaling_mode(), tensor_.get_qopt_block_scaling_dim()); if (rowwise_) { auto num_scales = product(rowwise_scale_meta.shape); - if (num_scales == 1){ + if (num_scales == 1) { rowwise_cpu_scale_inv_ptr()[0] = scale_inv; - } else{ + } else { std::uniform_int_distribution dis(0, 127); - auto* scale_inv_ptr = rowwise_cpu_scale_inv_ptr(); - for (size_t i = 0; i < num_scales; i++){ + auto *scale_inv_ptr = rowwise_cpu_scale_inv_ptr(); + for (size_t i = 0; i < num_scales; i++) { scale_inv_ptr[i] = dis(gen_); } } } if (columnwise_) { auto num_scales = product(colwise_scale_meta.shape); - if (num_scales == 1){ + if (num_scales == 1) { columnwise_cpu_scale_inv_ptr()[0] = scale_inv; - } else{ + } else { std::uniform_int_distribution dis(0, 127); - auto* scale_inv_ptr = columnwise_cpu_scale_inv_ptr(); - for (size_t i = 0; i < num_scales; i++){ + auto *scale_inv_ptr = columnwise_cpu_scale_inv_ptr(); + for (size_t i = 0; i < num_scales; i++) { scale_inv_ptr[i] = dis(gen_); } } @@ -406,23 +470,20 @@ void Tensor::set_scale_inv(float scale_inv) { } void Tensor::shareFP8Meta(const Tensor &other) { - if(isFp8Type(dtype()) && isFp8Type(other.dtype())) { + if (isFp8Type(dtype()) && isFp8Type(other.dtype())) { auto new_tensor = TensorWrapper(other.tensor_.scaling_mode()); auto my_rowwise_data = tensor_.get_rowwise_data(); - new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, - static_cast(my_rowwise_data.dtype), + new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, static_cast(my_rowwise_data.dtype), my_rowwise_data.shape); auto my_columnwise_data = tensor_.get_columnwise_data(); new_tensor.set_columnwise_data(my_columnwise_data.data_ptr, static_cast(my_columnwise_data.dtype), my_columnwise_data.shape); auto other_amax = other.tensor_.get_amax(); - new_tensor.set_amax(other_amax.data_ptr, - static_cast(other_amax.dtype), + new_tensor.set_amax(other_amax.data_ptr, static_cast(other_amax.dtype), other_amax.shape); auto other_scale = other.tensor_.get_scale(); - new_tensor.set_scale(other_scale.data_ptr, - static_cast(other_scale.dtype), + new_tensor.set_scale(other_scale.data_ptr, static_cast(other_scale.dtype), other_scale.shape); auto other_row_scale_inv = other.tensor_.get_rowwise_scale_inv(); new_tensor.set_rowwise_scale_inv(other_row_scale_inv.data_ptr, @@ -453,9 +514,7 @@ std::string to_string(const std::vector &v) { std::vector unravel(const size_t i, const NVTEShape &shape) { std::vector ret; size_t current_i = i; - for (size_t current = shape.ndim - 1; - current > 0; - --current) { + for (size_t current = shape.ndim - 1; current > 0; --current) { ret.push_back(current_i % shape.data[current]); current_i /= shape.data[current]; } @@ -693,7 +752,7 @@ void fillCase_special(Tensor *t) { }); } else { double minAbs = -2.0; - double maxAbs = 1.0; + double maxAbs = 1.0; if constexpr (Case != InputsFillCase::uniform) { minAbs = Quantized_Limits::ranges[Case]; maxAbs = Quantized_Limits::ranges[Case + 1]; @@ -752,14 +811,13 @@ void setRandomScaleInv(Tensor *t) { } bool isFp8Type(DType type) { - return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0; + return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0; } -int32_t getDeviceComputeCapability() -{ - cudaDeviceProp deviceProp; - cudaGetDeviceProperties(&deviceProp, 0); - return 10 * deviceProp.major + deviceProp.minor; +int32_t getDeviceComputeCapability() { + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, 0); + return 10 * deviceProp.major + deviceProp.minor; } size_t first_dimension(const std::vector &shape) { diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index dc515ccb8e..01134a23ca 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -95,21 +95,29 @@ struct TypeInfo{ constexpr static size_t size = sizeof(T); }; +struct QuantizationOptions { + bool force_pow_2_scales = false; + float amax_epsilon = 0.0; + size_t block_scaling_dim = 2u; +}; + class Tensor { public: Tensor(const std::string& name, const NVTEShape &shape, const DType type, const bool rowwise = true, const bool columnwise = false, - const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING); + const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING, + const QuantizationOptions* q_opts = nullptr); Tensor(const std::string& name, const std::vector &shape, const DType type, const bool rowwise = true, const bool columnwise = false, - const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) : - Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode) {} + const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING, + const QuantizationOptions* q_opts = nullptr) : + Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode, q_opts) {} Tensor() {} @@ -136,25 +144,19 @@ class Tensor { if (scale_inv != nullptr) { cudaFree(scale_inv); } - if (columnwise_data_ptr != nullptr){ + if (columnwise_data_ptr != nullptr) { cudaFree(columnwise_data_ptr); } - if (columnwise_scale_inv != nullptr){ + if (columnwise_scale_inv != nullptr) { cudaFree(columnwise_scale_inv); } } - NVTETensor data() const noexcept { - return tensor_.data(); - } + NVTETensor data() const noexcept { return tensor_.data(); } - NVTEShape rowwise_shape() const noexcept { - return tensor_.get_rowwise_data().shape; - } + NVTEShape rowwise_shape() const noexcept { return tensor_.get_rowwise_data().shape; } - NVTEShape columnwise_shape() const noexcept { - return tensor_.get_columnwise_data().shape; - } + NVTEShape columnwise_shape() const noexcept { return tensor_.get_columnwise_data().shape; } NVTEShape rowwise_scale_inv_shape() const { NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); @@ -221,6 +223,8 @@ class Tensor { T *rowwise_cpu_scale_inv_ptr(){ if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); + } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING) { + NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else { NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); } @@ -232,6 +236,8 @@ class Tensor { T *columnwise_cpu_scale_inv_ptr(){ if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); + } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING) { + NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else { NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); } @@ -455,6 +461,7 @@ extern std::vector all_fp_types; bool isFp8Type(DType type); int32_t getDeviceComputeCapability(); +constexpr int32_t hopperComputeCapability = 90; constexpr int32_t blackwellComputeCapability = 100; } // namespace test diff --git a/tests/pytorch/references/blockwise_quantizer_reference.py b/tests/pytorch/references/blockwise_quantizer_reference.py new file mode 100644 index 0000000000..72cb062c31 --- /dev/null +++ b/tests/pytorch/references/blockwise_quantizer_reference.py @@ -0,0 +1,361 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import dataclasses +import math +import torch +from typing import Optional, Protocol, Tuple + + +@dataclasses.dataclass() +class QuantizeResult: + data: torch.Tensor + scale: torch.Tensor + data_t: Optional[torch.Tensor] + scale_t: Optional[torch.Tensor] + + +# FIXME(kwyss): Put this in a common location for per-tensor current +# scaling reference +def _scale_from_amax_tensor( + x_dtype: torch.dtype, + amax: torch.Tensor, + quant_dtype: torch.dtype, + *, + eps: float, + pow_2_scales: bool, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Derives quantization and dequantization from amax and options. + + Reference implementation for scale calculation. + + Returns: + - scale: quantization scales + - scale_inv: dequantization scales + - amax: Amax tensor with updates made for extrema values. + """ + assert amax.dtype == torch.float, "amax must be a float tensor." + fp8_max = torch.finfo(quant_dtype).max + # Clamping amax to avoid division by small numbers + amax = torch.max(amax, torch.tensor(eps)) + + # Compute scale factor + scale = torch.div(fp8_max, amax) + # Note frexp doesn't give back inf for exponent with an inf input + # We take care of inf before pow_2_scales + scale = torch.where(scale == torch.inf, torch.finfo(x_dtype).max, scale) + if pow_2_scales: + # Calculate rounded down exponent + _, exp = torch.frexp(scale) + # Positive numbers are always returned as mant, exp with + # a mantissa in [0.5, 1.0). Because a normal float has a mantissa with + # hidden bit in [1.0, 2.0), the exponent will be off by exactly one because + # of the shift. Subnormal and zero cases need not be considered because + # the smallest possible result of fp8_max / amax is still normal. + exp = exp - 1 + # No subnormals and zero. + assert (exp > -127).all() + unity = torch.tensor([1.0], device=exp.device) + torch.ldexp(unity, exp, out=scale) + # Case where amax is inf. The frexp, ldexp logic changes 0.0 scales + # Return 0.0 for 0.0 scale for consistency with non-pow2 scale + # calculation. + scale = torch.where(amax == float("inf"), 0.0, scale) + + # Handle overflow cases for amax zero causing NaN + scale = torch.where(amax == 0, 1.0, scale) + + # Compute scale_inv + scale_inv = torch.reciprocal(scale) + + return scale, scale_inv, amax + + +@dataclasses.dataclass() +class CuBLASScaleMunger: + + def munge_scale_shapes_for_backend( + self, + unmunged: QuantizeResult, + tile_shape: Tuple[int, int], + ) -> QuantizeResult: + """ + cuBLAS GEMMs requires 1x128 quantized tensors to be have scales transposed + so that for an (M, N) tensor, the scales are (RounUpDiv(N, 128), M) + + For 128x128 quantized tensors, the GEMM expects (M, PadToAlign(RoundUpDivide(N, 128), 4)) + format. If RoundUpDivide(N, 128) is not divisible by 4, a transformation is required + """ + if tile_shape[0] != 1: + # 2D block quantized tensor needs padding for cuBLAS GEMM. + def _munge_scale_tensor(s: torch.Tensor) -> torch.Tensor: + M, K = s.shape + if K % 4 == 0: + return s + k_pad = 4 - (K % 4) + return torch.nn.functional.pad(s, (0, k_pad), mode="constant", value=0).contiguous() + + s = _munge_scale_tensor(unmunged.scale) + if unmunged.scale_t is None: + s_t = None + else: + s_t = _munge_scale_tensor(unmunged.scale_t) + return QuantizeResult(unmunged.data, s, unmunged.data_t, s_t) + + # 1D block quantized tensors needs transpose to prepare for the GEMM. + s = unmunged.scale.transpose(-1, -2).contiguous() + if unmunged.scale_t is None: + s_t = None + else: + s_t = unmunged.scale_t.transpose(-1, -2).contiguous() + return QuantizeResult(unmunged.data, s, unmunged.data_t, s_t) + + def demunge_scale_shape_from_backend( + cls, + qtensor_shape: Tuple[int, int], + scales: torch.Tensor, + tile_shape: Tuple[int, int], + ) -> torch.Tensor: + """ + Inverse operation of munge_scale_shapes_for_backend + """ + if tile_shape[0] != 1: + # 2D block quantized tensor may need padding stripped off + derived_scale_k_shape = math.ceil(qtensor_shape[1] / tile_shape[1]) + M, K = scales.shape + if derived_scale_k_shape == K: + return scales + else: + return scales[:, :derived_scale_k_shape].contiguous() + return scales.transpose(-1, -2).contiguous() + + +@dataclasses.dataclass() +class BlockwiseQuantizerReference: + """ + A reference QuantizeOp for subchannel/block hybrid quantization. + + Defers to ref GEMMs and quantizization formatting based on the backend. + """ + + def __init__(self) -> None: + self.scale_munger = CuBLASScaleMunger() + + @classmethod + def _quantize_square_block_tiling( + cls, + x: torch.Tensor, + quant_dtype: torch.dtype, + tile_len: int, + *, + return_transpose: bool, + pow_2_scales: bool, + eps: float, + ) -> QuantizeResult: + M, K = x.shape + + pad_m_k = [0, 0] + if K % tile_len != 0: + pad_m_k[1] = tile_len - (K % tile_len) + if M % tile_len != 0: + pad_m_k[0] = tile_len - (M % tile_len) + + unpadded_m, unpadded_k = M, K + if pad_m_k[0] != 0 or pad_m_k[1] != 0: + x = torch.nn.functional.pad( + x, (0, pad_m_k[1], 0, pad_m_k[0]), mode="constant", value=0 + ).contiguous() + M, K = x.shape + + x_tiled = x.reshape(M // tile_len, tile_len, K // tile_len, tile_len) + amax_grid = ( + torch.abs(x_tiled.transpose(-3, -2)) + .reshape(M // tile_len, K // tile_len, tile_len**2) + .amax(dim=-1) + ).float() + dtype_max = torch.finfo(quant_dtype).max + + scale, scale_inv, _ = _scale_from_amax_tensor( + x_dtype=x.dtype, + amax=amax_grid, + quant_dtype=quant_dtype, + pow_2_scales=pow_2_scales, + eps=eps, + ) + qx = x_tiled * scale.reshape(M // tile_len, 1, K // tile_len, 1) + qx = torch.clamp(qx, min=-dtype_max, max=dtype_max) + qx = qx.to(dtype=quant_dtype) + qx = qx.reshape(M, K) + if unpadded_k != K or unpadded_m != M: + qx = qx[:unpadded_m, :unpadded_k].contiguous() + if return_transpose: + # Valid because of square block sizes + qx_t = qx.transpose(-1, -2).contiguous() + scale_inv_t = scale_inv.transpose(-1, -2).contiguous() + else: + qx_t = None + scale_inv_t = None + + return QuantizeResult(data=qx, scale=scale_inv, data_t=qx_t, scale_t=scale_inv_t) + + @classmethod + def _quantize_vectorwise_reference( + cls, + x: torch.Tensor, + quant_dtype: torch.dtype, + tile_len: int, + *, + pow_2_scales: bool, + eps: float, + ) -> Tuple[torch.Tensor, torch.Tensor]: + M, K = x.shape + dtype_max = torch.finfo(quant_dtype).max + x_tiled = x.reshape(M, K // tile_len, tile_len) + amax_grid = torch.abs(x_tiled).amax(dim=-1).float() + scale, scale_inv, _ = _scale_from_amax_tensor( + x_dtype=x.dtype, + amax=amax_grid, + quant_dtype=quant_dtype, + pow_2_scales=pow_2_scales, + eps=eps, + ) + qx = x_tiled * scale.reshape(M, K // tile_len, 1) + qx = torch.clamp(qx, min=-dtype_max, max=dtype_max) + qx = qx.to(dtype=quant_dtype) + qx = qx.reshape(M, K) + return qx, scale_inv + + @classmethod + def _quantize_vector_tiling( + cls, + x: torch.Tensor, + quant_dtype: torch.dtype, + tile_len: int, + *, + return_transpose: bool, + pow_2_scales: bool, + eps: float, + ) -> QuantizeResult: + M, K = x.shape + + if K % tile_len == 0: + qref_input = x + else: + pad_amount = tile_len - (K % tile_len) + pad = (0, pad_amount) + qref_input = torch.nn.functional.pad(x, pad, mode="constant", value=0) + qout_padded, scale_inv = cls._quantize_vectorwise_reference( + qref_input, + quant_dtype, + tile_len=tile_len, + pow_2_scales=pow_2_scales, + eps=eps, + ) + if K % tile_len == 0: + qout = qout_padded + else: + qout = qout_padded[:, :K].contiguous() + + if return_transpose: + if M % tile_len == 0: + qref_input = x.transpose(-1, -2).contiguous() + else: + amount_to_pad = tile_len - (M % tile_len) + pad = (0, amount_to_pad) + qref_input = torch.nn.functional.pad( + x.transpose(-1, -2), pad, mode="constant", value=0 + ).contiguous() + qout_t_padded, scale_inv_t = cls._quantize_vectorwise_reference( + qref_input, + quant_dtype, + tile_len=tile_len, + pow_2_scales=pow_2_scales, + eps=eps, + ) + if M % tile_len == 0: + qout_t = qout_t_padded + else: + qout_t = qout_t_padded[:, :M].contiguous() + else: + qout_t, scale_inv_t = None, None + + return QuantizeResult(data=qout, scale=scale_inv, data_t=qout_t, scale_t=scale_inv_t) + + def ref_dequantize_rowwise( + self, + q: torch.Tensor, + quant_tile_shape: Tuple[int, int], + s: torch.Tensor, + dtype: torch.dtype, + ) -> torch.Tensor: + assert q.dim() == 2 + q_M, q_K = q.shape + s = self.scale_munger.demunge_scale_shape_from_backend((q_M, q_K), s, quant_tile_shape) + assert len(s.shape) == 2 + m_tiles, k_tiles = s.shape + M, K = q.shape + unpadded_m, unpadded_k = M, K + if M % quant_tile_shape[0] != 0 or K % quant_tile_shape[1] != 0: + m_pad_amount = (quant_tile_shape[0] - (M % quant_tile_shape[0])) % quant_tile_shape[0] + k_pad_amount = (quant_tile_shape[1] - (K % quant_tile_shape[1])) % quant_tile_shape[1] + q = torch.nn.functional.pad( + q, (0, k_pad_amount, 0, m_pad_amount), mode="constant", value=0 + ).contiguous() + M, K = q.shape + q_tiled = q.reshape(m_tiles, quant_tile_shape[0], k_tiles, quant_tile_shape[1]) + result = q_tiled.to(dtype) * s.reshape(m_tiles, 1, k_tiles, 1) + result = result.view(M, K).to(dtype) + if M != unpadded_m or K != unpadded_k: + result = result[:unpadded_m, :unpadded_k].contiguous() + return result + + def quantize( + self, + x: torch.Tensor, + quant_dtype: torch.dtype, + return_transpose: bool = False, + eps: float = 0.0, + pow_2_scales: bool = False, + quant_tile_shape: Tuple[int, int] = (128, 128), + ) -> QuantizeResult: + # sanity checks + assert x.dim() == 2 + assert x.dtype in ( + torch.float, + torch.float16, + torch.bfloat16, + torch.float32, + ), "Unsupported input dtype." + assert quant_dtype in ( + torch.float8_e4m3fn, + torch.float8_e5m2, + ), "Unsupported quant dtype." + + assert quant_tile_shape in ((1, 128), (128, 128)) + if quant_tile_shape[0] == 1: + # Quantize row-wise + return self.scale_munger.munge_scale_shapes_for_backend( + self._quantize_vector_tiling( + x, + quant_dtype, + tile_len=quant_tile_shape[1], + return_transpose=return_transpose, + pow_2_scales=pow_2_scales, + eps=eps, + ), + quant_tile_shape, + ) + else: + # Quantize block-wise + return self.scale_munger.munge_scale_shapes_for_backend( + self._quantize_square_block_tiling( + x, + quant_dtype, + tile_len=quant_tile_shape[0], + return_transpose=return_transpose, + pow_2_scales=pow_2_scales, + eps=eps, + ), + quant_tile_shape, + ) diff --git a/tests/pytorch/test_float8_blockwise_scaling_exact.py b/tests/pytorch/test_float8_blockwise_scaling_exact.py new file mode 100644 index 0000000000..16647184d6 --- /dev/null +++ b/tests/pytorch/test_float8_blockwise_scaling_exact.py @@ -0,0 +1,291 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from typing import Tuple +import math +import pytest +import torch +import transformer_engine as te +import transformer_engine_torch as tex + +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + Float8BlockQuantizer, + Float8BlockwiseQTensor, +) +from tests.pytorch.references.blockwise_quantizer_reference import ( + BlockwiseQuantizerReference, + QuantizeResult, +) + + +def initialize_for_many_scales( + x_shape_2d: Tuple[int, int], tile_shape: Tuple[int, int], *, dtype: torch.dtype, device: str +) -> torch.Tensor: + """ + Put separate distributions into each quantization tile + to avoid many tiles having similar scale values and + causing false passes. + """ + tile_grid_shape = ( + math.ceil(x_shape_2d[0] / tile_shape[0]), + math.ceil(x_shape_2d[1] / tile_shape[1]), + ) + # Arbitrary size + max_val = 8192.0 + # Make a uniform distribution of [-max_val, max_val] + tile_extrema = torch.rand(*tile_grid_shape, dtype=dtype) * max_val * 2 - max_val + result = torch.empty(x_shape_2d, dtype=dtype, device=device) + tile_elements = tile_shape[0] * tile_shape[1] + for i in range(tile_grid_shape[0]): + for j in range(tile_grid_shape[1]): + target = tile_extrema[i, j].item() + step = target / (tile_elements) + if target == 0: + tile = torch.zeros(tile_shape, dtype=dtype, device=device) + else: + tile = torch.arange(0.0, target, step=step, dtype=dtype, device=device) + tile = tile.reshape(*tile_shape) + min_dst_vals = (i * tile_shape[0], j * tile_shape[1]) + max_dst_vals = ( + min((i + 1) * tile_shape[0], x_shape_2d[0]), + min((j + 1) * tile_shape[1], x_shape_2d[1]), + ) + max_src_vals = ( + max_dst_vals[0] - min_dst_vals[0], + max_dst_vals[1] - min_dst_vals[1], + ) + result[min_dst_vals[0] : max_dst_vals[0], min_dst_vals[1] : max_dst_vals[1]] = tile[ + : max_src_vals[0], : max_src_vals[1] + ] + return result + + +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (128, 128), + (256, 256), + (256, 1024), + (1024, 256), + # Padding required cases + (256, 272), + (300, 300), + (305, 256), + # Some larger tiles. + (2000, 2000), + (2048, 2000), + (2000, 1024), + (2048, 1024), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("eps", [0, 1e-12], ids=["eps_0", "eps_1e-12"]) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "quantize_only"] +) +@pytest.mark.parametrize("pow_2_scales", [True, False], ids=["pow2scales", "f32scales"]) +@pytest.mark.parametrize("tile_size", [(1, 128), (128, 128)], ids=["1DTile", "2DTile"]) +def test_quantization_block_tiling_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + quant_dtype: torch.dtype, + eps: float, + return_transpose: bool, + pow_2_scales: bool, + tile_size: Tuple[int, int], +) -> None: + te_dtype = TE_DType[quant_dtype] + if tile_size == (1, 128): + block_scaling_dim = 1 + elif tile_size == (128, 128): + block_scaling_dim = 2 + else: + raise ValueError("Non support tile size") + # This test runs a comparison of the ref class versus the class using + # CUDA kernels to quantize. They should quantize identically for pixels + # that are not DC values in the scale factor shape. + ref_quantizer = BlockwiseQuantizerReference() + sut_quantizer = Float8BlockQuantizer( + fp8_dtype=te_dtype, + rowwise=True, + columnwise=return_transpose, + amax_epsilon=eps, + force_pow_2_scales=pow_2_scales, + block_scaling_dim=block_scaling_dim, + ) + + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Input + x = initialize_for_many_scales((M, N), tile_size, dtype=x_dtype, device=device) + + x_fp8_sut = sut_quantizer.make_empty((M, N), dtype=x_dtype, device=device, requires_grad=False) + x_fp8_sut = sut_quantizer.update_quantized(x, x_fp8_sut) + + assert x_fp8_sut._rowwise_data is not None + qx: torch.Tensor = x_fp8_sut._rowwise_data.view(dtype=quant_dtype) + assert x_fp8_sut._rowwise_scale_inv is not None + sx: torch.Tensor = x_fp8_sut._rowwise_scale_inv + qx_t = x_fp8_sut._columnwise_data + sx_t = x_fp8_sut._columnwise_scale_inv + + qresult_ref = ref_quantizer.quantize( + x, + quant_dtype=quant_dtype, + return_transpose=return_transpose, + eps=eps, + pow_2_scales=pow_2_scales, + quant_tile_shape=tile_size, + ) + qx_ref, sx_ref, qx_t_ref, sx_t_ref = ( + qresult_ref.data, + qresult_ref.scale, + qresult_ref.data_t, + qresult_ref.scale_t, + ) + + # Check + torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0, rtol=0.0) + if tile_size[0] != 1: + # Zero out values that are don't care values + # cuBLAS has padding of 2D tensors. + scale_mask = torch.ones( + (math.ceil(M / tile_size[0]), math.ceil(N / tile_size[1])), device=sx.device + ) + scale_mask = ref_quantizer.scale_munger.munge_scale_shapes_for_backend( + QuantizeResult(qx, scale_mask, None, None), tile_size + ).scale + sx = sx * scale_mask + + torch.testing.assert_close(sx, sx_ref, atol=0.0, rtol=0.0) + + if return_transpose: + assert qx_t is not None + qx_t = qx_t.view(dtype=quant_dtype) + assert qx_t_ref is not None + assert sx_t is not None + assert sx_t_ref is not None + if tile_size[0] != 1: + scale_mask = torch.ones( + (math.ceil(N / tile_size[0]), math.ceil(M / tile_size[1])), + device=sx_t.device, + ) + scale_mask = ref_quantizer.scale_munger.munge_scale_shapes_for_backend( + QuantizeResult(qx_t, scale_mask, None, None), tile_size + ).scale + sx_t = sx_t * scale_mask + torch.testing.assert_close(qx_t.float(), qx_t_ref.float(), atol=0.0, rtol=0.0) + torch.testing.assert_close(sx_t, sx_t_ref, atol=0.0, rtol=0.0) + else: + # should be None + assert qx_t is None and qx_t_ref is None + assert sx_t is None and sx_t_ref is None + + +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (1, 128), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("quant_dtype", [torch.float8_e4m3fn, torch.float8_e5m2], ids=str) +@pytest.mark.parametrize("eps", [0, math.pow(2, -125)], ids=["eps_0", "eps_small"]) +@pytest.mark.parametrize("pow_2_scales", [True, False], ids=["pow2scales", "f32scales"]) +@pytest.mark.parametrize("tile_size", [(1, 128)]) +@pytest.mark.parametrize("extrema_high", [False, True], ids=["zeros", "maxes"]) +def test_quantization_block_tiling_extrema_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + quant_dtype: torch.dtype, + eps: float, + pow_2_scales: bool, + tile_size: Tuple[int, int], + extrema_high: bool, +) -> None: + # This test runs a single tile through a quantizer as a way to test + # branch coverage of scale computation. + te_dtype = TE_DType[quant_dtype] + if tile_size == (1, 128): + block_scaling_dim = 1 + elif tile_size == (128, 128): + block_scaling_dim = 2 + else: + raise ValueError("Non support tile size") + ref_quantizer = BlockwiseQuantizerReference() + sut_quantizer = Float8BlockQuantizer( + fp8_dtype=te_dtype, + rowwise=True, + columnwise=False, + amax_epsilon=eps, + force_pow_2_scales=pow_2_scales, + block_scaling_dim=block_scaling_dim, + ) + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + return_transpose = False + # Input + if extrema_high: + x = torch.full((M, N), torch.finfo(x_dtype).max, dtype=x_dtype, device=device) + else: + x = torch.zeros((M, N), dtype=x_dtype, device=device) + + # Run cast and transpose kernel + # Internal call ops.quantize_tensorwise + x_fp8_sut = sut_quantizer.make_empty((M, N), dtype=x_dtype, device=device, requires_grad=False) + x_fp8_sut = sut_quantizer.update_quantized(x, x_fp8_sut) + qx = x_fp8_sut._rowwise_data.view(dtype=quant_dtype) + sx = x_fp8_sut._rowwise_scale_inv + + qresult_ref = ref_quantizer.quantize( + x, + quant_dtype=quant_dtype, + return_transpose=return_transpose, + eps=eps, + pow_2_scales=pow_2_scales, + quant_tile_shape=tile_size, + ) + qx_ref, sx_ref = ( + qresult_ref.data, + qresult_ref.scale, + ) + + # Check + torch.testing.assert_close(qx.float(), qx_ref.float(), atol=0.0, rtol=0.0) + torch.testing.assert_close(sx, sx_ref, atol=0.0, rtol=0.0) + + if extrema_high: + expected_value = torch.finfo(quant_dtype).max / torch.finfo(x_dtype).max + if pow_2_scales: + expected_value = math.floor(math.log2(expected_value)) + expected_value = math.pow(2.0, expected_value) + expected_value = 1 / expected_value + elif not extrema_high and eps == 0: + expected_value = 1.0 + else: + assert not extrema_high + # eps is small enough to trigger inf in quant_dtype_max / eps + if pow_2_scales: + expected_value = math.pow(2.0, -127) + else: + expected_value = 1 / torch.finfo(x_dtype).max + torch.testing.assert_close( + sx, + torch.tensor([expected_value], device=sx.device).reshape(1, 1), + atol=0.0, + rtol=0.0, + ) diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py new file mode 100644 index 0000000000..7058fdb22f --- /dev/null +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -0,0 +1,201 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from collections.abc import Iterable +import io +import math +from typing import Any, Dict, List, Tuple, Union + +import pytest +import torch + +import transformer_engine.common.recipe +import transformer_engine.pytorch as te +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + Float8BlockQuantizer, + Float8BlockwiseQTensor, +) +import transformer_engine_torch as tex + +# PyTorch tensor dtypes +_dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16] +# TE FP8 dtypes +_fp8_dtypes: List[tex.DType] = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2] + +# Numerical tolerances with FP8 types +_tols: Dict[tex.DType, Dict[str, float]] = { + tex.DType.kFloat8E4M3: dict(rtol=0.125, atol=0.0675), # epsilon = 0.0625 + tex.DType.kFloat8E5M2: dict(rtol=0.25, atol=0.125), # epsilon = 0.125 +} + + +def _to_list(x: Union[Iterable, Any]) -> List: + """Convert to list if iterable, otherwise put in singleton list""" + if isinstance(x, Iterable): + return list(x) + else: + return [x] + + +# Types that can be interpreted as tensor dims +DimsType = Union[Iterable[int], int] + +# Check if FP8 is supported +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +class TestFloat8BlockwiseTensor: + + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + def test_constructor( + self, + dims: DimsType = 1, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + dtype: torch.dtype = torch.float32, + ) -> None: + """Call constructor and perform sanity checks""" + dims = _to_list(dims) + + rowwise = True + columnwise = True + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, rowwise=rowwise, columnwise=columnwise + ) + + scale_dims = quantizer.get_scale_shape(dims, columnwise=False) + columnwise_scale_dims = quantizer.get_scale_shape(dims, columnwise=True) + columnwise_dims = quantizer.get_columnwise_shape(dims) + tensor = Float8BlockwiseQTensor( + shape=dims, + dtype=dtype, + rowwise_data=torch.zeros(dims, device="cuda", dtype=torch.uint8), + rowwise_scale_inv=torch.zeros(scale_dims, device="cuda", dtype=torch.float32), + columnwise_data=torch.zeros(columnwise_dims, device="cuda", dtype=torch.uint8), + columnwise_scale_inv=torch.zeros( + columnwise_scale_dims, device="cuda", dtype=torch.float32 + ), + fp8_dtype=fp8_dtype, + quantizer=quantizer, + ) + assert list(tensor.size()) == dims, "Incorrect dims" + assert tensor.dtype == dtype, "Incorrect nominal dtype" + assert tensor.is_cuda, "Incorrect device" + + def _test_quantize_dequantize( + self, + quantizer: Float8BlockQuantizer, + dtype: torch.dtype = torch.float32, + dims: DimsType = (23, 128), + rtol: float = 0.0, + atol: float = 0.0, + dequant_columnwise: bool = False, + use_cpp_allocation: bool = False, + ) -> None: + """Check numerical error when casting to FP8 and back""" + dims = _to_list(dims) + + # Initialize random data + x_ref = 2 * torch.rand(dims, dtype=dtype, device="cpu") - 1 + x_ref_cuda = x_ref.to("cuda") + + # Cast to FP8 and back + if not use_cpp_allocation: + x_fp8 = quantizer.make_empty(shape=dims, device="cuda") + quantizer.update_quantized(x_ref_cuda, x_fp8) + else: + # This codepath allows the CPP binding to allocate the output + # tensor + x_fp8 = tex.quantize(x_ref_cuda, quantizer, None, None) + if dequant_columnwise: + # Strip out rowwise data to verify dequantization of + # columnwise data. + x_fp8.update_usage(rowwise_usage=False, columnwise_usage=True) + x_fp8 = x_fp8.dequantize(dtype=dtype).cpu() + + # Check results + torch.testing.assert_close(x_fp8, x_ref, rtol=rtol, atol=atol) + + # Make sure we are not trivially passing the test + with pytest.raises(AssertionError): + torch.testing.assert_close(x_fp8, -x_ref, rtol=rtol, atol=atol) + + @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("block_scaling_dim", [1, 2]) + def test_quantize_dequantize_dtypes( + self, fp8_dtype: tex.DType, dtype: torch.dtype, block_scaling_dim: int + ) -> None: + atol = _tols[fp8_dtype]["atol"] + rtol = _tols[fp8_dtype]["rtol"] + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, + rowwise=True, + columnwise=False, + block_scaling_dim=block_scaling_dim, + ) + self._test_quantize_dequantize(quantizer=quantizer, dtype=dtype, atol=atol, rtol=rtol) + + @pytest.mark.parametrize( + "dims", [[], 256, 311, [264], [256, 512], [250, 500], [7, 5, 3], [2, 3, 5, 3]] + ) + @pytest.mark.parametrize("block_scaling_dim", [1, 2]) + @pytest.mark.parametrize("dq_columnwise", [True, False]) + def test_quantize_dequantize_dims( + self, dims: DimsType, block_scaling_dim: int, dq_columnwise: bool + ) -> None: + atol = _tols[tex.DType.kFloat8E4M3]["atol"] + rtol = _tols[tex.DType.kFloat8E4M3]["rtol"] + quantizer = Float8BlockQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=dq_columnwise, + block_scaling_dim=block_scaling_dim, + ) + self._test_quantize_dequantize( + quantizer=quantizer, + dims=dims, + atol=atol, + rtol=rtol, + dequant_columnwise=dq_columnwise, + ) + + @pytest.mark.parametrize( + "dims", [[], 256, 311, [264], [256, 512], [250, 500], [7, 5, 3], [2, 3, 5, 3]] + ) + @pytest.mark.parametrize("block_scaling_dim", [1, 2]) + @pytest.mark.parametrize("fp8_dtype", _fp8_dtypes) + @pytest.mark.parametrize("dq_columnwise", [True, False]) + def test_quantize_dequantize_dims_cpp_allocate_output( + self, dims: DimsType, block_scaling_dim: int, fp8_dtype: tex.DType, dq_columnwise: bool + ) -> None: + atol = _tols[fp8_dtype]["atol"] + rtol = _tols[fp8_dtype]["rtol"] + quantizer = Float8BlockQuantizer( + fp8_dtype=fp8_dtype, + rowwise=True, + columnwise=dq_columnwise, + block_scaling_dim=block_scaling_dim, + ) + self._test_quantize_dequantize( + quantizer=quantizer, + dims=dims, + atol=atol, + rtol=rtol, + dequant_columnwise=dq_columnwise, + use_cpp_allocation=True, + ) + + # FIXME(kwyss): Add some testing for other tensor operations. + # - basic_ops + # - in_place_ops + # - serialization + # - set_data diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index c77d230ce5..17421d039e 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -58,6 +58,8 @@ list(APPEND transformer_engine_SOURCES transpose/cast_transpose_fusion.cu transpose/transpose_fusion.cu transpose/multi_cast_transpose.cu + transpose/quantize_transpose_square_blockwise.cu + transpose/quantize_transpose_vector_blockwise.cu activation/gelu.cu fused_attn/fused_attn_f16_max512_seqlen.cu fused_attn/fused_attn_f16_arbitrary_seqlen.cu diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 46eb248156..1f19b14174 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -86,6 +86,10 @@ struct Tensor { NVTEScalingMode scaling_mode; + float amax_epsilon; + bool force_pow_2_scales; + int block_scaling_dim; + Tensor() : data(), columnwise_data(), @@ -93,7 +97,10 @@ struct Tensor { scale(nullptr, {1}, DType::kFloat32), scale_inv(nullptr, {1}, DType::kFloat32), columnwise_scale_inv(nullptr, {1}, DType::kFloat32), - scaling_mode(NVTE_DELAYED_TENSOR_SCALING) {} + scaling_mode(NVTE_DELAYED_TENSOR_SCALING), + amax_epsilon(0.0), + force_pow_2_scales(false), + block_scaling_dim(scaling_mode == NVTE_BLOCK_SCALING ? 2 : 0) {} int numel() const { NVTE_CHECK(data.dptr != nullptr || columnwise_data.dptr != nullptr, @@ -116,6 +123,33 @@ struct Tensor { bool has_columnwise_data() const noexcept { return columnwise_data.dptr != nullptr; } + bool supports_force_pow_2_scales_qopt() const noexcept { + switch (scaling_mode) { + case NVTE_BLOCK_SCALING: + return true; + default: + return false; + } + } + + bool supports_amax_epsilon_qopt() const noexcept { + switch (scaling_mode) { + case NVTE_BLOCK_SCALING: + return true; + default: + return false; + } + } + + bool supports_block_scaling_dim(int block_scaling_dim) const noexcept { + switch (scaling_mode) { + case NVTE_BLOCK_SCALING: + return block_scaling_dim == 1 || block_scaling_dim == 2; + default: + return false; + } + } + DType dtype() const { if (has_data()) return data.dtype; if (has_columnwise_data()) return columnwise_data.dtype; @@ -396,6 +430,15 @@ struct TypeInfo { } \ } +#define TRANSFORMER_ENGINE_SWITCH_CONDITION(CONDITION, FLAG, ...) \ + if (CONDITION) { \ + constexpr bool FLAG = true; \ + { __VA_ARGS__ } \ + } else { \ + constexpr bool FLAG = false; \ + { __VA_ARGS__ } \ + } + //////////////////////////////////////////////////////////////////////////////////////////////////// inline int log2_ceil(int value) { diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 52fa89b914..cb1cc0e757 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -76,6 +76,9 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla const transformer_engine::Tensor &B, const cublasOperation_t transB, const int k, const int lda, const int ldb) { using namespace transformer_engine; + // FIXME(kwyss): 1x128 by 128x128 GEMM is part of the subchannel design. + // Must either force them both into a common block scaling mode or loosen this + // restriction. NVTE_CHECK(A.scaling_mode == B.scaling_mode, "Inputs A and B to GEMM need to have the same scaling mode!"); NVTE_CHECK(A.has_data() || A.has_columnwise_data(), "Input A does not hold any data!"); @@ -85,6 +88,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.lda = lda; ret.ldb = ldb; + // FIXME(kwyss): 128x128 by 128x128 GEMMs and 1x128 by 128x128 GEMMs need cases + // or need to be treated as `is_tensor_scaling`. if (is_tensor_scaling(A.scaling_mode)) { ret.A = A.data.dptr; ret.A_scale_inv = A.scale_inv.dptr; @@ -238,6 +243,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(fastAccuMode))); + // FIXME(kwyss): Add binding code for 128x128 block quantized 1x128 block quantized + // GEMM types. + // Scaling factors. #if CUDA_VERSION >= 12080 cublasLtMatmulMatrixScale_t scaling_mode; diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index d57975b2f4..9be0e14d8a 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -42,23 +42,25 @@ extern "C" { * of the output tensor should be set to 0. */ -/*! \brief Casts input tensor to FP8/MXFP8. +/*! \brief Casts input tensor to FP8/MXFP8/BlockwiseFP8. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, - * the block quantization (MXFP8) of the specified shape of the block will be used. + * the MXFP8 block quantization of the specified shape of the block will be used. + * If the scaling mode of the output tensor is set to NVTE_BLOCK_SCALING, + * blockwise float8 scaling will be used. * * \param[in] input Input tensor to be cast. - * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[in,out] output Output FP8/MXFP8/BlockwiseFP8 tensor. * \param[in] stream CUDA stream used for the operation. */ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream); /*! \brief Casts input tensor to FP8/MXFP8, providing the option to immediately exit the kernel * based on the value of the 'noop' tensor. - * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, - * the block quantization (MXFP8) of the specified shape of the block will be used. + * The type of quantized tensor in the output depends on the scaling mode of the output + * tensor. * * \param[in] input Input tensor to be cast. - * \param[in,out] output Output FP8/MXFP8 tensor. + * \param[in,out] output Output quantized tensor. * \param[out] noop Noop tensor. * \param[in] stream CUDA stream used for the operation. */ diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index e393dbffc4..ac0a4a0a77 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -77,6 +77,11 @@ enum NVTEScalingMode { /*! Single scale per block of 32 elements consecutive in either rowwise or columnwise direction */ NVTE_MXFP8_1D_SCALING = 1, + /*! Tensor is split into NxN quantization tiles or 1xN quantization tiles, + which each yield a scale. The block_scaling_dim property of the quantizer + selects the granularity. + */ + NVTE_BLOCK_SCALING = 2, NVTE_INVALID_SCALING }; @@ -231,6 +236,63 @@ void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream); void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, const NVTEBasicTensor *param); +/*! \brief Set a quantization option for whether to force power of 2 scales. + * + * \param[in/out] tensor Tensor. + * \param[in] zero_if_false Whether to force power of 2 scales. + * + * \return zero if the tensor supports this option and it was set. non-zero if + * call had no effect. + */ +int nvte_set_qopt_force_pow_2_scales(NVTETensor tensor, int zero_if_false); + +/*! \brief Set a quantization option for epsilon to set floor of amax. + * + * \param[in/out] tensor Tensor. + * \param[in] amax_epsilon Epsilon to use for amax calculation. + * + * \return zero if the tensor supports this option and it was set. non-zero if + * call had no effect. + */ +int nvte_set_qopt_amax_epsilon(NVTETensor tensor, float amax_epsilon); + +/*! \brief Set a quantization option to use 1D or 2D quantization blocks + * to scale the tensor. + * + * \param[in/out] tensor Tensor. + * \param[in] block_scaling_dim, 1D or 2D. + * + * \return zero if the tensor supports this option and it was set. non-zero if + * call had no effect or the number of dims is not supported. + */ +int nvte_set_qopt_block_scaling_dim(NVTETensor tensor, int block_scaling_dim); + +/*! \brief Get a quantization option for whether to force power of 2 scales. + * + * \param[in] tensor Tensor. + * + * \return zero if the tensor will not force power of 2 scales or if the + * setting is irrelevant. non-zero if the flag is configured. + */ +int nvte_get_qopt_force_pow_2_scales(NVTETensor tensor); + +/*! \brief Get a quantization option for amax epsilon. + * + * \param[in] tensor Tensor. + * + * \return amax_epsilon value or zero if not applicable. + */ +float nvte_get_qopt_amax_epsilon(const NVTETensor tensor); + +/*! \brief Get the number of dimensions in the quantization blocks. + * + * \param[in] tensor Tensor. + * + * \return zero if the quantization does not support the block_scaling_dim + * option or the block_scaling_dim configured. + */ +int nvte_get_qopt_block_scaling_dim(const NVTETensor tensor); + /*! \brief Get a value of the parameter of the tensor. * * \param[in] tensor Tensor. @@ -598,6 +660,24 @@ class TensorWrapper { void zero_(cudaStream_t stream) { nvte_zero_tensor(tensor_, stream); } + int set_qopt_force_pow_2_scales(bool flag) { + return nvte_set_qopt_force_pow_2_scales(tensor_, flag ? 1 : 0); + } + + int set_qopt_amax_epsilon(float eps) { return nvte_set_qopt_amax_epsilon(tensor_, eps); } + + int set_qopt_block_scaling_dim(int block_scaling_dim) { + return nvte_set_qopt_block_scaling_dim(tensor_, block_scaling_dim); + } + + bool get_qopt_force_pow_2_scales() const { + return nvte_get_qopt_force_pow_2_scales(tensor_) != 0; + } + + float get_qopt_amax_epsilon() const { return nvte_get_qopt_amax_epsilon(tensor_); } + + int get_qopt_block_scaling_dim() const { return nvte_get_qopt_block_scaling_dim(tensor_); } + static constexpr size_t defaultData = 1; static constexpr NVTEShape defaultShape = {&defaultData, 1}; diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index faf6ec990d..9e9174237a 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -412,3 +412,48 @@ void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) { } cudaStreamSynchronize(stream); } + +int nvte_set_qopt_force_pow_2_scales(NVTETensor tensor, int zero_if_false) { + auto &t = *reinterpret_cast(tensor); + if (t.supports_force_pow_2_scales_qopt()) { + t.force_pow_2_scales = zero_if_false != 0; + return 0; + } else { + return 1; + } +} + +int nvte_set_qopt_amax_epsilon(NVTETensor tensor, float amax_epsilon) { + auto &t = *reinterpret_cast(tensor); + if (t.supports_amax_epsilon_qopt()) { + t.amax_epsilon = amax_epsilon; + return 0; + } else { + return 1; + } +} + +int nvte_set_qopt_block_scaling_dim(NVTETensor tensor, int block_scaling_dim) { + auto &t = *reinterpret_cast(tensor); + if (t.supports_block_scaling_dim(block_scaling_dim)) { + t.block_scaling_dim = block_scaling_dim; + return 0; + } else { + return 1; + } +} + +int nvte_get_qopt_force_pow_2_scales(const NVTETensor tensor) { + const auto &t = *reinterpret_cast(tensor); + return t.force_pow_2_scales ? 1 : 0; +} + +float nvte_get_qopt_amax_epsilon(const NVTETensor tensor) { + const auto &t = *reinterpret_cast(tensor); + return t.amax_epsilon; +} + +int nvte_get_qopt_block_scaling_dim(const NVTETensor tensor) { + const auto &t = *reinterpret_cast(tensor); + return t.block_scaling_dim; +} diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index ed9bd5f5f7..cf2cb15174 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -23,6 +23,18 @@ template +#include +#include + +#include +#include + +namespace transformer_engine { + +// Type trait for extreme values of fp8 types. +// Used in the calculation of scale factors +// as a constexpr lookup from e4m3 or e5m2 to +// the max finite value. +template +struct F8LimitsTrait; + +template <> +struct F8LimitsTrait<__nv_fp8_e4m3> { + static constexpr float max = 448.0f; +}; + +template <> +struct F8LimitsTrait<__nv_fp8_e5m2> { + static constexpr float max = 57344.0f; +}; + +// Type trait to resolve the max finite value +// represented by a input type to quantization. +// Or to represent max representable power of 2 +// finite value. +template +struct HighPrecisionFloatScaleLimitsTrait; + +template <> +struct HighPrecisionFloatScaleLimitsTrait { + static constexpr float max = std::numeric_limits::max(); +}; + +template <> +struct HighPrecisionFloatScaleLimitsTrait { + // Hex float format of 1.0 * 2 ^ 127 + static constexpr float max = 0x1.0p127; +}; + +template <> +struct HighPrecisionFloatScaleLimitsTrait { + // Hex float format of 1.(7 bits of 1) * 2 ^ 127 + static constexpr float max = 0x1.FEp127; +}; + +template <> +struct HighPrecisionFloatScaleLimitsTrait { + // Hex float format of 1.0 * 2 ^ 127 + static constexpr float max = 0x1.0p127; +}; + +template <> +struct HighPrecisionFloatScaleLimitsTrait { + // Hex float format of 1.(10 bits of 1) * 2 ^ 15 + static constexpr float max = 0x1.FFCp15; +}; + +template <> +struct HighPrecisionFloatScaleLimitsTrait { + // Hex float format of 1.0 * 2 ^ 15 + static constexpr float max = 0x1.0p15; +}; + +// Calculate the quantization scale for an individual data element +// given the amax(abs(tile)) value for a given quantization tile. +// +// +// Arguments: +// IType: data type of the tensor being quantized (float or bf16) +// OType: quantized data type (e4m3 or e5m2) +// pow_2_scaling: Whether to force the scale to be a power of 2. +// amax: The evaluation of amax(abs(tile)) for the quantization tile. +// eps: An epsilon used as a floor for amax. +template +__device__ __forceinline__ float ComputeScale(const float amax, const float eps) { + constexpr float fp8_max = F8LimitsTrait::max; + + // Clamping amax to avoid division by small numbers + float amax_mod = fmaxf(amax, eps); + + // Handle overflow cases for non-clamped amax (eps is 0 or very small) + if (amax_mod == 0.f) { + // If amax is 0, return 1 + return 1.f; + } + // Compute scale factor + float scale = fp8_max / amax_mod; + + if (isinf(scale)) { + // If scale is infinity, return max value of IType + return HighPrecisionFloatScaleLimitsTrait::max; + } + if (scale == 0.0) { + // Case that amax is "inf". The frexp, ldexp logic changes 0.0 scales. + // Return 0.0 for 0.0 scale here is consistent with non-Power2Scaling model. + // quantization will remove signal from the tensor, + // this is bad for the model, but define pow2Scale behavior + // as returning 0.0 scale. amax calculation can + // improve the situation to avoid this by taking largest finite. + return scale; + } + if constexpr (Power2Scaling) { + // NOTE: using bit fiddling based on advice of Asit in this + // thread: https://nvidia.slack.com/archives/C06EDT7LZEW/p1738274404153439 + + // inf scales already early returned, as did nan scales. + // The cases to consider here are normals, zero, and subnormals. + // zero is not possible with current math as + // 448.0 / float_max == 1.31655e-36, which is the smallest + // possible scale given current dtypes. It is still in the normal + // fp32 range with an exponent of -120, so subnormals are also + // not possible. To handle normals, we can simply mask off the + // mantissa. + uint32_t scale_bits = *reinterpret_cast(&scale); + scale_bits &= 0xFF800000; + // If the exponent was zero, we have a logic error. + __builtin_assume(scale_bits != 0); + __builtin_assume(scale_bits != 0x80000000); + scale = *reinterpret_cast(&scale_bits); + } + return scale; +} + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_COMPUTE_SCALE_CUH_ diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu new file mode 100644 index 0000000000..934cf8a5fb --- /dev/null +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -0,0 +1,603 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include + +#include "common/common.h" +#include "common/utils.cuh" +#include "compute_scale.cuh" + +#if (!defined(__CUDA_MINIMUM_ARCH__)) || \ + (defined(__CUDA_MINIMUM_ARCH__) && __CUDA_MINIMUM_ARCH__ >= 900) +#define TMA_HW_SUPPORTED +#endif + +namespace transformer_engine { +namespace { + +#ifdef TMA_HW_SUPPORTED +using barrier = cuda::barrier; +namespace cde = cuda::device::experimental; +#endif + +// const values configuration + +constexpr size_t kThreadsPerWarp = 32; +#ifdef TMA_HW_SUPPORTED +constexpr size_t BLOCK_TILE_DIM = 128; +constexpr size_t WARP_TILE_DIM_X = 32; +constexpr size_t WARP_TILE_DIM_Y = 64; +constexpr size_t THREAD_TILE_DIM_X = 16; +constexpr size_t THREAD_TILE_DIM_Y = 4; +#else +constexpr size_t BLOCK_TILE_DIM = 128; +constexpr size_t WARP_TILE_DIM_X = 64; +constexpr size_t WARP_TILE_DIM_Y = 32; +constexpr size_t THREAD_TILE_DIM_X = 8; +constexpr size_t THREAD_TILE_DIM_Y = 8; +#endif + +#ifdef TMA_HW_SUPPORTED +constexpr size_t NUM_BYTES_PER_BANK = 4; +constexpr size_t NUM_BANKS_PER_SHARED_ELEM = THREAD_TILE_DIM_Y / NUM_BYTES_PER_BANK; +constexpr size_t SHARED_BLOCK_TILE_DIM_Y = BLOCK_TILE_DIM; +constexpr size_t SHARED_BLOCK_TILE_DIM_X_BANKS = + BLOCK_TILE_DIM / (NUM_BYTES_PER_BANK * NUM_BANKS_PER_SHARED_ELEM); +constexpr size_t NUM_BANKS_Y_IN_WARP = WARP_TILE_DIM_Y / NUM_BYTES_PER_BANK; +#endif +constexpr size_t ELE_PER_THREAD = THREAD_TILE_DIM_X * THREAD_TILE_DIM_Y; +constexpr size_t THREADS_PER_BLOCK = BLOCK_TILE_DIM * BLOCK_TILE_DIM / ELE_PER_THREAD; +constexpr size_t NUM_WARPS_X_IN_BLOCK = BLOCK_TILE_DIM / WARP_TILE_DIM_X; +constexpr size_t NUM_WARPS_Y_IN_BLOCK = BLOCK_TILE_DIM / WARP_TILE_DIM_Y; +constexpr size_t NUM_WARPS_IN_BLOCK = NUM_WARPS_X_IN_BLOCK * NUM_WARPS_Y_IN_BLOCK; + +constexpr size_t NUM_THREADS_X_IN_WARP = WARP_TILE_DIM_X / THREAD_TILE_DIM_X; +constexpr size_t NUM_THREADS_Y_IN_WARP = kThreadsPerWarp / NUM_THREADS_X_IN_WARP; + +#define MIN(a, b) (a < b ? a : b) + +template +__global__ void __launch_bounds__(THREADS_PER_BLOCK) + block_scaled_cast_transpose_kernel(const IType* const input, OType* const output_c, + OType* const output_t, CType* const tile_scales_inv_c, + CType* const tile_scales_inv_t, const size_t row_length, + const size_t num_rows, const size_t scale_stride_x, + const size_t scale_stride_y, const size_t scale_t_stride_x, + const size_t scale_t_stride_y, const float epsilon, + const __grid_constant__ CUtensorMap tensor_map_output_t) { + using IVec = Vec; + using OVecCast = Vec; + using OVecTrans = Vec; + + // shared mem for amax reduction in entire block, each warp produces one amax, there are + // NUM_WARPS_IN_BLOCK amax to reduce + __shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK]; + + IVec thrd_tile_input[THREAD_TILE_DIM_Y]; + constexpr int THREAD_TILE_DIM_X_ = kReturnTranspose ? THREAD_TILE_DIM_X : 1; + OVecTrans thrd_tile_out_trans[THREAD_TILE_DIM_X_]; + + const int tid_in_warp = threadIdx.x % kThreadsPerWarp; + const int tid_in_warp_x = tid_in_warp % NUM_THREADS_X_IN_WARP; + const int tid_in_warp_y = tid_in_warp / NUM_THREADS_X_IN_WARP; + const int warp_id_in_block = threadIdx.x / kThreadsPerWarp; + const int warp_id_in_block_x = warp_id_in_block % NUM_WARPS_X_IN_BLOCK; + const int warp_id_in_block_y = warp_id_in_block / NUM_WARPS_X_IN_BLOCK; + + // This is ONLY true if the input is a full tile + const int tile_id_x = blockIdx.x; + const int tile_id_y = blockIdx.y; + + const size_t block_tile_start_idx = + tile_id_y * BLOCK_TILE_DIM * row_length + tile_id_x * BLOCK_TILE_DIM; + const size_t warp_tile_start_idx = + block_tile_start_idx + + warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP * row_length + + warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP; + const size_t thread_tile_start_idx = warp_tile_start_idx + + tid_in_warp_y * THREAD_TILE_DIM_Y * row_length + + tid_in_warp_x * THREAD_TILE_DIM_X; + + CType warp_tile_amax; + CType block_tile_amax; + CType block_tile_scale; + CType amax = 0; + +// Step 1: Load a block tile of input data into thread tiles on registers +#pragma unroll + for (int i = 0; i < THREAD_TILE_DIM_Y; i++) { + thrd_tile_input[i].load_from(input + thread_tile_start_idx + i * row_length); + } + + // Step 2: calculate block tile amax and scale + // Calculate thread_tile amax + for (int i = 0; i < THREAD_TILE_DIM_Y; i++) { +#pragma unroll + for (int j = 0; j < THREAD_TILE_DIM_X; j++) { + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(static_cast(thrd_tile_input[i].data.elt[j]))); + } + } + // Reduce amax in the warp (32x32 tile) + warp_tile_amax = warp_reduce_max(amax); + // broadcast the amax to all threads in a warp from the lane 0 + constexpr int lane_zero = 0; + warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero); + + // reduce warp_tile_amax across multiple warps in a thread block using shared mem + if (tid_in_warp == 0) { + block_tile_amax_shared[warp_id_in_block_y * NUM_WARPS_X_IN_BLOCK + warp_id_in_block_x] = + warp_tile_amax; + } + __syncthreads(); + // only 8 elements needs reduction, if using reduction tree, multiple _syncthreads will be needed, + // instead we just let thread 0 do the job + if (threadIdx.x == 0) { + CType blk_amax = block_tile_amax_shared[0]; +#pragma unroll + for (int idx = 1; idx < NUM_WARPS_IN_BLOCK; idx++) { + blk_amax = fmaxf(blk_amax, block_tile_amax_shared[idx]); + } + block_tile_amax_shared[0] = blk_amax; + } + __syncthreads(); + block_tile_amax = block_tile_amax_shared[0]; + + block_tile_scale = ComputeScale(block_tile_amax, epsilon); + + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + const CType scale_inv = 1.0f / block_tile_scale; + + size_t row_idx = tile_id_y; + size_t col_idx = tile_id_x; + tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv; + + if constexpr (kReturnTranspose) { + row_idx = tile_id_x; + col_idx = tile_id_y; + tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv; + } + } + + // Step 3: Store cast output, Step 4: do transpose within thread tile + OVecCast tmp_output_c; + + for (int i = 0; i < THREAD_TILE_DIM_Y; i++) { +#pragma unroll + for (int j = 0; j < THREAD_TILE_DIM_X; j++) { + // Step 3: Store cast output + CType scale_data = block_tile_scale; + + OType scaled_elt = + static_cast(static_cast(thrd_tile_input[i].data.elt[j]) * scale_data); + tmp_output_c.data.elt[j] = scaled_elt; + // Step 4: do transpose within thread tile + if constexpr (kReturnTranspose) { + thrd_tile_out_trans[j].data.elt[i] = scaled_elt; + } + } + tmp_output_c.store_to(output_c + thread_tile_start_idx + i * row_length); + } + + // Step 4: store transpose into shared memory + if constexpr (kReturnTranspose) { +#ifdef TMA_HW_SUPPORTED + __shared__ alignas(128) + OVecTrans block_tile_trans_shared[SHARED_BLOCK_TILE_DIM_Y][SHARED_BLOCK_TILE_DIM_X_BANKS]; + OType(*block_tile_trans_shared_otype_ptr)[BLOCK_TILE_DIM] = + reinterpret_cast(block_tile_trans_shared); + +#pragma unroll + for (int i = 0; i < THREAD_TILE_DIM_X; i++) { + auto warp_id_in_block_x_ = warp_id_in_block_y; + auto warp_id_in_block_y_ = warp_id_in_block_x; + int row_idx = warp_id_in_block_y_ * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP + + tid_in_warp_x * THREAD_TILE_DIM_X + i; + int col_idx = + warp_id_in_block_x_ * (NUM_BANKS_Y_IN_WARP / NUM_BANKS_PER_SHARED_ELEM) + tid_in_warp_y; + block_tile_trans_shared[row_idx][col_idx] = thrd_tile_out_trans[i]; + } + + // Wait for shared memory writes to be visible to TMA engine. + cde::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Step 5: store transpose output + // Initiate TMA transfer to copy shared memory to global memory + if (threadIdx.x == 0) { + cde::cp_async_bulk_tensor_2d_shared_to_global( + &tensor_map_output_t, tile_id_y * BLOCK_TILE_DIM, tile_id_x * BLOCK_TILE_DIM, + block_tile_trans_shared_otype_ptr); + // Wait for TMA transfer to have finished reading shared memory. + // Create a "bulk async-group" out of the previous bulk copy operation. + cde::cp_async_bulk_commit_group(); + // Wait for the group to have completed reading from shared memory. + cde::cp_async_bulk_wait_group_read<0>(); + } +#else + // Step 4 Alternative (when TMA is not available, skip writing to shared memory) + const size_t block_tile_t_start_idx = + tile_id_x * BLOCK_TILE_DIM * num_rows + tile_id_y * BLOCK_TILE_DIM; + const size_t warp_tile_t_start_idx = + block_tile_t_start_idx + + warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP * num_rows + + warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP; + const size_t thread_tile_t_start_idx = warp_tile_t_start_idx + + tid_in_warp_x * THREAD_TILE_DIM_X * num_rows + + tid_in_warp_y * THREAD_TILE_DIM_Y; +#pragma unroll + for (int i = 0; i < THREAD_TILE_DIM_X; i++) { + thrd_tile_out_trans[i].store_to(output_t + thread_tile_t_start_idx + i * num_rows); + } +#endif + } +} + +template +__global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose_kernel_notaligned( + const IType* const input, OType* const output_c, OType* const output_t, + CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length, + const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y, + const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon) { + using IVec = Vec; + using OVecCast = Vec; + using OVecTrans = Vec; + + // shared mem for amax reduction in entire block, each warp produces one amax, there are + // NUM_WARPS_IN_BLOCK amax to reduce + __shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK]; + + IVec thrd_tile_input[THREAD_TILE_DIM_Y]; + constexpr int THREAD_TILE_DIM_X_ = kReturnTranspose ? THREAD_TILE_DIM_X : 1; + OVecTrans thrd_tile_out_trans[THREAD_TILE_DIM_X_]; + + const int tid_in_warp = threadIdx.x % kThreadsPerWarp; + const int tid_in_warp_x = tid_in_warp % NUM_THREADS_X_IN_WARP; + const int tid_in_warp_y = tid_in_warp / NUM_THREADS_X_IN_WARP; + const int warp_id_in_block = threadIdx.x / kThreadsPerWarp; + const int warp_id_in_block_x = warp_id_in_block % NUM_WARPS_X_IN_BLOCK; + const int warp_id_in_block_y = warp_id_in_block / NUM_WARPS_X_IN_BLOCK; + + const int tile_id_x = blockIdx.x; + const int tile_id_y = blockIdx.y; + + const size_t block_tile_start_row_idx = tile_id_y * BLOCK_TILE_DIM; + const size_t block_tile_start_col_idx = tile_id_x * BLOCK_TILE_DIM; + const size_t block_tile_start_idx = + block_tile_start_row_idx * row_length + block_tile_start_col_idx; + const size_t warp_tile_start_idx = + block_tile_start_idx + + warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP * row_length + + warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP; + const size_t thread_tile_start_idx = warp_tile_start_idx + + tid_in_warp_y * THREAD_TILE_DIM_Y * row_length + + tid_in_warp_x * THREAD_TILE_DIM_X; + + // handle non-full tile + // check for three cases: full thread tile, nonfull thread tile, empty thread tile + // for empty thread tile, directly write zero to the transposed shared mem buffer + // for nonfull thread tile, fill zero to thread tile and act as if it's full + const size_t thread_tile_start_row_idx = + tile_id_y * BLOCK_TILE_DIM + warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP + + tid_in_warp_y * THREAD_TILE_DIM_Y; + const size_t thread_tile_start_col_idx = + tile_id_x * BLOCK_TILE_DIM + warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP + + tid_in_warp_x * THREAD_TILE_DIM_X; + + const size_t thread_tile_end_row_idx = thread_tile_start_row_idx + THREAD_TILE_DIM_Y - 1; + const size_t thread_tile_end_col_idx = thread_tile_start_col_idx + THREAD_TILE_DIM_X - 1; + + bool full_thrd_tile = + (thread_tile_end_row_idx < num_rows) && (thread_tile_end_col_idx < row_length); + bool empty_thrd_tile = + (thread_tile_start_row_idx >= num_rows) || (thread_tile_start_col_idx >= row_length); + bool nonfull_thrd_tile = (!full_thrd_tile) && (!empty_thrd_tile); + + const size_t thread_tile_ncols = + MIN(THREAD_TILE_DIM_X, + (MIN(thread_tile_end_col_idx, row_length - 1) - thread_tile_start_col_idx + 1)); + const size_t thread_tile_nrows = + MIN(THREAD_TILE_DIM_Y, + (MIN(thread_tile_end_row_idx, num_rows - 1) - thread_tile_start_row_idx + 1)); + + CType warp_tile_amax; + CType block_tile_amax; + CType block_tile_scale; + CType amax = 0; + + if (!empty_thrd_tile) { + // Step 1: Load a block tile of input data into thread tiles on registers + // Edge case: nonfull thread tile case, will use the partial load function here + if (nonfull_thrd_tile) { +#pragma unroll + for (int i = 0; i < THREAD_TILE_DIM_Y; i++) { + if (i >= thread_tile_nrows) { + thrd_tile_input[i].clear(); + } else { + thrd_tile_input[i].load_from_elts(input + thread_tile_start_idx + i * row_length, 0, + thread_tile_ncols); + } + } + } else { +#pragma unroll + for (int i = 0; i < THREAD_TILE_DIM_Y; i++) { + thrd_tile_input[i].load_from_elts(input + thread_tile_start_idx + i * row_length, 0, + THREAD_TILE_DIM_X); + } + } + + // Step 2: calculate block tile amax and scale + // Calculate thread_tile amax + for (int i = 0; i < THREAD_TILE_DIM_Y; i++) { +#pragma unroll + for (int j = 0; j < THREAD_TILE_DIM_X; j++) { + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(static_cast(thrd_tile_input[i].data.elt[j]))); + } + } + } + // Reduce amax in the warp (32x32 tile) + warp_tile_amax = warp_reduce_max(amax); + // broadcast the amax to all threads in a warp from the lane 0 + constexpr int lane_zero = 0; + warp_tile_amax = __shfl_sync(0xFFFFFFFF, warp_tile_amax, lane_zero); + + // reduce warp_tile_amax across multiple warps in a thread block using shared mem + if (tid_in_warp == 0) { + block_tile_amax_shared[warp_id_in_block_y * NUM_WARPS_X_IN_BLOCK + warp_id_in_block_x] = + warp_tile_amax; + } + __syncthreads(); + // only 8 elements needs reduction, if using reduction tree, multiple _syncthreads will be needed, + // instead we just let thread 0 do the job + if (threadIdx.x == 0) { + CType blk_amax = block_tile_amax_shared[0]; +#pragma unroll + for (int idx = 1; idx < NUM_WARPS_IN_BLOCK; idx++) { + blk_amax = fmaxf(blk_amax, block_tile_amax_shared[idx]); + } + block_tile_amax_shared[0] = blk_amax; + } + __syncthreads(); + block_tile_amax = block_tile_amax_shared[0]; + + block_tile_scale = ComputeScale(block_tile_amax, epsilon); + + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + const CType scale_inv = 1.0f / block_tile_scale; + + size_t row_idx = tile_id_y; + size_t col_idx = tile_id_x; + tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv; + + if constexpr (kReturnTranspose) { + row_idx = tile_id_x; + col_idx = tile_id_y; + tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv; + } + } + + // Step 3: Store cast output, Step 4: do transpose within thread tile + // Edge case: in the non-full tile case, there are three subcases + // for full thread tile, it's the same thing here + // for nonfull thread tile, pay attention when saving tmp_output_c to global + // memory, cannot vec store_to, but need to elt store to for empty tile, + // it should not enter this step, skip to Step 4 + + // set thrd_tile_out_trans to all zero + if constexpr (kReturnTranspose) { +#pragma unroll + for (int j = 0; j < THREAD_TILE_DIM_X; j++) { + thrd_tile_out_trans[j].clear(); + } + } + + if (!empty_thrd_tile) { + OVecCast tmp_output_c; + for (int i = 0; i < THREAD_TILE_DIM_Y; i++) { + if (i >= thread_tile_nrows) { + continue; + } +#pragma unroll + for (int j = 0; j < THREAD_TILE_DIM_X; j++) { + // Step 3: Store cast output + CType scale_data = block_tile_scale; + + OType scaled_elt = + static_cast(static_cast(thrd_tile_input[i].data.elt[j]) * scale_data); + tmp_output_c.data.elt[j] = scaled_elt; + // Step 4: do transpose within thread tile + if constexpr (kReturnTranspose) { + thrd_tile_out_trans[j].data.elt[i] = scaled_elt; + } + } + tmp_output_c.store_to_elts(output_c + thread_tile_start_idx + i * row_length, 0, + thread_tile_ncols); + } + + if constexpr (kReturnTranspose) { + const size_t block_tile_t_start_idx = + tile_id_x * BLOCK_TILE_DIM * num_rows + tile_id_y * BLOCK_TILE_DIM; + const size_t warp_tile_t_start_idx = + block_tile_t_start_idx + + warp_id_in_block_x * THREAD_TILE_DIM_X * NUM_THREADS_X_IN_WARP * num_rows + + warp_id_in_block_y * THREAD_TILE_DIM_Y * NUM_THREADS_Y_IN_WARP; + const size_t thread_tile_t_start_idx = warp_tile_t_start_idx + + tid_in_warp_x * THREAD_TILE_DIM_X * num_rows + + tid_in_warp_y * THREAD_TILE_DIM_Y; +#pragma unroll + for (int i = 0; i < thread_tile_ncols; i++) { + thrd_tile_out_trans[i].store_to_elts(output_t + thread_tile_t_start_idx + i * num_rows, 0, + thread_tile_nrows); + } + } + } +} + +PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() { + void* driver_ptr = nullptr; + cudaDriverEntryPointQueryResult driver_status; + NVTE_CHECK_CUDA(cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &driver_ptr, cudaEnableDefault, + &driver_status)); + return reinterpret_cast(driver_ptr); +} + +template +CUtensorMap get_tensor_map(SimpleTensor& tensor, size_t global_dim_x, size_t global_dim_y) { + // example-begin create-tensor-map + CUtensorMap tensor_map_output_trans{}; + // rank is the number of dimensions of the array. + constexpr uint32_t rank = 2; + uint64_t size[rank] = {global_dim_x, global_dim_y}; // x, y + // The stride is the number of bytes to traverse from the first element of one row to the next. + // It must be a multiple of 16. + uint64_t stride[rank - 1] = {global_dim_x * sizeof(OutputType)}; + // The box_size is the size of the shared memory buffer that is used as the + // destination of a TMA transfer. + uint32_t box_size[rank] = {BLOCK_TILE_DIM, BLOCK_TILE_DIM}; + // The distance between elements in units of sizeof(element). A stride of 2 + // can be used to load only the real component of a complex-valued tensor, for instance. + uint32_t elem_stride[rank] = {1, 1}; + + // Get a function pointer to the cuTensorMapEncodeTiled driver API. + auto cuTensorMapEncodeTiled = get_cuTensorMapEncodeTiled(); + CUtensorMapDataType dataType; + + if constexpr (std::is_same_v || + std::is_same_v) { + dataType = CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8; + } else { + NVTE_CHECK(false, "Invalid Output type (must be FP8)."); + } + + // Create the tensor descriptor. + CUresult res = cuTensorMapEncodeTiled( + &tensor_map_output_trans, // CUtensorMap *tensorMap, + dataType, + rank, // cuuint32_t tensorRank, + reinterpret_cast(tensor.dptr), // void *globalAddress, + size, // const cuuint64_t *globalDim, + stride, // const cuuint64_t *globalStrides, + box_size, // const cuuint32_t *boxDim, + elem_stride, // const cuuint32_t *elementStrides, + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + // Swizzling can be used to avoid shared memory bank conflicts. + CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE, + // Any element that is outside of bounds will be set to zero by the TMA transfer. + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + + return tensor_map_output_trans; +} + +} // namespace +} // namespace transformer_engine + +namespace transformer_engine::detail { + +void nvte_quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor& scale_inv, + SimpleTensor& scale_inv_t, SimpleTensor& output, + SimpleTensor& output_t, const float epsilon, + const bool return_transpose, const bool pow_2_scale, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_transpose_square_blockwise); + NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); + const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; + size_t num_rows = 1; + for (size_t i = 0; (i < input.shape.size() - 1) && (input.shape.size() > 0); ++i) { + num_rows *= input.shape.at(i); + } + + NVTE_CHECK(scale_inv.shape.size() == 2, "scale_inv must have 2 dimensions."); + + size_t scale_k = scale_inv.shape[1]; + + const size_t scale_stride_x = 1; + const size_t scale_stride_y = scale_k; + + size_t scale_t_stride_x = 0; + size_t scale_t_stride_y = 0; + + if (return_transpose) { + NVTE_CHECK(output_t.shape.size() == input.shape.size(), + "output_t must have same number of dimensions as input."); + if (output_t.shape.size() > 0) { + NVTE_CHECK(output_t.shape[0] == row_length, "Wrong dimension 0 of output_t."); + for (size_t i = 1; i < output_t.shape.size(); ++i) { + NVTE_CHECK(output_t.shape.at(i) == input.shape.at(i - 1), "Wrong dimension in output_t"); + } + } + NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same type."); + + NVTE_CHECK(scale_inv_t.shape.size() == 2, "scale_inv_t must have 2 dimensions."); + + scale_t_stride_x = 1; + scale_t_stride_y = scale_inv_t.shape[1]; + } + + const size_t num_blocks_x = DIVUP(row_length, BLOCK_TILE_DIM); + const size_t num_blocks_y = DIVUP(num_rows, BLOCK_TILE_DIM); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype, InputType, + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output.dtype, OutputType, + + dim3 grid(num_blocks_x, num_blocks_y, 1); + const bool full_tile = row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0; + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_transpose, kReturnTranspose, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + pow_2_scale, kPow2Scale, + + if (full_tile) { + CUtensorMap tensor_map_output_trans; + if constexpr (kReturnTranspose) { + tensor_map_output_trans = + get_tensor_map(output_t, num_rows, row_length); + } + block_scaled_cast_transpose_kernel + <<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, + scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, + epsilon, tensor_map_output_trans); + } else { + block_scaled_cast_transpose_kernel_notaligned + <<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, + scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, + epsilon); + } // full-tile + + ) // kPow2Scale + ) // kReturnTranspose + ) // OutputType + ) // InputType + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace transformer_engine::detail diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu new file mode 100644 index 0000000000..d9676504ed --- /dev/null +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -0,0 +1,528 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "common/common.h" +#include "common/utils.cuh" +#include "compute_scale.cuh" + +namespace transformer_engine { +namespace { + +// clang-format off +/* + +Step 1: Load input to shared memory +* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding) +* 8 warps +* Loop 8 times +* What each thread does in each loop: + * 8 elements are read from the input at a time + * 2 elements are written to the shared memory at a time, for a total of 4 times ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 | T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 | +| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 | T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| Warp 1 | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| Warp 7 | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | +| ... | +| Loop 8 times | +| ... | +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ + +Step 2: Cast and store to output_c +* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding) +* 8 warps +* Loop 4 times +* What each thread does in each loop: + * 2 elements are read from the shared memory at a time, for a total of 8 times + * Every 8 consecutive threads do reduction and calculate the amax of each row + * 16 elements are quantized and write to output_c at a time ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 | +| T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 | +| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 | +| T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| | +| Warp 1 | +| | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| | +| Warp 7 | +| | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | +| ... | +| Loop 4 times | +| ... | +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ + +Step 3: Transpose, cast and store to output_t +* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding) +* 8 warps +* Loop 2 times +* What each thread does in each loop: + * 2 elements (in a row) are read from the shared memory at a time, for a total of 16 times + * Every 8 consecutive threads do reduction and calculate the amax of each column + * 16 elements are quantized and write to output_c at a time, for a total of 2 times ++------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+ +| T0 | T8 | T16 | T24 | | | | T0 | T8 | T16 | T24 | | | | +| T1 | T9 | T17 | T25 | | | | T1 | T9 | T17 | T25 | | | | +| T2 | T10 | T18 | T26 | | | | T2 | T10 | T18 | T26 | | | | +| T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 | T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 | +| T4 | T12 | T20 | T28 | | | | T4 | T12 | T20 | T28 | | | | +| T5 | T13 | T21 | T29 | | | | T5 | T13 | T21 | T29 | | | | +| T6 | T14 | T22 | T30 | | | | T6 | T14 | T22 | T30 | | | | +| T7 | T15 | T23 | T31 | | | | T7 | T15 | T23 | T31 | | | | ++-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+ + +*/ +// clang-format on + +constexpr size_t kThreadsPerWarp = 32; + +// Hyperparameters for performance tuning +constexpr int kTileDim = 128; // Fixed to 128 beacause we are using 1x128 and 128x1 quantization +constexpr int kNVecIn = 8; // The number of elements each LDG touches +constexpr int kNVecOut = 16; // The number of elements each STG touches +constexpr int kNVecSMem = 2; // The number of elements each LDS/STS touches +constexpr int kThreadsPerBlock = 256; // Thread block size, 8 warps in total + +// Auto-calculated constants, do not modify directly) +static_assert(kNVecIn % kNVecSMem == 0, "kNVecIn must be divisible by kNVecSMem"); +static_assert(kNVecOut % kNVecSMem == 0, "kNVecOut must be divisible by kNVecSMem"); +constexpr int kSMemRow = kTileDim; +constexpr int kSMemCol = (kTileDim / kNVecSMem) + 1; +constexpr int kSMemSize = kSMemRow * kSMemCol * kNVecSMem; +constexpr int kNumThreadsLoad = kTileDim / kNVecIn; +constexpr int kNumThreadsStore = kTileDim / kNVecOut; +static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp"); +static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); + +template +__global__ void __launch_bounds__(kThreadsPerBlock) + block_scaled_1d_cast_transpose_kernel(const IType* const input, OType* const output_c, + OType* const output_t, CType* const tile_scales_inv_c, + CType* const tile_scales_inv_t, const size_t row_length, + const size_t num_rows, const size_t scale_stride_x, + const size_t scale_stride_y, + const size_t scale_t_stride_x, + const size_t scale_t_stride_y, const float epsilon) { + using SMemVec = Vec; + using OVec = Vec; + union IVec { + Vec input_type; + Vec smem_type; + }; + + extern __shared__ char smem_base[]; + SMemVec* smem = reinterpret_cast(&smem_base[0]); + + // Step 1: Load input to shared memory + { + constexpr int r_stride = kThreadsPerBlock / kNumThreadsLoad; // stride in rows of shared memory + constexpr int num_iterations = kTileDim / r_stride; + const int c_s = + (threadIdx.x % kNumThreadsLoad) * (kNVecIn / kNVecSMem); // Column in shared memory + int r_s = threadIdx.x / kNumThreadsLoad; // Row in shared memory + const size_t c_g = (size_t)blockIdx.x * kTileDim + c_s * kNVecSMem; // Column in global memory + size_t r_g = (size_t)blockIdx.y * kTileDim + r_s; // Row in global memory + const size_t stride_g = (size_t)r_stride * row_length; // Stride in global memory + const size_t num_ele = + c_g < row_length ? min((size_t)kNVecIn, row_length - c_g) : 0; // For not aligned case + const IType* input_g = &input[r_g * row_length + c_g]; // Input address in global memory +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + IVec input_vec; + // Step 1.1: Load from global memory (input) to registers + if constexpr (kAligned) { + input_vec.input_type.load_from(input_g); + } else { + if (r_g < num_rows) { + input_vec.input_type.load_from_elts(input_g, 0, num_ele); + } else { + input_vec.input_type.clear(); + } + } + // Step 1.2: Write to shared memory +#pragma unroll + for (int i = 0; i < kNVecIn / kNVecSMem; ++i) { + int c = c_s + i; + int r = r_s; + smem[r * kSMemCol + c] = input_vec.smem_type.data.elt[i]; + } + // Step 1.3: Update input address, row index of shared memory, (and row index of global memory for not aligned case) + input_g += stride_g; + r_s += r_stride; + if constexpr (!kAligned) { + r_g += r_stride; + } + } + } + + __syncthreads(); + + // Step 2: Cast and store to output_c + { + constexpr int r_stride = + kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory + constexpr int num_iterations = kTileDim / r_stride; + const int c_s = + (threadIdx.x % kNumThreadsStore) * (kNVecOut / kNVecSMem); // Column in shared memory + int r_s = threadIdx.x / kNumThreadsStore; // Row in shared memory + const size_t c_g = (size_t)blockIdx.x * kTileDim + c_s * kNVecSMem; // Column in global memory + size_t r_g = (size_t)blockIdx.y * kTileDim + r_s; // Row in global memory + const size_t stride_g = (size_t)r_stride * row_length; // Stride in global memory + const size_t num_ele = + c_g < row_length ? min((size_t)kNVecOut, row_length - c_g) : 0; // For not aligned case + OType* output_g = &output_c[r_g * row_length + c_g]; // Output address in global memory + // Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of + // the first thread to do the reduction. + const unsigned src_lane = (threadIdx.x % kThreadsPerWarp) / kNumThreadsStore * kNumThreadsStore; + // This mask represents which threads should do the reduction together. + const unsigned mask = ((1 << kNumThreadsStore) - 1) << src_lane; + const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0; +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + SMemVec smem_vec[kNVecOut / kNVecSMem]; + // Step 2.1: Load from shared memory to registers +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { + int c = c_s + i; + int r = r_s; + smem_vec[i] = smem[r * kSMemCol + c]; + } + // Step 2.2: Compute local amax + CType amax = 0; +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { +#pragma unroll + for (int j = 0; j < kNVecSMem; ++j) { + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[j])); + } + } + // Step 2.3: Reduce amax +#pragma unroll + for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) { + const float other_amax = __shfl_down_sync(mask, amax, delta); + __builtin_assume(amax >= 0); + __builtin_assume(other_amax >= 0); + amax = fmaxf(amax, other_amax); + } + amax = __shfl_sync(mask, amax, src_lane); + // Step 2.4: Compute scale + CType scale = ComputeScale(amax, epsilon); + // Step 2.5: Write scale_inv + bool write_scale_inv = is_src_lane; + if constexpr (!kAligned) { + write_scale_inv &= (r_g < num_rows); + } + if (write_scale_inv) { + CType scale_inv = 1.0 / scale; + size_t row_idx = (size_t)blockIdx.y * kTileDim + r_s; + size_t col_idx = (size_t)blockIdx.x; + if constexpr (kPermuteScale) { + size_t p_row = row_idx / kTileDim; + size_t p_col = col_idx; + size_t p_dep = row_idx % kTileDim; + size_t p_2d_stride = kTileDim * scale_stride_y; + tile_scales_inv_c[p_row * p_2d_stride + p_col * kTileDim + p_dep] = scale_inv; + } else { + tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv; + } + } + // Step 2.6: Quantize + OVec output_vec; +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { +#pragma unroll + for (int j = 0; j < kNVecSMem; ++j) { + output_vec.data.elt[i * kNVecSMem + j] = + static_cast(static_cast(smem_vec[i].data.elt[j]) * scale); + } + } + // Step 2.7: Store output_c + if constexpr (kAligned) { + output_vec.store_to(output_g); + } else { + if (r_g < num_rows) { + output_vec.store_to_elts(output_g, 0, num_ele); + } + } + // Step 2.8: Update output address, row index of shared memory (and row index of global memory for not aligned case) + output_g += stride_g; + r_s += r_stride; + if constexpr (!kAligned) { + r_g += r_stride; + } + } + } + + // Step 3: Transpose, cast and store to output_t + if constexpr (kReturnTranspose) { + constexpr int c_stride = + kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory + constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem); + const int r_s = (threadIdx.x % kNumThreadsStore) * kNVecOut; // Row in shared memory + int c_s = threadIdx.x / kNumThreadsStore; // Column in shared memory + size_t r_g = (size_t)blockIdx.x * kTileDim + c_s * kNVecSMem; // Row in global memory + const size_t c_g = (size_t)blockIdx.y * kTileDim + r_s; // Column in global memory + const size_t stride_g = (size_t)c_stride * kNVecSMem * num_rows; // Stride in global memory + const size_t num_ele = + c_g < num_rows ? min((size_t)kNVecOut, num_rows - c_g) : 0; // For not aligned case + OType* output_g = &output_t[r_g * num_rows + c_g]; // Output address in global memory + // Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of + // the first thread to do the reduction. + const unsigned src_lane = (threadIdx.x % kThreadsPerWarp) / kNumThreadsStore * kNumThreadsStore; + // This mask represents which threads should do the reduction together. + const unsigned mask = ((1 << kNumThreadsStore) - 1) << src_lane; + const bool is_src_lane = (threadIdx.x % kNumThreadsStore) == 0; +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + SMemVec smem_vec[kNVecOut]; + // Step 3.1: Load from shared memory to registers +#pragma unroll + for (int i = 0; i < kNVecOut; ++i) { + int r = r_s + i; + int c = c_s; + smem_vec[i] = smem[r * kSMemCol + c]; + } +#pragma unroll + for (int smem_idx = 0; smem_idx < kNVecSMem; ++smem_idx) { + // Step 3.2: Compute local amax + CType amax = 0; +#pragma unroll + for (int i = 0; i < kNVecOut; ++i) { + amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[smem_idx])); + } + // Step 3.3: Reduce amax +#pragma unroll + for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) { + const float other_amax = __shfl_down_sync(mask, amax, delta); + __builtin_assume(amax >= 0); + __builtin_assume(other_amax >= 0); + amax = fmaxf(amax, other_amax); + } + amax = __shfl_sync(mask, amax, src_lane); + // Step 3.4: Compute scale + CType scale = ComputeScale(amax, epsilon); + // Step 3.5: Write scale_inv_t + bool write_scale_inv = is_src_lane; + if constexpr (!kAligned) { + write_scale_inv &= (r_g + smem_idx < row_length); + } + if (write_scale_inv) { + CType scale_inv = 1.0 / scale; + size_t row_idx = (size_t)blockIdx.x * kTileDim + c_s * kNVecSMem + smem_idx; + size_t col_idx = (size_t)blockIdx.y; + if constexpr (kPermuteScale) { + size_t p_row = row_idx / kTileDim; + size_t p_col = col_idx; + size_t p_dep = row_idx % kTileDim; + size_t p_2d_stride = kTileDim * scale_t_stride_y; + tile_scales_inv_t[p_row * p_2d_stride + p_col * kTileDim + p_dep] = scale_inv; + } else { + tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv; + } + } + // Step 3.6: Quantize + OVec output_vec; +#pragma unroll + for (int i = 0; i < kNVecOut; ++i) { + output_vec.data.elt[i] = + static_cast(static_cast(smem_vec[i].data.elt[smem_idx]) * scale); + } + // Step 3.7: Store output_t + if constexpr (kAligned) { + output_vec.store_to(output_g + smem_idx * num_rows); + } else { + if (r_g + smem_idx < row_length) { + output_vec.store_to_elts(output_g + smem_idx * num_rows, 0, num_ele); + } + } + } + // Step 3.8: Update output address, column index of shared memory (and row index of global memory for not aligned case) + output_g += stride_g; + c_s += c_stride; + if constexpr (!kAligned) { + r_g += c_stride * kNVecSMem; + } + } + } +} + +} // namespace +} // namespace transformer_engine + +namespace transformer_engine::detail { + +void nvte_quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor& scale_inv, + SimpleTensor& scale_inv_t, SimpleTensor& output, + SimpleTensor& output_t, const float epsilon, + const bool return_transpose, const bool pow2_scale, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_transpose_vector_blockwise); + NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); + + const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; + size_t num_elements = row_length; + size_t num_rows = 1; + for (size_t i = 0; (i < input.shape.size() - 1) && (input.shape.size() > 0); ++i) { + num_rows *= input.shape.at(i); + num_elements *= input.shape.at(i); + } + + // Early return if the input tensor is empty + if (num_elements == 0) { + return; + } + + // Options for scale layout of cuBLAS GEMM kernel. + constexpr bool kPermuteScale = false; + bool permute_scale = false; + bool transpose_scales = true; + + NVTE_CHECK(input.shape.size() == output.shape.size(), + "Input and output must have the same shape."); + NVTE_CHECK((!transpose_scales || !permute_scale), + "Permute scale and transpose scales are mutually exclusive flags."); + + size_t scale_stride_x = 0; + size_t scale_stride_y = 0; + if (permute_scale) { + NVTE_CHECK(scale_inv.shape.size() == 3, "scale_inv must have 3 dimensions."); + size_t scale_k = scale_inv.shape[1]; + NVTE_CHECK(scale_inv.shape[2] == kTileDim, "Scale inner dimension must be kTileDim."); + scale_stride_x = 1; + scale_stride_y = scale_k; + } else { + NVTE_CHECK(scale_inv.shape.size() == 2, "Scale dimension must be 2 when not permuting scale."); + size_t scale_k = scale_inv.shape[1]; + scale_stride_x = 1; + scale_stride_y = scale_k; + if (transpose_scales) { + std::swap(scale_stride_x, scale_stride_y); + } + } + + size_t scale_t_stride_x = 0; + size_t scale_t_stride_y = 0; + + if (return_transpose) { + NVTE_CHECK(output_t.shape.size() == input.shape.size(), + "output_t must have same number of dimensions as input."); + if (output_t.shape.size() > 0) { + NVTE_CHECK(output_t.shape[0] == row_length, "Wrong dimension 0 of output_t."); + for (size_t i = 1; i < output_t.shape.size(); ++i) { + NVTE_CHECK(output_t.shape.at(i) == input.shape.at(i - 1), "Wrong dimension in output_t"); + } + } + + NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same dtype."); + + if (permute_scale) { + NVTE_CHECK(scale_inv_t.shape.size() == 3, "Scale_t dimension must be 3."); + scale_t_stride_x = 1; + scale_t_stride_y = scale_inv_t.shape[1]; + NVTE_CHECK(scale_inv_t.shape[2] == kTileDim, "Scale_t inner dimension must be kTileDim."); + } else { + NVTE_CHECK(scale_inv_t.shape.size() == 2, + "Scale_t dimension must be 2 when not permuting scale."); + scale_t_stride_x = 1; + scale_t_stride_y = scale_inv_t.shape[1]; + if (transpose_scales) { + std::swap(scale_t_stride_x, scale_t_stride_y); + } + } + } + + const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim); + const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype, InputType, + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output.dtype, OutputType, + + dim3 grid(num_blocks_x, num_blocks_y, 1); + + const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0; + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_transpose, kReturnTranspose, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + pow2_scale, kPow2Scale, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + full_tile, kAligned, + + size_t smem_bytes = kSMemSize * sizeof(InputType); + // shared memory must be requested up + if (smem_bytes >= 48 * 1024) { + cudaError_t err = cudaFuncSetAttribute( + &block_scaled_1d_cast_transpose_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size."); + } block_scaled_1d_cast_transpose_kernel + <<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, + scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, + epsilon);) // kAligned + ) // kPow2Scale + ) // kReturnTranspose + ) // OutputType + ) // InputType + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace transformer_engine::detail diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index d1ede8d98d..229862da88 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1245,6 +1245,27 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe workspace_tensor, stream); break; } + case NVTE_BLOCK_SCALING: { + // FIXME(kwyss): Currently ignoring IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters. + if (output_tensor->block_scaling_dim == 2) { + nvte_quantize_transpose_square_blockwise( + input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, + /*epsilon=*/output_tensor->amax_epsilon, + /*return_transpose=*/output_tensor->has_columnwise_data(), + output_tensor->force_pow_2_scales, stream); + } else if (output_tensor->block_scaling_dim == 1) { + nvte_quantize_transpose_vector_blockwise( + input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, + /*epsilon=*/output_tensor->amax_epsilon, + /*return_transpose=*/output_tensor->has_columnwise_data(), + output_tensor->force_pow_2_scales, stream); + } else { + NVTE_ERROR("Not supported block scaling dim."); + } + break; + } default: NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); } diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index e529289640..6e798ca748 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -349,6 +349,7 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); } } else { + // FIXME(kwyss): Move dequantization code from torch to C++ for NVTE_BLOCK_SCALING NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); } } diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index ff475caf21..a83a508aed 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -24,6 +24,22 @@ torch.bfloat16: tex.DType.kBFloat16, } +""" +This is a map: int -> torch.dtype +Used for resolving cuda extension types to torch. +Has one to one mapping with enum in +transformer_engine.h +""" +TE_DType_To_Torch = { + tex.DType.kByte: torch.uint8, + tex.DType.kFloat8E4M3: torch.float8_e4m3fn, + tex.DType.kFloat8E5M2: torch.float8_e5m2, + tex.DType.kInt32: torch.int32, + tex.DType.kFloat32: torch.float32, + tex.DType.kFloat16: torch.half, + tex.DType.kBFloat16: torch.bfloat16, +} + AttnMaskTypes = ( "no_mask", "padding", diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 40245cf2d9..5e51a71ca5 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -136,6 +136,36 @@ class Float8Quantizer : public Quantizer { std::optional rowwise_data = std::nullopt) const override; }; +class Float8BlockQuantizer : public Quantizer { + public: + // Which float8 type is used for q data. + DType dtype; + + private: + // Options about how to quantize the tensor + // Quantization scales are rounded down to powers of 2. + bool force_pow_2_scales = false; + // Amax within quantization tile has a floor of epsilon. + float amax_epsilon = 0.0; + int block_scaling_dim = 2; + + public: + // Initializes from a python handle to a Float8BlockQuantizer + explicit Float8BlockQuantizer(const py::handle& quantizer); + + NVTEScalingMode get_scaling_mode() const override { return NVTE_BLOCK_SCALING; } + + // Gets rowwise and columnwise_data from tensor and sets them on wrapper + void set_quantization_params(TensorWrapper* tensor) const override; + + // Create a python Float8BlockQuantized tensor and C++ wrapper + // for the tensor. Should set quantized data, scales for rowwise + // and optionally columnwise usage. + std::pair create_tensor( + const std::vector& shape, DType dtype, + std::optional rowwise_data = std::nullopt) const override; +}; + class MXFP8Quantizer : public Quantizer { public: DType dtype; diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 442837d767..27fcb524fa 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -27,6 +27,9 @@ PyTypeObject *Float8QuantizerClass = nullptr; PyTypeObject *MXFP8TensorPythonClass = nullptr; /// TODO Remove PyTypeObject *MXFP8TensorBasePythonClass = nullptr; PyTypeObject *MXFP8QuantizerClass = nullptr; +PyTypeObject *Float8BlockwiseQTensorPythonClass = nullptr; +PyTypeObject *Float8BlockwiseQTensorBasePythonClass = nullptr; +PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; void init_float8_extension() { if (Float8TensorPythonClass) return; @@ -58,9 +61,31 @@ void init_mxfp8_extension() { "Internal error: could not initialize pyTorch MXFP8 extension."); } +void init_float8blockwise_extension() { + if (Float8BlockwiseQTensorBasePythonClass) return; + auto fp8_module = + py::module_::import("transformer_engine.pytorch.tensor.float8_blockwise_tensor"); + auto fp8_base_module = py::module_::import( + "transformer_engine.pytorch.tensor._internal.float8_blockwise_tensor_base"); + Float8BlockwiseQuantizerClass = reinterpret_cast( + PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockQuantizer")); + Float8BlockwiseQTensorBasePythonClass = reinterpret_cast( + PyObject_GetAttrString(fp8_base_module.ptr(), "Float8BlockwiseQTensorBase")); + Float8BlockwiseQTensorPythonClass = reinterpret_cast( + PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockwiseQTensor")); + + NVTE_CHECK(Float8BlockwiseQuantizerClass != nullptr, + "Internal error: could not initialize pyTorch float8blockwise extension."); + NVTE_CHECK(Float8BlockwiseQTensorBasePythonClass != nullptr, + "Internal error: could not initialize pyTorch float8blockwise extension."); + NVTE_CHECK(Float8BlockwiseQTensorPythonClass != nullptr, + "Internal error: could not initialize pyTorch float8blockwise extension."); +} + void init_extension() { init_float8_extension(); init_mxfp8_extension(); + init_float8blockwise_extension(); } } // namespace transformer_engine::pytorch @@ -73,6 +98,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("output") = py::none(), py::arg("noop") = py::none()); m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"), py::arg("otype")); + m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)", diff --git a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp index effeb8cb4d..fafc9043e3 100644 --- a/transformer_engine/pytorch/csrc/extensions/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/extensions/quantizer.cpp @@ -140,6 +140,132 @@ std::pair Float8Quantizer::create_tensor( return {std::move(tensor), std::move(ret)}; } +Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) { + this->dtype = quantizer.attr("dtype").cast(); + this->force_pow_2_scales = quantizer.attr("force_pow_2_scales").cast(); + this->amax_epsilon = quantizer.attr("amax_epsilon").cast(); + this->block_scaling_dim = quantizer.attr("block_scaling_dim").cast(); +} + +void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const { + // Change the rowwise and columnwise_data to the configured dtype. + // May be a switch between E5M2 and E4M3. + auto rowwise_data = tensor->get_rowwise_data(); + rowwise_data.dtype = static_cast(dtype); + + auto columnwise_data = tensor->get_columnwise_data(); + columnwise_data.dtype = static_cast(dtype); + + tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), + rowwise_data.shape); + tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), + columnwise_data.shape); + + // Set options on TensorWrapper from quantization. + tensor->set_qopt_force_pow_2_scales(force_pow_2_scales); + tensor->set_qopt_amax_epsilon(amax_epsilon); + tensor->set_qopt_block_scaling_dim(block_scaling_dim); +} + +std::pair Float8BlockQuantizer::create_tensor( + const std::vector& shape, DType dtype, std::optional rowwise_data) const { + using namespace pybind11::literals; + std::vector torch_shape; + size_t numel = 1; + for (auto s : shape) { + torch_shape.emplace_back(static_cast(s)); + numel *= s; + } + + TensorWrapper tensor(NVTE_BLOCK_SCALING); + at::TensorOptions opts; + at::TensorOptions scale_opts; + at::Tensor data_rowwise, data_colwise, scale_inv_rowwise, scale_inv_colwise; + opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); + scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA); + + size_t k_dim = torch_shape.size() == 0 ? 1u : torch_shape.back(); + size_t m_dim = numel / k_dim; + constexpr size_t kBlockLen = 128; + + if (rowwise_usage) { + if (rowwise_data.has_value()) { + data_rowwise = std::move(*rowwise_data); + } else { + data_rowwise = at::empty(torch_shape, opts); + } + size_t sinv0 = 0; + size_t sinv1 = 0; + if (block_scaling_dim == 2) { + sinv0 = (m_dim + kBlockLen - 1) / kBlockLen; + sinv1 = roundup((k_dim + kBlockLen - 1) / kBlockLen, 4); + } else if (block_scaling_dim == 1) { + sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; + sinv1 = m_dim; + } else { + NVTE_CHECK(false, "Unsupported block_scaling_dim in create_tensor rowwise."); + } + scale_inv_rowwise = at::empty({sinv0, sinv1}, scale_opts); + tensor.set_rowwise_data(data_rowwise.data_ptr(), this->dtype, shape); + tensor.set_rowwise_scale_inv(scale_inv_rowwise.data_ptr(), DType::kFloat32, + std::vector{sinv0, sinv1}); + } + + if (columnwise_usage) { + std::vector torch_columnwise_shape; + std::vector columnwise_shape; + NVTE_CHECK(torch_shape.size() == shape.size(), "Shape expected to match torch shape."); + if (torch_shape.size() > 0) { + torch_columnwise_shape.reserve(torch_shape.size()); + columnwise_shape.reserve(shape.size()); + torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]); + columnwise_shape.push_back(shape[shape.size() - 1]); + for (size_t i = 0; i < torch_shape.size() - 1; ++i) { + torch_columnwise_shape.push_back(torch_shape[i]); + columnwise_shape.push_back(shape[i]); + } + } + size_t sinv0 = 0; + size_t sinv1 = 0; + if (block_scaling_dim == 2) { + sinv0 = (k_dim + kBlockLen - 1) / kBlockLen; + sinv1 = roundup((m_dim + kBlockLen - 1) / kBlockLen, 4); + } else if (block_scaling_dim == 1) { + sinv0 = (m_dim + kBlockLen - 1) / kBlockLen; + sinv1 = k_dim; + } else { + NVTE_CHECK(false, "Unsupported block_scaling_dim in create_tensor columnwise."); + } + data_colwise = at::empty(torch_columnwise_shape, opts); + scale_inv_colwise = at::empty({sinv0, sinv1}, scale_opts); + + tensor.set_columnwise_data(data_colwise.data_ptr(), this->dtype, columnwise_shape); + tensor.set_columnwise_scale_inv(scale_inv_colwise.data_ptr(), DType::kFloat32, + std::vector{sinv0, sinv1}); + } + this->set_quantization_params(&tensor); + + py::object ret; + if (internal) { + py::handle Float8BlockwiseQTensorClass( + reinterpret_cast(Float8BlockwiseQTensorBasePythonClass)); + ret = Float8BlockwiseQTensorClass( + "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, + "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, + "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + } else { + py::handle Float8BlockwiseQTensorClass( + reinterpret_cast(Float8BlockwiseQTensorPythonClass)); + ret = Float8BlockwiseQTensorClass( + "shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise, + "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, + "columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype, + "quantizer"_a = this->quantizer); + } + + return {std::move(tensor), std::move(ret)}; +} + MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { this->dtype = quantizer.attr("dtype").cast(); } diff --git a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp index d2607e4ed0..48b0919a69 100644 --- a/transformer_engine/pytorch/csrc/extensions/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/extensions/type_converters.cpp @@ -74,6 +74,38 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizer) return ret; } +TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer *quantizer) { + const DType dtype = tensor.attr("_fp8_dtype").cast(); + auto ret = TensorWrapper(NVTE_BLOCK_SCALING); + + bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); + bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); + + if (rowwise_usage) { + const at::Tensor &data_rowwise = tensor.attr("_rowwise_data").cast(); + const at::Tensor &scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast(); + void *scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr(); + const auto &shape = getTensorShape(data_rowwise); + ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, shape); + + const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise); + ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat32, scale_inv_rowwise_shape); + } + + if (columnwise_usage) { + const at::Tensor &data_colwise = tensor.attr("_columnwise_data").cast(); + const at::Tensor &scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast(); + void *scale_inv_colwise_dptr = scale_inv_colwise.data_ptr(); + const auto &shape = getTensorShape(data_colwise); + ret.set_columnwise_data(data_colwise.data_ptr(), dtype, shape); + + const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise); + ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape); + } + quantizer->set_quantization_params(&ret); + return ret; +} + } // namespace detail } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index 0679528b94..ba3d4bb6fc 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -24,6 +24,9 @@ extern PyTypeObject *Float8QuantizerClass; extern PyTypeObject *MXFP8TensorPythonClass; extern PyTypeObject *MXFP8TensorBasePythonClass; extern PyTypeObject *MXFP8QuantizerClass; +extern PyTypeObject *Float8BlockwiseQTensorPythonClass; +extern PyTypeObject *Float8BlockwiseQTensorBasePythonClass; +extern PyTypeObject *Float8BlockwiseQuantizerClass; void init_extension(); @@ -45,6 +48,15 @@ inline bool IsMXFP8Tensor(PyObject *obj) { return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorBasePythonClass; } +inline bool IsFloat8BlockwiseQParams(PyObject *obj) { + return Py_TYPE(obj) == Float8BlockwiseQuantizerClass; +} + +inline bool IsFloat8BlockwiseQTensor(PyObject *obj) { + return Py_TYPE(obj) == Float8BlockwiseQTensorPythonClass || + Py_TYPE(obj) == Float8BlockwiseQTensorBasePythonClass; +} + TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer); template @@ -56,6 +68,9 @@ TensorWrapper NVTETensorFromMXFP8Tensor(py::handle tensor, Quantizer *quantizati std::unique_ptr CreateMXFP8Params(const py::handle params); +TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, + Quantizer *quantization_params); + inline bool IsFloatingPointType(at::ScalarType type) { return type == at::kFloat || type == at::kHalf || type == at::kBFloat16; } @@ -64,7 +79,9 @@ constexpr std::array custom_types_converters = { std::make_tuple(IsFloat8Tensor, IsFloat8QParams, NVTETensorFromFloat8Tensor, CreateQuantizer), std::make_tuple(IsMXFP8Tensor, IsMXFP8QParams, NVTETensorFromMXFP8Tensor, - CreateQuantizer)}; + CreateQuantizer), + std::make_tuple(IsFloat8BlockwiseQTensor, IsFloat8BlockwiseQParams, + NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer)}; } // namespace detail diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py new file mode 100644 index 0000000000..ffed102ee5 --- /dev/null +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -0,0 +1,246 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Mixin class holding data specific for Float8BlockwiseQTensor""" + +from __future__ import annotations +import math +from typing import Optional, Dict, Any, Tuple +import torch + +import transformer_engine_torch as tex +from transformer_engine_torch import DType as TE_DType + +from ...constants import TE_DType as torch_to_transformer_engine_dtype +from ...constants import TE_DType_To_Torch + +from ..quantized_tensor import Quantizer + + +class Float8BlockwiseQTensorBase: + """Mixin class that holds data attributes of Float8BlockwiseQTensor. + + Float8BlockwiseQTensor inherits from the PyTorch tensor class and this + mixin class. If this class is instantiated directly, it has the same + data, lower CPU overhead, and less functionality. It should only + be instantiated directly for performance-critical internal usage. + """ + + _rowwise_data: Optional[torch.Tensor] + _columnwise_data: Optional[torch.Tensor] + _quantizer: Quantizer + _fp8_dtype: TE_DType + _rowwise_scale_inv: Optional[torch.Tensor] + _columnwise_scale_inv: Optional[torch.Tensor] + + def __new__( + cls, + *args, + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: Optional[torch.Tensor], + columnwise_scale_inv: Optional[torch.Tensor], + fp8_dtype: TE_DType, + quantizer: Quantizer, + **kwargs, + ): + instance = super().__new__(cls, *args, **kwargs) + instance._rowwise_data = rowwise_data + instance._columnwise_data = columnwise_data + instance._quantizer = quantizer + instance._fp8_dtype = fp8_dtype + instance._rowwise_scale_inv = rowwise_scale_inv + instance._columnwise_scale_inv = columnwise_scale_inv + + return instance + + def get_metadata(self) -> Dict[str, Any]: + """Get this tensor's metadata.""" + return { + "rowwise_data": self._rowwise_data, + "rowwise_scale_inv": self._rowwise_scale_inv, + "columnwise_data": self._columnwise_data, + "columnwise_scale_inv": self._columnwise_scale_inv, + "fp8_dtype": self._fp8_dtype, + "quantizer": self._quantizer, + } + + def prepare_for_saving( + self, + ) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]: + """Prepare the tensor base for saving for backward + + FIXME(kwyss): Should this clear out data? + FIXME(kwyss): What about dq scales? + """ + tensors = [self._rowwise_data, self._columnwise_data] + return tensors, self + + def restore_from_saved( + self, tensors: list[Optional[torch.Tensor]] + ) -> list[Optional[torch.Tensor]]: + """Restore the tensor base data from the saved tensors list.""" + self._rowwise_data = tensors[0] + self._columnwise_data = tensors[1] + return tensors[2:] + + def get_data_tensors(self): + """Get this Tensor's data.""" + return self._rowwise_data, self._columnwise_data + + def _transpose_dq_columnwise_output(self, columnwise_dq: torch.Tensor) -> torch.Tensor: + """Takes dequantized columnwise data and permutes to a rowwise shape""" + if columnwise_dq.dim() < 2: + return columnwise_dq + permute_dims = [x for x in range(1, columnwise_dq.dim())] + permute_dims.append(0) + return torch.permute(columnwise_dq, tuple(permute_dims)).contiguous() + + def _dequantize_vectorwise(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + block_len = 128 + + q_M, q_K = 1, 1 + if self._rowwise_data is not None: + q = self._rowwise_data + scale_inv = self._rowwise_scale_inv + transpose_output = False + if len(q.shape) >= 1: + q_K = q.shape[-1] + for i in range(len(q.shape) - 1): + q_M *= q.shape[i] + else: + assert self._columnwise_data is not None, "No data to dequantize" + q = self._columnwise_data + scale_inv = self._columnwise_scale_inv + transpose_output = True + if len(q.shape) >= 1: + q_M = q.shape[0] + for i in range(1, len(q.shape)): + q_K *= q.shape[i] + + orig_shape = q.shape + q = q.reshape(q_M, q_K) + k_tiles, m = scale_inv.shape + if q_K % block_len != 0: + k_pad_amount = (block_len - (q_K % block_len)) % block_len + q = torch.nn.functional.pad( + q, (0, k_pad_amount, 0, 0), mode="constant", value=0 + ).contiguous() + _, padded_K = q.shape + q_tiled = q.reshape(q_M, k_tiles, block_len) + dq_scale = scale_inv.transpose(-2, -1).contiguous().reshape(m, k_tiles, 1) + torch_q_dtype = TE_DType_To_Torch[self._fp8_dtype] + result = q_tiled.view(torch_q_dtype).to(torch.float32) * dq_scale + if padded_K != q_K: + result = result.reshape(q_M, padded_K)[:, :q_K] + result = result.to(dtype) + if len(orig_shape) == 0: + result = result.reshape([]) + else: + result = result.reshape(*orig_shape).contiguous() + + if transpose_output: + return self._transpose_dq_columnwise_output(result) + return result + + def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """ + Construct plain PyTorch tensor from Float8BlockwiseQTensor + """ + block_len = 128 + assert self._quantizer is not None + if self._quantizer.block_scaling_dim != 2: + assert self._quantizer.block_scaling_dim == 1 + return self._dequantize_vectorwise(dtype=dtype) + + def format_scale_as_logical_shape(q_M, q_K, scales, block_len): + # The GEMM for 2D blocks required padding in the scales. + derived_scale_k_shape = math.ceil(q_K / block_len) + scale_M, scale_K = scales.shape + if derived_scale_k_shape == scale_K: + return scales + else: + return scales[:, :derived_scale_k_shape].contiguous() + return formatted_scales + + q_M, q_K = 1, 1 + if self._rowwise_data is not None: + q = self._rowwise_data + scale_inv = self._rowwise_scale_inv + transpose_output = False + if len(q.shape) >= 1: + q_K = q.shape[-1] + for i in range(len(q.shape) - 1): + q_M *= q.shape[i] + else: + assert self._columnwise_data is not None, "No data to dequantize" + q = self._columnwise_data + scale_inv = self._columnwise_scale_inv + transpose_output = True + if len(q.shape) >= 1: + q_M = q.shape[0] + for i in range(1, len(q.shape)): + q_K *= q.shape[i] + + orig_shape = q.shape + q = q.reshape(q_M, q_K) + formatted_scales = format_scale_as_logical_shape(q_M, q_K, scale_inv, block_len) + assert len(formatted_scales.shape) == 2 + m_tiles, k_tiles = formatted_scales.shape + unpadded_m, unpadded_k = q_M, q_K + m_block_len = block_len + k_block_len = block_len + if q_M % m_block_len != 0 or q_K % k_block_len != 0: + m_pad_amount = (m_block_len - (q_M % m_block_len)) % m_block_len + k_pad_amount = (k_block_len - (q_K % k_block_len)) % k_block_len + q = torch.nn.functional.pad( + q, (0, k_pad_amount, 0, m_pad_amount), mode="constant", value=0 + ).contiguous() + padded_M, padded_K = q.shape + q_tiled = q.reshape(m_tiles, m_block_len, k_tiles, k_block_len) + + torch_q_dtype = TE_DType_To_Torch[self._fp8_dtype] + + result = q_tiled.view(torch_q_dtype).to(torch.float32) * formatted_scales.view( + m_tiles, 1, k_tiles, 1 + ) + result = result.view(padded_M, padded_K).to(dtype) + if padded_M != unpadded_m or padded_K != unpadded_k: + result = result[:unpadded_m, :unpadded_k] + if len(orig_shape) == 0: + result = result.reshape([]) + else: + result = result.reshape(*orig_shape).contiguous() + if transpose_output: + return self._transpose_dq_columnwise_output(result) + return result + + def size(self, *args, **kwargs): + # pylint: disable=missing-function-docstring + if self._rowwise_data is not None: + return self._rowwise_data.size(*args, **kwargs) + else: + dims = list(self._columnwise_data.size(*args, **kwargs)) + reordered = [] + for i in range(1, len(dims)): + reordered.append(dims[i]) + reordered.append(dims[0]) + return torch.Size(reordered) + + def __repr__(self): + if self._rowwise_data is not None: + data = self.dequantize() + descriptor = "rowwise" + scale_inv = self._rowwise_scale_inv + else: + data = self.dequantize() + descriptor = "columnwise" + scale_inv = self._columnwise_scale_inv + return ( + "Float8BlockwiseQTensorBase(" + f"fp8_dtype={self._fp8_dtype}, " + f"{descriptor}_scaled_data={data_rowwise}" + f"{descriptor}_scale_inv={scale_inv}, " + ")" + ) diff --git a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py index d78bd55d9a..45de7e621e 100644 --- a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py @@ -96,7 +96,7 @@ def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorB """Prepare the tensor base for saving for backward After calling this, the tensor instance does not hold any - data. + data. Yes it does? TODO """ tensors = [self._rowwise_data, self._columnwise_data] diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py new file mode 100644 index 0000000000..8090ac20b8 --- /dev/null +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -0,0 +1,539 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tensor class with FP8 data quantized with NxN tiles""" +from __future__ import annotations +from typing import Optional, Tuple, Iterable +import warnings + +import math +import torch +import transformer_engine_torch as tex + +from transformer_engine_torch import DType as TE_DType +from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc +from ..utils import devices_match, round_up_to_nearest_multiple + +aten = torch.ops.aten + + +class Float8BlockQuantizer(Quantizer): + """Builder class for tensors quantized with current scaling using + NxN quantization tilings to choose scale. + + This class is typically used to convert a high-precision tensor + (e.g. in FP32 or BF16) into a quantized tensor (e.g. in FP8). + + """ + + dtype: TE_DType + block_len: int + amax_epsilon: float + force_pow_2_scales: bool + block_scaling_dim: int + + def __init__( + self, + fp8_dtype: TE_DType, + *, + rowwise: bool, + columnwise: bool, + amax_epsilon: float = 0.0, + force_pow_2_scales: bool = False, + block_scaling_dim: int = 2, + ) -> None: + super().__init__(rowwise=rowwise, columnwise=columnwise) + assert rowwise + self.dtype = fp8_dtype + self.block_len = 128 + self.force_pow_2_scales = force_pow_2_scales + self.amax_epsilon = amax_epsilon + self.block_scaling_dim = block_scaling_dim + + def update_quantized( + self, + src: torch.Tensor, + dst: QuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + assert isinstance( + dst, Float8BlockwiseQTensor + ), f"Cannot store quantized blockwise tensor in {type(dst)} type." + # Make sure input is in expected format + if not devices_match(src.device, dst.device): + src = src.to(device=dst.device) + if not src.is_contiguous(): + src = src.contiguous() + + # Launch cast kernel + tex.quantize(src, self, dst, noop_flag) + + dst._fp8_dtype = self.dtype + return dst + + def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]: + # cuBLAS kernel format (for NxN by NxN and 1xN by NxN GEMMs) + # The scales for 2D block quantized tensors must have scales padded + # to multiples of 4 on the inner dimension. TODO: Verify whether outer + # dimension also to be padded for either GEMM. + if self.block_scaling_dim == 2: + logical_scale_shape = [1, 1] + for i in range(len(shape) - 1): + logical_scale_shape[-2] *= shape[i] + if len(shape) > 0: + logical_scale_shape[-1] = math.ceil(shape[-1] / self.block_len) + logical_scale_shape[-2] = math.ceil(logical_scale_shape[-2] / self.block_len) + if columnwise: + tmp = logical_scale_shape[-1] + logical_scale_shape[-1] = logical_scale_shape[-2] + logical_scale_shape[-2] = tmp + logical_scale_shape[-1] = round_up_to_nearest_multiple(logical_scale_shape[-1], 4) + return tuple(logical_scale_shape) + else: + assert self.block_scaling_dim == 1, "Only 1D or 2D blocks supported" + + logical_scale_shape = [1, 1] + for i in range(len(shape) - 1): + logical_scale_shape[-1] *= shape[i] + if len(shape) > 0: + logical_scale_shape[-2] = shape[-1] + if not columnwise: + logical_scale_shape[-2] = math.ceil(logical_scale_shape[-2] / self.block_len) + return tuple(logical_scale_shape) + else: + logical_scale_shape[-1] = math.ceil(logical_scale_shape[-1] / self.block_len) + return (logical_scale_shape[1], logical_scale_shape[0]) + + def get_columnwise_shape(self, shape: Iterable[int]) -> Tuple[int, ...]: + if len(shape) == 0: + return tuple() + colwise_shape = [shape[-1]] + for i in range(len(shape) - 1): + colwise_shape.append(shape[i]) + return tuple(colwise_shape) + + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + requires_grad: bool = False, + ) -> Float8BlockwiseQTensor: + """Construct quantized tensor with uninitialized data""" + if device is None: + device = torch.device("cuda") + + # Allocate FP8 data + data = torch.empty(shape, dtype=torch.uint8, device=device) + scale_shape = self.get_scale_shape(shape, columnwise=False) + scale_inv = torch.empty( + scale_shape, + dtype=torch.float32, + device=device, + ) + + # Allocate FP8 data transpose if needed + columnwise_data = None + columnwise_scale_inv = None + if self.columnwise_usage: + columnwise_data = torch.empty( + self.get_columnwise_shape(shape), dtype=torch.uint8, device=device + ) + columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) + columnwise_scale_inv = torch.empty( + columnwise_scale_shape, + dtype=torch.float32, + device=device, + ) + + # Construct FP8 tensor + return Float8BlockwiseQTensor( + shape=shape, + dtype=dtype, + fp8_dtype=self.dtype, + rowwise_data=data, + rowwise_scale_inv=scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + quantizer=self, + requires_grad=requires_grad, + ) + + def calibrate(self, tensor: torch.Tensor) -> None: + # NOTE: This interface is specific to requirements like delayed scaling + # where state from an estimator influences distribution parameters. + pass + + +class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): + """Tensor class with FP8 data quantized via NxN blocks or 1xN blocks. + + The tensor presents as having a standard, higher-precision dtype, + but the data itself is (scaled) FP8. For most tensor operations, + the data will be cast to the nominal dtype before performing the + operation. + + Parameters + ---------- + rowwise_data: torch.Tensor + FP8 data in a uint8 tensor matching shape of dequantized tensor. + rowwise_scale_inv: torch.Tensor + FP32 dequantization scales in GEMM format for dequantizing rowwise_data. + columnwise_data: Optional[torch.Tensor] + FP8 data in a uint8 tensor matching shape of dequantized tensor transpose. + columnwise_scale_inv: Optional[torch.Tensor] + FP32 dequantization scales in GEMM format for dequantizing columnwise_data. + + fp8_dtype: transformer_engine_torch.DType, default = kFloat8E4M3 + FP8 format. + quantizer: Quantizer - the Float8BlockQuantizer that quantized this tensor and + holds configuration about quantization and dequantization modes. + """ + + def __repr__(self, *, tensor_contents=None): + return ( + f"Float8BlockwiseQTensor(fp8_dtype={self._fp8_dtype}," + f" data={self.dequantize(dtype=self.dtype)})" + ) + + def _get_quantizer(self) -> Quantizer: + """Get builder for quantized tensor + + Quantizer can be used for in-place operations. + + """ + assert self._quantizer is not None + return self._quantizer + + def quantize_( + self, + tensor: torch.Tensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> Float8BlockwiseQTensor: + """Update FP8 data + + Parameters + ---------- + tensor: torch.Tensor + Tensor to copy from + noop_flag: torch.Tensor, optional + float32 flag indicating whether to avoid performing update + + """ + if isinstance(tensor, QuantizedTensor): + return self.quantize_(tensor.dequantize()) + self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) + return self + + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """ + Construct plain PyTorch tensor from Float8BlockwiseQTensor + + By default the resulting tensor's dtype is the + Float8BlockwiseQTensor's pre-quantized dtype. + """ + if dtype is not None: + dequant_dtype = dtype + else: + dequant_dtype = self.dtype + return super().dequantize(dtype=dequant_dtype) + + def detach(self) -> Float8BlockwiseQTensor: + # pylint: disable=missing-function-docstring + return Float8BlockwiseQTensor.make_like(self) + + def update_usage(self, rowwise_usage=True, columnwise_usage=True): + """ + update_usage can be used to clear out one of two possible copies of the data. + """ + + assert ( + columnwise_usage or rowwise_usage + ), "Must retain some data either columnwise or rowwise" + + if columnwise_usage and rowwise_usage: + assert ( + self._rowwise_data is not None + and self._rowwise_scale_inv is not None + and self._columnwise_data is not None + and self._columnwise_scale_inv is not None + ), "Cannot update to rowwise and columnwise usage." + return + + if rowwise_usage: + assert ( + self._rowwise_data is not None and self._rowwise_scale_inv is not None + ), "Cannot update to rowwise usage." + self._columnwise_data = None + self._columnwise_scale_inv = None + return + if columnwise_usage: + assert ( + self._columnwise_data is not None and self._columnwise_scale_inv is not None + ), "Cannot update to columnwise usage." + self._rowwise_data = None + self._rowwise_scale_inv = None + return + + return + + def clone(self) -> Float8BlockwiseQTensor: + # pylint: disable=missing-function-docstring + rowwise_data = None + if self._rowwise_data is not None: + rowwise_data = self._rowwise_data.detach().clone() + columnwise_data = None + if self._columnwise_data is not None: + columnwise_data = self._columnwise_data.detach().clone() + return _IdentityFunc.apply( + self, + { + "rowwise_data": rowwise_data, + "columnwise_data": columnwise_data, + }, + ) + + def view(self, *shape: Tuple[int]) -> Float8BlockwiseQTensor: + # pylint: disable=missing-function-docstring + return _ViewFunc.apply(self, shape) + + def reshape(self, *shape: Tuple[int]) -> Float8BlockwiseQTensor: + # pylint: disable=missing-function-docstring + return _ReshapeFunc.apply(self, shape) + + def contiguous( + self, + memory_format: torch.memory_format = torch.contiguous_format, + ) -> Float8BlockwiseQTensor: + """Returns tensor with data in provided memory format + + Returns `self` if data is already in correct memory format. + + """ + if ( + self._rowwise_data is not None + and self._rowwise_data.is_contiguous(memory_format=memory_format) + and ( + (self._columnwise_data is None) + or (self._columnwise_data.is_contiguous(memory_format=memory_format)) + ) + ): + return self + raise ValueError("Float8BlockwiseQTensor does not support different memory formats!") + + def clear(self): + """Deallocate this tensor's memory. Typically not needed and must be used carefully.""" + self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None + self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + + # View op + if func == aten.view.default: + tensor = args[0] + data = tensor._rowwise_data + out_data = data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + out_shape = out_data.size() + return Float8BlockwiseQTensor( + shape=out_shape, + dtype=tensor.dtype, + rowwise_data=out_data, + rowwise_scale_inv=tensor._rowwise_scale_inv, + columnwise_data=tensor._columnwise_data, + columnwise_scale_inv=tensor._columnwise_scale_inv, + quantizer=tensor._quantizer, + requires_grad=False, + fp8_dtype=tensor._fp8_dtype, + ) + + # Default case + return super().__torch_dispatch__(func, types, args, kwargs) + + @classmethod + def _make_in_reduce_ex( + cls, + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: torch.Tensor, + columnwise_scale_inv: torch.Tensor, + fp8_dtype: TE_DType, + dtype: torch.dtype, + ) -> Float8BlockwiseQTensor: + """Build Float8BlockwiseQTensor, for use in __reduce__ + + __reduce_ex__ assumes object constructor has positional + arguments. + + """ + return Float8BlockwiseQTensor( + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + fp8_dtype=fp8_dtype, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + dtype=dtype, + ) + + def __reduce_ex__(self, protocol: int) -> tuple: + """Custom pickling to remove references to FP8 metadata objects""" + return ( + Float8BlockwiseQTensor._make_in_reduce_ex, + ( + self._rowwise_data, + self._rowwise_scale_inv, + self._columnwise_data, + self._columnwise_scale_inv, + self._fp8_dtype, + self.dtype, + ), + ) + + def _get_data(self) -> Float8BlockwiseQTensor: + """Get tensor data property""" + return super().data + + @torch.no_grad() + def _set_data(self, tensor: torch.Tensor) -> None: + """Set tensor data property + + Just takes FP8 data if setting from a Float8BlockwiseQTensor. Otherwise + casts to FP8. + + """ + + # Tensor device + new_device = tensor.device if tensor.is_cuda else self.device + + # Just copy FP8 data if other tensor is Float8BlockwiseQTensor + if isinstance(tensor, Float8BlockwiseQTensor): + if ( # pylint: disable=too-many-boolean-expressions + self.size() != tensor.size() + or self.stride() != tensor.stride() + or self.storage_offset() != tensor.storage_offset() + or self.dtype != tensor.dtype + or self.layout != tensor.layout + or not devices_match(self.device, new_device) + ): + dummy_tensor = torch.Tensor._make_wrapper_subclass( + Float8BlockwiseQTensor, + tensor.size(), + strides=tensor.stride(), + storage_offset=tensor.storage_offset(), + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=tensor.requires_grad, + device=new_device, + ) + # pylint: disable=unnecessary-dunder-call + super(Float8BlockwiseQTensor, type(self)).data.__set__(self, dummy_tensor) + self._rowwise_data = tensor._rowwise_data + self._columnwise_data = tensor._columnwise_data + self._quantizer = tensor._quantizer + self._fp8_dtype = tensor._fp8_dtype + self._rowwise_scale_inv = tensor._rowwise_scale_inv + self._columnwise_scale_inv = tensor._columnwise_scale_inv + return + + # Quantize to FP8 + assert self._quantizer is not None, "Can't quantize without a quantizer" + self.data = self._quantizer.quantize(tensor) + if self.requires_grad != tensor.requires_grad: + self.requires_grad_(requires_grad=tensor.requires_grad) + + # Cast to FP8 when setting Float8BlockwiseQTensor.data + data = property(_get_data, _set_data) + + +class _ViewFunc(torch.autograd.Function): + """View function + + View the Float8BlockwiseQTensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: Float8BlockwiseQTensor, + shape: Optional[list[int]] = None, + ) -> Float8BlockwiseQTensor: + # pylint: disable=missing-function-docstring + + # Return input tensor if shape is not provided + ctx.shape = tensor.shape + if shape is None: + return tensor + + if shape != ctx.shape: + raise NotImplementedError("View not implemented.") + else: + return tensor + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + + if isinstance(grad, Float8BlockwiseQTensor): + raise NotImplementedError("View bwd not implemented") + return grad.view(ctx.shape), None + + +class _ReshapeFunc(torch.autograd.Function): + """Reshape function + + Reshape the Float8BlockwiseQTensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: Float8BlockwiseQTensor, + shape: Optional[list[int]] = None, + ) -> Float8BlockwiseQTensor: + # pylint: disable=missing-function-docstring + + # Return input tensor if shape is not provided + ctx.shape = tensor.shape + if shape is None: + return tensor + + # Canonicalize shape + if not isinstance(shape, Iterable): + shape = [shape] + elif len(shape) == 1 and isinstance(shape[0], Iterable): + shape = shape[0] + if -1 in shape: + shape = list(shape) + d_inferred = -math.prod(ctx.shape) // math.prod(shape) + for i, d in enumerate(shape): + if d == -1: + shape[i] = d_inferred + break + if shape != ctx.shape: + raise NotImplementedError("Reshape not implemented yet.") + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + + if isinstance(grad, Float8BlockwiseQTensor): + raise NotImplementedError("Reshape bwd not implemented yet.") + return grad.view(ctx.shape), None